Neural ODEs

Neural Ordinary Differential Equations (Neural ODEs) are a new and elegant type of mathematical model designed for machine learning. This model type was proposed in a 2018 paper and has caught noticeable attention ever since. The idea was mainly to unify two powerful modelling tools: Ordinary Differential Equations (ODEs) & Machine Learning.

This post will be a maths-heavy look at the concepts that lead up to Neural ODEs.

Ordinary Differential Equations

Ordinary differential equations are a staple of modelling in many different fields. Typically, the trajectory of an object over time is described using an ODE. For example, in classical mechanics, the trajectory of a projectile flying through the air can be modelled with the following ODEs:

\begin{array}{cc}m\dfrac{d^2x}{dt^2}(t)=-b_l\dfrac{dx}{dt}(t)-b_n\left(\dfrac{dx}{dt}(t)\right)^2\\ m\dfrac{d^2y}{dt^2}(t)=-b_l\dfrac{dy}{dt}(t)-b_n\left(\dfrac{dy}{dt}(t)\right)^2-mg\end{array}

These equations come from Newton’s Second Law of Motion, F=ma. The breakdown is as follows:

  • In the horizontal (x) direction, the forces on the projectile are only drag forces which act against its velocity. There is linear drag with coefficient b_l and nonlinear drag with coefficient b_n.
  • In the vertical (y) direction, the forces on the projectile are drag forces (as like the horizontal direction) and gravity (downwards with acceleration g).

The usual problem of interest is to find the projectile’s position and speed at some time T (i.e. x(T), y(T), \dfrac{dx}{dt}(T), \dfrac{dy}{dt}(T) ) given the initial conditions x(0), y(0), \dfrac{dx}{dt}(0), \dfrac{dy}{dt}(0) . This would be an Initial Value Problem (IVP). There are several well-known techniques that rely on time marching to solve IVPs, most of which fall under the Runge-Kutta family. Some of the common ones are Forward Euler, RK4, Leapfrog and Implicit Euler.

The most important technique that these solvers require is the transformation of the system of ODEs into a first-order system if it isn’t already one. The above system is second-order (due to the presence of second derivatives) and hence cannot be used as-is. The trick is to let all derivatives except the highest order be explicit variables, which decouples our system into a larger but first-order one. In this example, the decoupled system is:

\dfrac{d}{dt}\left(\begin{array}{cc}x\\ y\\ \dot{x}\\ \dot{y}\end{array}\right)(t)=\left(\begin{array}{cc}\dot{x}\\ \dot{y}\\ -\dfrac{b_l}{m}\dot{x}-\dfrac{b_n}{m}\dot{x}^2\\ -\dfrac{b_l}{m}\dot{y}-\dfrac{b_n}{m}\dot{y}^2-g\end{array}\right)(t)

The vector on the left is the state vector, usually denoted as h(t) . Observe that the vector on the right is purely a function of the state vector, f(h(t)) . However, in general, the function can explicitly depend on time (dynamic systems) and other external factors. We shall denote these external factors as \beta . In time-series analysis, h(t) contains the endogenous variables and \beta contains the exogenous variables.

The problem with passing these external factors to an ODE solver is that we need to know an expression for these factors as a function of time, i.e. \beta(t) . If they are constant, this isn’t an issue; if they aren’t, we can interpolate/extrapolate them across time. Either way, we can absorb \beta(t) into f by letting f explicitly depend on time. Thus, the above system can be generalised as:

\dfrac{dh(t)}{dt}=f(h(t),t)

There is a concern with the toy model we have above: anyone familiar with actual projectiles would know that this is not a very realistic model as there are many simplifying assumptions (e.g. time-dependence of the drag coefficients and effects of projectile rotation are unaccounted for).

What if we have data of an actual projectile’s motion through the air, as well as external factors, and we want to have an ODE model learn the dynamics? A Neural ODE would be the tool for this! Before that, I shall quickly discuss about DNNs and ResNets.

Deep Neural Networks

Deep neural networks need no detailed introduction, as they are already well-explained everywhere. Neural networks are usually applied in situations where there is an input variable x , an actual output variable y , and the actual output is related to the input by a function f :

y=f(x)

