header
x = torch.randn(N,dim_i ).to(device) # pick N evaluation point at random θ = torch.randn(N,dim_i,dim_o).to(device) # pick N evaluation point at random (if θ_batch=False) θ = torch.randn(dim_i,dim_o).repeat(N,1,1).to(device) # pick 1 evaluation point at random (if θ_batch=True) dx = (torch.randn(x.shape)*ε).to(device) # pick a variation dθ = (torch.randn(θ.shape)*ε).to(device) # pick a variation df = blackbox(x+dx) - blackbox(x) # A(x+ε) - A(x) = Aε go = torch.randn(N,dim_o).to(device) # pick a grad_o at random gx,gθ = bwd_model(x,θ,go) should = (gx*dx).sum(dim=1).view(1,-1) + (gθ*dθ).sum(dim=1).sum(dim=1).view(1,-1) # for subspace found = (go*df).sum(dim=1).view(1,-1) # for subspace (corresponding ones) loss_mse = mse(should, found) loss_cos = 1 - cossim(should, found).mean() loss_dot = - (should*found/(torch.norm(should,dim=1)**2)).sum() # scale down by 1/|should| blue = log(loss_mse) red = log(loss_cos) green = loss_dot
param
dim_i, dim_o
loss_mode
opti_mode
dim_sample
enable_log_exp
logexp_cap
θ_batch
value