nn_freeze is to make a model untrainable.
Commented out is a (broken) version that replaces parameters with constant buffers.
Checking pytorch's model.train() and model.eval() source code, it seems that setting requires_grad to False is sufficient.
def nn_freeze(model):
# for n,p in model.named_parameters():
# x = p.data
# delattr(model,n)
# model.register_buffer(n,x)
for param in model.parameters(): param.requires_grad = False
for l in model.children(): nn_freeze(l)
# for l in model.modules(): nn_freeze(l)
return model
def nn_unfreeze(model):
for param in model.parameters(): param.requires_grad = True
for l in model.children(): nn_freeze(l)
return model
Just to be sure that this is working as intended, here is a frozen example that is unable to train even after model.train()
def nn_freeze_check():
torch.cuda.empty_cache()
σ = nn.GELU()
model = nn_freeze(nn.Sequential(*[
nn.Conv2d(1 ,1,(3,3)), σ, nn.MaxPool2d((2,2)),
nn.Conv2d(1 ,32,(3,3)), σ, nn.MaxPool2d((2,2)),
nn.Conv2d(32,10,(3,3)), σ, nn.MaxPool2d((2,2)),
nn.Flatten(), nn.Linear(10,10), nn.Softmax(1),
]).to(device))
class DebugModule(nn.Module):
def __init__(self):
super(DebugModule,self).__init__()
# need to add useless parameters to be optimized, otherwise torch complains and crashes
self.L = nn.Linear(784,10).to(device)
self.flat = nn.Flatten()
def forward(self, x): return model(x) + 0*self.L(self.flat(x))
m = DebugModule()
mnist_train(m,10)
modulize makes an nn.Module out of a function.
def modulize(fun):
"""
Parameters
----------
fun : function
"""
class Modulized(nn.Module):
def __init__(self): super(Modulized,self).__init__()
def forward(self, x): return fun(x)
return Modulized()
gradiator makes an nn.Module out of a function and a given custom jacobian
# https://towardsdatascience.com/extending-pytorch-with-custom-activation-functions-2d8b065ef2fa
def gradiator(fun,grad_fun):
"""
Parameters
----------
fun : x ↦ y |
grad_fun : x,grad_out ↦ grad_in | evaluation of jacobian for backpropagation
"""
class GradiatorFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x) # save input for backward pass
return fun(x)
@staticmethod
def backward(ctx, grad_output):
if not ctx.needs_input_grad[0]: return None # if grad not required, don't compute
x, = ctx.saved_tensors # restore input from context
grad_input = grad_fun(x,grad_output)
return grad_input
return modulize(GradiatorFunction.apply)
approximator makes an nn.Module out of a function and a given differentiable approximation
def approximator(fun,app): # PNNFunction
"""
Parameters
----------
fun : function |
app : diff function | differentiable approximation of fun
"""
class ApproximatorFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
ctx.save_for_backward(*args)
return fun(*args)
@staticmethod
def backward(ctx, grad_output):
args = ctx.saved_tensors
# https://pytorch.org/docs/stable/generated/torch.set_grad_enabled.html
torch.set_grad_enabled(True) # vjp's output require_grad
# https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html
y = torch.autograd.functional.vjp(app, args, v=grad_output)
torch.set_grad_enabled(False)
return y[1]
return ApproximatorFunction.apply
Here is an usage example
σ = modulize(lambda x: torch.sin(x**2))
σ = modulize(lambda x: x ,grad_out: grad_out)
σ = modulize(lambda x: x**2,grad_out: 2*x*grad_out)
σ = approximator(lambda x: torch.floor(x),lambda x: x)
mnist_model = nn.Sequential(*[
nn.Conv2d(1 ,1,(3,3)), σ, nn.MaxPool2d((2,2)),
nn.Conv2d(1 ,32,(3,3)), σ, nn.MaxPool2d((2,2)),
nn.Conv2d(32,10,(3,3)), σ, nn.MaxPool2d((2,2)),
nn.Flatten(), nn.Linear(10,10), nn.Softmax(1),
]).to(device)