The input and predicted output could have any structure, such as scalars and vectors. An NN aims to solve a problem where a set of x,y are available, but the function that relates them is unknown. Every x,y is referred to as a sample and the set of samples is referred to as the dataset. It is for this reason that NNs are also called function estimators.

Hence, an NN attempts to estimate the function by imposing an architecture \tilde{f} with parameters \theta . The function takes in the input x and returns a predicted output \tilde{y} :

\tilde{y}=\tilde{f}(x,\theta)

NNs are unique in that their architectures are designed to closely mimic animal brains, and the term model is usually used to collectively refer to the architectures and parameters. For DNNs, the inputs and outputs are vectors, and the architecture consists of multiple layers, where each layer consists of elementary units known as perceptrons.

Every layer first performs a matrix multiplication with weights W and vector addition with biases b on the state before it, and a nonlinear function known as the activation is applied to every element of the resulting vector. For example, in layer k :

z_k=\sigma(W_kz_{k-1}+b_k)

The weights and biases of all layers are collectively the parameters of the model. The activation function is not strictly necessary, but its purpose is to make the DNN’s architecture nonlinear. If the architecture is linear, having multiple layers would be no different from just having one layer. The nonlinearity allows us to increase the predictive power of the architecture by increasing the number of layers.

NNs start out with random parameters, so the initial predicted outputs are generally random. Hence, there will generally be an error between every predicted output and actual output, E(\tilde{y},y) . A simple example is the mean-squared error:

E_{\text{MS}}(\tilde{y},y)=\Vert\tilde{y}-y\Vert_2^2

Since actual outputs might be a result of measurement, there might be measurement noise within them, and thus it might not be wise to try to minimise the error for every sample. Typically, the errors of multiple samples are averaged to obtain the loss, and it would be more effective to minimise the loss. A group of samples from which a loss is calculated is known as a batch, which is a subset of the dataset.

L(\tilde{y},y)=\dfrac{1}{\text{Batch Size}}\sum\limits_{\{x,y\}\in\text{Batch}}\left(E\left(\tilde{f}(x,\theta),y\right)\right)

An important process for an NN is model training, where the parameters of the model are repeatedly adjusted in small increments to minimise the loss of every batch, through mathematical optimisation. There are several optimisers for this purpose, the simplest being Stochastic Gradient Descent (SGD). This method updates the parameters by finding the gradient of the loss w.r.t the parameters, and having the parameters step in that direction with a distance determined by step size \alpha :

\theta_{n+1}=\theta_n+\alpha\nabla_{\theta_n}L(\tilde{y},y)

Since the parameters are orthogonal to one another, the gradient is simply the transpose of the extrinsic derivatives:

\theta_{n+1}=\theta_n+\alpha\left(\dfrac{dL(\tilde{y},y)}{d\theta_n}\right)^T

The issue with computing these derivatives is that the parameters are not directly related to the loss. The mathematical operation of every layer has a derivative and the chain rule has to be applied backwards across all layers. This process is referred to as backpropagation.

Fortunately, the computational technique of automatic differentiation performs this automatically. It is commonly used as it is fast and gives exact solutions. After training, the model becomes very accurate at predicting outputs for inputs within the dataset. For inputs outside the dataset, the model may be inaccurate if the problem of overfitting occurs.

Residual Networks

There is a popular class of NNs known as Residual Neural Networks (ResNets). Instead of modelling the relationship between an input and output, they model the difference (residual) between the input and output:

\tilde{y}=x+\tilde{f}(x,\theta)

This method became popular due to the following problem: As a state passes through layer after layer, more and more of the original information is lost. Hence, the more layers a model has, the more likely it is for overfitting to occur. This effectively set an upper bound for how many layers a model could have.

ResNets attempt to alleviate the issue as follows: We could take an earlier state and add it in at a later time, thereby allowing its information to be retained across the model. This is referred to as a skip connection. This has been very useful in image processing, in applications such as style transfer and image segmentation.

Recurrent ResNets

If we have a ResNet where the input is some state at time-step n , and the output is the state at time-step n+1 , then the network models the first-order difference. Since the output of the model is indirectly fed back into the input, this becomes a type of Recurrent Neural Network (RNN). This method of modelling discretises time using the forward difference scheme:

