Custom Gradient Rules
You can easily define your own gradient transformation or compose existing ones.
🔹 Registering a Rule Function
The easiest way to create a custom rule is to define a function and register it with xg.register.
A rule function must have the following signature:
fn(ctx, grad_out, tin, **params)
ctx: Thetorch.autograd.Context(rarely used, can beNone).grad_out: The incoming gradient tensor (i.e., $dL/dy$).tin: The input tensor from the forward pass (i.e., $x$).- Note: This will be
Noneif the rule is running on a parameter gradient (fromscope="params"). Your rule should handle this case.
- Note: This will be
**params: A dictionary that collects any parameters passed toxg.use(..., params={...}).
The function must return the new gradient, $dL/dx$.
Example:
Here is a rule that applies a threshold and adds noise.
import freegrad as xg
import torch
@xg.register("noisy_threshold")
def noisy_threshold(ctx, grad_out, tin, t: float = 0.0, sigma: float = 0.1):
# Handle parameter hooks where tin is None
if tin is None:
mask = 1.0
else:
mask = (tin >= t).to(grad_out.dtype)
noise = sigma * torch.randn_like(grad_out)
return grad_out * mask + noise
````
Use it:
```python
import freegrad as xg
from freegrad.wrappers import Activation
import torch
x = torch.randn(5, requires_grad=True)
act = Activation("ReLU")
# Apply the custom rule
with xg.use("noisy_threshold", params={"t": 0.5, "sigma": 0.05}):
y = act(x).sum()
y.backward()
🔹 Composing Rules
You can create a new rule by composing existing rules in series using xg.compose. The rules are applied in the order they are listed.
import freegrad as xg
from freegrad.wrappers import Activation
import torch
# Create a new rule that first clips, then adds noise
clip_and_noise = xg.compose("clip_norm", "noise")
# You can register it (optional)
xg.register("clip_then_noise")(clip_and_noise)
# Example input and activation
x = torch.randn(5, requires_grad=True)
act = Activation("Linear")
# Use it directly
with xg.use(clip_and_noise, params={"max_norm": 1.0, "sigma": 0.1}, scope="activations"):
y = act(x).sum()
y.backward()
🔹 Advanced: Custom torch.autograd.Function
For complete control over both forward and backward, or to create a new operation that is not just an activation, you can define a torch.autograd.Function.
Note: This is a standard PyTorch feature. A Function defined this way will not interact with the xg.use context manager. It has its own, hard-coded backward pass.
import torch
class SquareNoiseFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, sigma=0.1):
ctx.save_for_backward(x)
ctx.sigma = sigma
return x**2
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
# Custom backward: 2*x*grad_output + noise
noise = ctx.sigma * torch.randn_like(x)
grad_in = grad_output * (2*x) + noise
# Return grads for (x, sigma)
return grad_in, None
def square_noise(x, sigma=0.1):
return SquareNoiseFn.apply(x, sigma)
# ---
# Use it like a regular function
# ---
x = torch.randn(5, requires_grad=True)
y = square_noise(x, sigma=0.1).sum()
y.backward()
print(x.grad) # Will have noise