x = torch.randn(N,dim_i ).to(device) # pick N evaluation point at random
θ = torch.randn(N,dim_i,dim_o).to(device) # pick N evaluation point at random (if θ_batch=False)
θ = torch.randn(dim_i,dim_o).repeat(N,1,1).to(device) # pick 1 evaluation point at random (if θ_batch=True)
dx = (torch.randn(x.shape)*ε).to(device) # pick a variation
dθ = (torch.randn(θ.shape)*ε).to(device) # pick a variation
df = blackbox(x+dx) - blackbox(x) # A(x+ε) - A(x) = Aε
go = torch.randn(N,dim_o).to(device) # pick a grad_o at random
gx,gθ = bwd_model(x,θ,go)
should = (gx*dx).sum(dim=1).view(1,-1) + (gθ*dθ).sum(dim=1).sum(dim=1).view(1,-1) # for subspace
found = (go*df).sum(dim=1).view(1,-1) # for subspace (corresponding ones)
loss_mse = mse(should, found)
loss_cos = 1 - cossim(should, found).mean()
loss_dot = - (should*found/(torch.norm(should,dim=1)**2)).sum() # scale down by 1/|should|
bwd_σ=Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)])
bwd_models = [
[ "basic", lambda bwd_i,bwd_o: nn.Sequential(
bwd_σ, Linear(bwd_σ.m * bwd_i, bwd_o),
)],
[ "unbounded", lambda bwd_i,bwd_o: nn.Sequential(
Stack([nn.Identity(),GELU(),Exp(maximum=False),Log(minimum=False)]), Linear(bwd_σ.m * bwd_i, bwd_o),
)],
[ "extra", lambda bwd_i,bwd_o: nn.Sequential(
bwd_σ, Linear(bwd_σ.m * bwd_i, bwd_o),
bwd_σ, Linear(bwd_σ.m * bwd_o, bwd_o),
)],
[ "diabolo", lambda bwd_i,bwd_o: nn.Sequential(
bwd_σ, nn.Linear(bwd_σ.m * (bwd_dim_i//1), bwd_dim_i//2, dtype=torch.cfloat),
bwd_σ, nn.Linear(bwd_σ.m * (bwd_dim_i//2), bwd_dim_i//4, dtype=torch.cfloat),
bwd_σ, nn.Linear(bwd_σ.m * (bwd_dim_i//4), bwd_dim_o//4, dtype=torch.cfloat),
bwd_σ, nn.Linear(bwd_σ.m * (bwd_dim_o//4), bwd_dim_o//2, dtype=torch.cfloat),
bwd_σ, nn.Linear(bwd_σ.m * (bwd_dim_o//2), bwd_dim_o//1, dtype=torch.cfloat),
)],
[ "loglinexp", lambda bwd_i,bwd_o: nn.Sequential(
Log(minimum=2**-8),
Linear(bwd_i, bwd_o), # multiplications with a*b = exp(log(a)+log(b))
Exp(maximum=2**6),
Linear(bwd_o, bwd_o), # normal additions
)],
[ "loglinexp-unbounded", lambda bwd_i,bwd_o: nn.Sequential(
Log(minimum=False),
Linear(bwd_i, bwd_o),
Exp(maximum=False),
Linear(bwd_o, bwd_o),
)],
[ "logexpgelu", lambda bwd_i,bwd_o: nn.Sequential(
Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)]),
Linear(4*bwd_i, bwd_o),
Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)]),
Linear(4*bwd_o, bwd_o),
)],
[ "deep-fc", lambda bwd_i,bwd_o: nn.Sequential(
Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)]),
Linear(4*bwd_i, bwd_i),
Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)]),
Linear(4*bwd_i, bwd_i),
Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)]),
Linear(4*bwd_i, bwd_i),
Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)]),
Linear(4*bwd_i, bwd_o),
Stack([nn.Identity(),GELU(),Exp(maximum=2**6),Log(minimum=2**-8)]),
Linear(4*bwd_o, bwd_o),
)],
]
blue = log(loss_mse)
red = log(loss_cos)
green = loss_dot
param
dim_i, dim_o
loss
opti
dim_sample
bwd_σ
value