header
x = torch.randn(N,dim_i ).to(device) # pick N evaluation point at random θ = torch.randn(N,dim_i,dim_o).to(device) # pick N evaluation point at random (if θ_batch=False) θ = torch.randn(dim_i,dim_o).repeat(N,1,1).to(device) # pick 1 evaluation point at random (if θ_batch=True) dx = (torch.randn(x.shape)*ε).to(device) # pick a variation dθ = (torch.randn(θ.shape)*ε).to(device) # pick a variation df = blackbox(x+dx) - blackbox(x) # A(x+ε) - A(x) = Aε go = torch.randn(N,dim_o).to(device) # pick a grad_o at random gx,gθ = bwd_model(x,θ,go) should = (gx*dx).sum(dim=1).view(1,-1) + (gθ*dθ).sum(dim=1).sum(dim=1).view(1,-1) # for subspace found = (go*df).sum(dim=1).view(1,-1) # for subspace (corresponding ones) loss_mse = mse(should, found) loss_cos = 1 - cossim(should, found).mean() loss_dot = - (should*found/(torch.norm(should,dim=1)**2)).sum() # scale down by 1/|should| blue = log(loss_mse) red = log(loss_cos) green = loss_dot
param
dim_i, dim_o
loss
opti
dim_sample
bwd_σ
θ_batch
value