\tilde{h}_{n+1}=h_n+\tilde{f}(h_n,\theta)

We can easily rearrange this into the following Finite Difference Equation (FDE):

\dfrac{\tilde{h}_{n+1}-h_n}{\Delta t}=\frac{1}{\Delta t}\tilde{f}(h_n,\theta)

We can do a few things here:

  • We can make the model invariant to the time-step size by absorbing \frac{1}{\Delta t} into it. This makes it learn how to predict the approximate first derivative instead of the residual.
  • We can allow the model to take exogenous variables by adding an explicit dependence on the time-step.
  • We can allow the model to extrapolate for multiple time-steps by indicating the current state as a predicted state.

\dfrac{\tilde{h}_{n+1}-\tilde{h}_n}{\Delta t}=\tilde{f}(\tilde{h}_n,\theta,n)

For convenience, we can use the following notation for the forward difference operator:

\Delta_{\Delta t}[\tilde{h}_n]:=\tilde{h}_{n+1}-\tilde{h}_n

\dfrac{\Delta_{\Delta t}[\tilde{h}_n]}{\Delta t}=\tilde{f}(\tilde{h}_n,\theta,n)

We can step the function forward in time multiple times by predicting the approximate first derivative, scaling it by the step size, and adding it to the solution at each time-step (Forward Euler method):

\tilde{h}_{n+k}=h_n+\sum\limits_{i=n}^{n+k-1}\tilde{f}(\tilde{h}_i,\theta,i)
\tilde{h}_n=h_n

Were our concern to minimise the single-step loss, we can simply differentiate the loss directly. However, our intention here is to minimise the loss across multiple steps (multi-step loss). In this case, we can attempt to optimise the loss as-is:

\delta L=\dfrac{\partial L(\tilde{h}_{n+k},h_{n+k})}{\partial\tilde{h}_{n+k}}\delta\tilde h_{n+k}

State Unfolding Method

We can get the change in state \delta\tilde h_{n+k} unfolding the state from the predicted to current state. This gives us the following loss gradient:

\dfrac{dL}{d\theta}=\dfrac{\partial L(\tilde{h}_{n+k},h_{n+k})}{\partial\tilde{h}_{n+k}}\sum\limits_{i=n}^{n+k-1}\left(I+\dfrac{\partial\tilde{f}(\tilde{h}_i,\theta,i)}{\partial\tilde{h}_i}\Delta t\right)^{i-n}\dfrac{\partial\tilde{f}(\tilde{h}_i,\theta,i)}{\partial\theta}\Delta t

Adjoint Sensitivity Method

Alternatively, we can use the well-known adjoint sensitivity method, which leads to the same result but with a computationally-efficient formulation. We include the FDE as a constraint using Lagrange multipliers. We shall use multipliers a_i^T\Delta t, in which each a_i^T is asymptotically independent of the time-step size. For example, if the current time-step is n and we have the actual solution for k steps ahead:

L= L(\tilde{h}_{n+k},h_{n+k})+\sum\limits_{i=n}^{n+k-1}a_i^T\left(\dfrac{\Delta_{\Delta t}\left[\tilde{h}_i\right]}{\Delta t}-\tilde{f}(\tilde{h}_i,\theta,i)\right)\Delta t

We can then proceed to perturb the parameters to minimise this loss. The derivation of the loss gradient is not shown here but will shown for Neural ODEs later. We end up with the result:

\left(\dfrac{dL}{d\theta}\right)_{n+k}=-\sum\limits_{i=n+k-1}^na_i^T\dfrac{\partial\tilde{f}(\tilde{h}_i,\theta,i)}{\partial\theta}\Delta t

a_{n+j}^T=-\dfrac{\partial L(\tilde{h}_{n+k},h_{n+k})}{\partial\tilde{h}_{n+k}}+\sum\limits_{i=n+k-1}^{n+j}a_i^T\dfrac{\partial\tilde{f}(\tilde{h}_i,\theta,i)}{\partial\tilde{h}_i}\Delta t

Backpropagation Through Time

In both methods above, the parameter error propagates forward through the layers resulting in a model error at every time-step. These model errors accumulate across time from the initial to final time-step resulting in the loss. Conversely, when we backpropagate the loss, we have to go backward through time, accounting for the model error at each time-step.

