x = torch.randn(N,dim_i ).to(device) # pick N evaluation point at random
θ = torch.randn(N,dim_i,dim_o).to(device) # pick N evaluation point at random (if θ_batch=False)
θ = torch.randn(dim_i,dim_o).repeat(N,1,1).to(device) # pick 1 evaluation point at random (if θ_batch=True)
dx = (torch.randn(x.shape)*ε).to(device) # pick a variation
dθ = (torch.randn(θ.shape)*ε).to(device) # pick a variation
df = blackbox(x+dx) - blackbox(x) # A(x+ε) - A(x) = Aε
go = torch.randn(N,dim_o).to(device) # pick a grad_o at random
gx,gθ = bwd_model(x,θ,go)
should = (gx*dx).sum(dim=1).view(1,-1) + (gθ*dθ).sum(dim=1).sum(dim=1).view(1,-1) # for subspace
found = (go*df).sum(dim=1).view(1,-1) # for subspace (corresponding ones)
loss_mse = mse(should, found)
loss_cos = 1 - cossim(should, found).mean()
loss_dot = - (should*found/(torch.norm(should,dim=1)**2)).sum() # scale down by 1/|should|blue = log(loss_mse)
red = log(loss_cos)
green = loss_dot