backprop
Let's denote $α = (x,θ)$ where $x$ is the input part and $θ$ are trainable parameters and make a model
$$ f(α) = f(x;θ) := f_n(...f_2(f_1(x,θ_0),θ_1)..., θ_{n-1})$$
By naming the last layer $f_n:=loss$, the goal becomes to minimize $f(x;θ)$ for $θ$
$$ θ = \argmin_{Θ} f(x;Θ)$$
If $f$ is smooth enough, reasonably convex, gradient descent like algorithm could be a reasonable approach
$$ \al{ θ^{(k+1)}_m = θ^{(k)}_m - ε \blue{\frac{d}{dθ_m} f(x;θ^{(k)})} && ∀m}$$
$$\al{
\blue{\frac{d}{dθ_{m}} x_{n+1}}
&=
J_{\orange{x_{n }}}^{f_{n }}(x_{n },θ_{n }) ⋅
J_{\orange{x_{n-1}}}^{f_{n-1}}(x_{n-1},θ_{n-1}) \cdots
J_{\orange{x_{m+1}}}^{f_{m+1}}(x_{m+1},θ_{m+1}) ⋅
J_{\purple{θ_{m }}}^{f_{m }}(x_{m },θ_{m }) \\
}$$
where
$$J_{\orange{x_n}}^{f_n}(x,θ) = \mat{\frac{∂f_{n,i}}{∂\orange{x_{n}}_j}(x,θ)}_{i,j} $$
Indeed, the process could be rewritten as
$$\al{
x_1 &= x && \text{input}\\
\green{x_{l+1}} &\green{= f_l(x_l, θ_l)} && \text{intermediate} \\
\red{f(x;θ)} &\red{= x_{n+1}} && \text{output} \\
}$$
Indeed, one can compute
$$\al{
\blue{\frac{d}{dθ_{m}} x_{n+1}}
&= \frac{d}{dθ_{m}} \green{f_n(x_n;θ_n)} \\
\text{(chain rule)}
&= J_{x_n}^{f_n}(x_n;θ_n) \blue{\frac{d}{dθ_{m}} x_n}
+ J_{θ_n}^{f_n}(x_n;θ_n) \comment{\frac{d}{dθ_{m}} θ_n}{δ_{m,n}}\\
\text{(recurse)}
&= J_{x_n}^{f_n}(x_n;θ_n) ⋅
\blue{\left(J_{x_{n-1}}^{f_{n-1}}(x_{n-1};θ_{n-1}) {\color{cyan}\frac{d}{dθ_{m}} x_{n-1}} + J_{θ_{n-1}}^{f_{n-1}}(x_{n-1};θ_{n-1}) \comment{\frac{d}{dθ_{m}} θ_{n-1}}{δ_{m,{n-1}}} \right)} \\
\text{(recurse)}
&= J_{x_n}^{f_n}(x_n;θ_n) ⋅
J_{x_{n-1}}^{f_{n-1}}(x_{n-1};θ_{n-1}) \cdots
J_{x_{m+1}}^{f_{m+1}}(x_{m+1};θ_{m+1}) ⋅
\left(
J_{x_m}^{f_m}(x_m;θ_m) \comment{\frac{d}{dθ_m} x_m}{0} +
J_{θ_{m}}^{f_{m}}(x_{m};θ_{m})
\right) \\
}$$
Instead of targetting specifically $θ$, let's write down the gradient for $α$ to keep things lighter. This causes no problem as we would only need to ignore whatever variables we don't want to optimize.
The process could be rewritten as
$$\al{
α_1 &= (x,θ_1) && \text{input}\\
\green{α_{L+1}} &\green{= (f_L(α_L),θ_L)} && \text{intermediate} \\
\red{f(x;θ)} &\red{= f(α) = α_{n+1} = (f_n(α_n),∅)} && \text{output} \\
}$$
$$\al{
\blue{\frac{d}{dα_{m}} α_{n+1}}
&= J_{α_n}^{f_n}(α_n) ⋅
J_{α_{n-1}}^{f_{n-1}}(α_{n-1}) \cdots
J_{α_{m+1}}^{f_{m+1}}(α_{m+1}) ⋅
J_{α_{m}}^{f_{m}}(α_{m}) \\
}$$
One can tell there is a lot of computation recycling to be done here !
$$\al{
\blue{\frac{dα_{n+1}}{dα_{m}}}
&= \blue{\frac{dα_{n+1}}{dα_{m+1}}} ⋅ J_{α_{m}}^{f_{m}}(α_{m})
}$$
This recycling is expressed in pytorch as such :
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x) # save input for backward pass
return my_function(x) # actual evaluation
@staticmethod
def backward(ctx, grad_out):
if not ctx.needs_input_grad[0]: return None # if grad not required, don't compute
x, = ctx.saved_tensors # restore input from context
grad_input = torch.functional.vjp(my_function,x,grad_out) # grad_input = grad_out ⋅ J(x)
return grad_input # return gradient of the loss
idea
In the case of reservoir computing, $f$ is given and cheap to compute, but we don't have it's derivative.
Thus it is useful to introduce a surrogate model for that, so that one can make backpropagation happen.
we make a differentiable model approximation of how physical_driver maps inputs to measurement.
We train approximation(input) ≈ physical_driver(input)
class Option1(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return physical_driver(x)
@staticmethod
def backward(ctx, grad_out):
if not ctx.needs_input_grad[0]: return None
x, = ctx.saved_tensors
return torch.functional.vjp(approximation,x,grad_out)
we make J that maps inputs to the jacobian (all derivatives) of physical_driver
We train J(x) ⋅ ε ≈ physical_driver(x+ε) - physical_driver(x-ε)
class Option2(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return physical_driver(x)
@staticmethod
def backward(ctx, grad_out):
if not ctx.needs_input_grad[0]: return None
x, = ctx.saved_tensors
return torch.einsum('bi,bij->bj') grad_out, J(x))
We avoid the full representation of $J$ (it could be somewhat sparse).
we directly compute J(x,grad_out) ≈ grad_out.T⋅J(x).
class Option3(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return physical_driver(x)
@staticmethod
def backward(ctx, grad_out):
if not ctx.needs_input_grad[0]: return None
x, = ctx.saved_tensors
return J(x,grad_out) # here is the difference with option 2
However how do you train this ? Given that J(x,grad_out) ≈ grad_out.T ⋅ J(x) the following could make sens :
J(x,grad_out) ⋅ ε ≈ grad_out.T ⋅ (physical_driver(x+ε) - physical_driver(x-ε))
But now you need to choose grad_out. An obvious choice is to take all vectors of the canonical basis. An other option is to just use whatever happens to be in use during a training session on an actual problem. This way we make sure that the model gets finer on where we are evaluating the device on.
We might take some symmetries into account :
$$\al{
J(x,λg) &= λJ(x,g) \\
J(x,g_1)+J(x,g_2) &= J(x,g_1+g_2)
}$$
Maybe knowing something about how $J(x,⋅)$ is sparse could lead to some expected structure to $J$.