This is known as backpropagation through time (BPTT), and is an essential technique for optimising RNNs. BPTT cannot be parallelised in the same way as with the layers, and relies on other techniques (e.g. multigrid method).

Increasing Model Order

Using the above model without modification, we can only model first-order FDEs. We can increase the FDE order by including approximate higher-order derivatives into the state. We can compute the initial state by applying numerical differentiation with an endpoint scheme on past samples. The model will predict the approximate highest-order derivative. E.g. for a 3rd order univariate FDE:

\dfrac{\Delta_{\Delta t}[\tilde{h}_n]}{\Delta t}=\dfrac{\Delta_{\Delta t}}{\Delta t}\left[\left(\begin{array}{cc}\tilde{x}_n\\ \Delta_{\Delta t}[\tilde{x}_n]\\ \Delta^2_{\Delta t}[\tilde{x}_n]\end{array}\right)\right]=\left(\begin{array}{cc}\Delta_{\Delta t}[\tilde{x}_n]\\ \Delta^2_{\Delta t}[\tilde{x}_n]\\ \hat{f}(\tilde{h}_n,\theta,n)\end{array}\right)

However, one major issue with the Forward Euler method and all explicit methods is that they may be unstable, meaning that for certain problems, the prediction diverges from the actual solution over time. We can choose a different time-marching method that is stable, but this would require deriving the loss gradient for that particular method.

Ideally, we would like to have an architecture where we only need to derive the loss gradient once for any time-marching method, and a Neural ODE is the answer.

Neural ODEs

Neural ODEs have the same structure as Recurrent ResNets but in the limit where time becomes continuous. Where we utilised an NN to model the approximate first derivative of the state for Recurrent ResNets, we model the exact first derivative for Neural ODEs:

\dfrac{d\tilde{h}(t)}{dt}=\tilde{f}(\tilde{h}(t),\theta,t)

If we know the state at time t_0, we can perform extrapolation to obtain the state at time t_1 by solving an IVP:

\tilde{h}(t_1)=h(t_0)+\int_{t_0}^{t_1}\tilde{f}(\tilde{h}(t),\theta,t)\,dt
\tilde{h}(t_0)=h(t_0)

An ODE solver can compute this integral numerically and automatically. However, the main problem is that we need a method to update the weights based on the loss, i.e. we need an expression for the loss gradient \dfrac{dL}{d\theta} based on quantities we can feasibly compute.

Imagine that we know the state at t_0 and we have the actual solution at t_1. Let us first try to minimise the loss without imposing any constraint. Firstly, we need to find out how it changes when we perturb the parameters slightly by \delta\theta:

\delta L=\dfrac{\partial L}{\partial\tilde{h}(t_1)}\delta\tilde{h}(t_1)

In this expression, we know that \delta\tilde{h}(t_1) is related to \delta\theta but if we attempt to express the relation, we end up with an infinite recursion due to the way a perturbation in the parameters propagates through infinite time-steps to affect the loss.

State Unrolling Method

What we can instead do is to take the equations we got when we used the unfolding method in the discrete case, and take the limit when the time-step size goes to 0. What we get is a continuous unfolding of the state, which we can refer to as unrolling. This leads to some nasty limits in the calculation but an ultimately simple result:

\dfrac{dL}{d\theta}(t_1)=\dfrac{\partial L}{\partial\tilde{h}(t_1)}\int_{t_1}^{t_0}\exp\left((t-t_0)\dfrac{\partial\tilde{f}}{\partial\tilde{h}}(\tilde{h}(t),\theta,t)\right)\,dt\,\dfrac{\partial\tilde{f}}{\partial\theta}(\tilde{h}(t),\theta,t)

As we can see, this method requires us to compute matrix exponentials, which can be computationally expensive.

Adjoint Sensitivity Method

Let us use Lagrange multipliers instead. Our constraint will be the ODE itself, which acts as a connection between neighbouring time-steps, forming something akin to a chain we can perform BPTT on.

We have infinitely many time-steps between t_0 and t_1, and we need to impose the constraint on every time-step. Hence, we shall define our multiplier to be an infinitesimal function of t and integrate over the constraints. Since each constraint is a column vector and our objective is a scalar, we shall let our multiplier be a row vector, a^T(t)\,dt, and left-multiply it to the constraint:

L=L(\tilde{h}(t_1),h(t_1))+\int_{t_0}^{t_1}a^T(t)\left(\dfrac{d\tilde{h}(t)}{dt}-\tilde{f}(\tilde{h}(t),\theta,t)\right)\,dt

We introduce a small perturbation to the parameters to get a small change in the loss:

\delta L=\dfrac{\partial L}{\partial\tilde{h}(t_1)}\delta\tilde{h}(t_1)+\int_{t_0}^{t_1}a^T(t)\left(\dfrac{d\delta\tilde{h}(t)}{dt}-\dfrac{\partial\tilde{f}}{\partial\tilde{h}}(\tilde{h}(t),\theta,t)\delta\tilde{h}(t)-\dfrac{\partial\tilde{f}}{\partial\theta}(\tilde{h}(t),\theta,t)\delta\theta\right)\,dt

If we choose a^T(t) wisely, we can eliminate terms with \delta\tilde{h}(t\ne t_0), thereby decoupling the time-dependence from the loss, so that the loss only depends on the parameters. For this reason, a^T(t) is referred to as the adjoint state.

\dfrac{d\delta\tilde{h}(t)}{dt} is a problematic term to work with, so we need to convert it into \delta\tilde{h}(t) by integrating \int_{t_0}^{t_1}a^T(t)\dfrac{d\delta\tilde{h}(t)}{dt}\,dt by parts:

\int_{t_0}^{t_1}a^T(t)\dfrac{d\delta\tilde{h}(t)}{dt}\,dt=a^T(t_1)\delta\tilde{h}(t_1)-a^T(t_0)\delta\tilde{h}(t_0)-\int_{t_0}^{t_1}\dfrac{da^T(t)}{dt}\delta\tilde{h}(t)\,dt

Note that \delta\tilde{h}(t_0)=0 since the initial state is not affected by a perturbation of the parameters. Substituting this result into the original expression:

\delta L=\dfrac{\partial L}{\partial\tilde{h}(t_1)}\delta\tilde{h}(t_1)+a^T(t_1)\delta\tilde{h}(t_1)-
\int_{t_0}^{t_1}\left(\dfrac{da^T(t)}{dt}\delta\tilde{h}(t)+a^T(t)\dfrac{\partial\tilde{f}}{\partial\tilde{h}}(\tilde{h}(t),\theta,t)\delta\tilde{h}(t)+a^T(t)\dfrac{\partial\tilde{f}}{\partial\theta}(\tilde{h}(t),\theta,t)\delta\theta\right)\,dt

Within this equation, we can require that a^T(t) satisfy certain conditions, so that we separate the loss’ dependence on state and time into these independent equations:

\dfrac{\partial L}{\partial\tilde{h}(t_1)}\delta\tilde{h}(t_1)+a^T(t_1)\delta\tilde{h}(t_1)=0

\dfrac{da^T(t)}{dt}\delta\tilde{h}(t)+a^T(t)\dfrac{\partial\tilde{f}}{\partial\tilde{h}}(\tilde{h}(t),\theta,t)\delta\tilde{h}(t)=0

I will come back to these requirements in a moment. The expression for the change in the loss then simplifies to the following equation, which is only dependent on the model parameters:

\delta L=-\int_{t_0}^{t_1}a^T(t)\dfrac{\partial\tilde{f}}{\partial\theta}(\tilde{h}(t),\theta,t)\delta\theta\,dt

Since our perturbation \delta\theta is independent of time, we can bring it out of the integral and divide both sides by it. Then, we can let the perturbation approach 0^+, which results in the loss gradient. The resulting equation is the adjoint equation of this problem:

\dfrac{dL}{d\theta}(t_1)=-\int_{t_0}^{t_1}a^T(t)\dfrac{\partial\tilde{f}}{\partial\theta}(\tilde{h}(t),\theta,t)\,dt

We now have the time-derivative of the loss gradient, which is the term in the integral. Going back to the requirements we set earlier, if we let a^T satisfy the following properties, we can ensure that the requirements will be always satisfied:

a^T(t_1)=-\dfrac{\partial L}{\partial\tilde{h}(t_1)}
\dfrac{da^T(t)}{dt}=-a^T(t)\dfrac{\partial\tilde{f}}{\partial\tilde{h}}(\tilde{h}(t),\theta,t)

Notice that the second equation is an ODE and the first equation is the final condition of a^T. However, we can reverse time, and the first equation becomes the initial condition. We can then solve for a^T(t) using an ODE solver by marching backward in time.

a^T(t)=-\dfrac{\partial L}{\partial\tilde{h}(t_1)}+\int_{t_1}^{t_0}-a^T(t)\dfrac{\partial\tilde{f}}{\partial\tilde{h}}(\tilde{h}(t),\theta,t)\,dt

Also notice that the expression for the loss gradient is an integral that we can solve for as long as we have the adjoint state; it does not matter which direction we traverse time in. We can traverse backwards by reversing the limits of integration:

\dfrac{dL}{d\theta}(t_1)=\int_{t_1}^{t_0}a^T(t)\dfrac{\partial\tilde{f}}{\partial\theta}(\tilde{h}(t),\theta,t)\,dt

This way, we can have our ODE solver solve for the adjoint state and loss gradient by combining them as an augmented state, then marching backwards in time:

  • ODE: \dfrac{d}{dt}\left(\begin{array}{cc}a^T\\ \dfrac{dL}{d\theta}\end{array}\right)(t)=\left(\begin{array}{cc}-a^T\dfrac{\partial\tilde{f}}{\partial\tilde{h}}\\ a^T\dfrac{\partial\tilde{f}}{\partial\theta}\end{array}\right)(t)
  • Initial Value: \left(\begin{array}{cc}-\dfrac{\partial L}{\partial\tilde{h}(t_1)}\\ 0\end{array}\right)
  • Start Time: t_1
  • End Time: t_0

Ultimately, we can see that the purpose of the adjoint state is to isolate the entire time-dependence of the loss as an explicit ODE we can solve for. In this way, our backpropagation through layers to obtain \dfrac{\partial\tilde{f}}{\partial\theta} becomes time-independent and only needs to be evaluated once per sample. The derivative \dfrac{\partial\tilde{f}}{\partial\tilde{h}} also becomes time-independent.

The results here are almost identical to the Recurrent ResNet, except that the loss gradient has no negative sign here. The reason is that when we swap the order of integration here, dt changes sign; whereas for the Recurrent ResNet, since \Delta t was defined explicitly, the sign does not change when we swap the order of summation.

Increasing Model Order

As with Recurrent ResNets, the ODE order can be increased by including higher-order derivatives into the state. Since the time-steps may be irregular, the endpoint scheme has to be recomputed for each sample. For a kth-order ODE, this requires solving a k×k matrix equation on every sample. Needless to say, this introduces discretisation errors. E.g. for a 3rd order univariate ODE:

\dfrac{d\tilde{h}(t)}{dt}=\dfrac{d}{dt}\left(\begin{array}{cc}\tilde{x}(t)\\ \dot{\tilde{x}}(t)\\ \ddot{\tilde{x}}(t)\end{array}\right)=\left(\begin{array}{cc}\dot{\tilde{x}}(t)\\ \ddot{\tilde{x}}(t)\\ \hat{f}(\tilde{h}(t),\theta,t)\end{array}\right)

I might attempt to apply Neural ODEs for time-series forecasting in time and hopefully observe interesting results.

Conclusion

The mathematics involved are no more complicated than solving ODEs. Applying Lagrange multipliers to the loss function may be rather messy, but the idea itself is simple. The adjoint sensitivity method may be difficult to understand, but is incredibly elegant in transforming a problem with infinite complexity into a form that we can work with.

The main analogue between Recurrent ResNets and Neural ODEs is that:

  • For Recurrent ResNets, we first discretise the time-derivative, then differentiate the resulting equation for BPTT.
  • For Neural ODEs, we first differentiate the time-derivative, then discretise the resulting equation for BPTT.

This allows the adjoint equation to be agnostic of the discretisation method used. If stability is desired, an implicit method can be used. If stiffness of the ODE is a concern, a symplectic integrator can be used.

Leave a comment

Design a site like this with WordPress.com
Get started