Structured Linear CDEs - Part 1
Introduction
This is the first in a three-part blog series on Structured Linear Controlled Differential Equations (SLiCEs), a new framework for sequence models that balance maximal expressivity with efficient parallel-in-time computation by using input-dependent structured state-transition matrices. These models were introduced in our recent preprint. This series of blog posts will cover:
- The What: Understanding what SLiCEs are and how they fit within the broader sequence modelling landscape.
- The Why: Exploring their significance from theoretical and empirical perspectives.
- The How: A practical implementation in Jax.
Let's dive into Part 1 - The What.
Time Series Modelling
Given observations from a time series \( \{x_i\}_{i=1}^n \), a generic shallow learning method involves selecting a function \( g \) such that
\[ h_i = g(x_1, x_2, \ldots, x_{i-1}), \]
and making predictions via a fitted linear map \( C_\theta \),
\[ y_i \approx C_\theta h_i. \]
Q: Theoretically, can we match \( y_i \) as closely as we want?
A: Yes.
Spoiler: Let \( g \) calculate the signature, see here.
However, the signature captures everything and most problems are actually very focused. In order to tell if someone's ECG means they are dying, you don’t need to know if their heart rate matches the beat of The Imperial March. Therefore, learning a \( g_\theta \) to extract the useful information for your current task usually makes sense.
The general deep learning approach is to learn a function \( g_\theta \) and linear map \( C_\theta \) such that
\[ \begin{aligned} h_i &= g_\theta(x_1, x_2, \ldots, x_{i-1}), \\ y_i &\approx C_\theta h_i. \end{aligned} \]
Linear Recurrent Neural Networks
Recurrent Neural Networks (RNNs) take the generic form
\[ h_{i+1} = h_i + g_\theta(h_i, x_i), \]
and linear RNNs further restrict to
\[ h_{i+1} = h_i + A_\theta h_i + B_\theta x_i, \]
where \( A_\theta \in \mathbb{R}^{d_h \times d_h} \) and \( B_\theta \in \mathbb{R}^{d_h \times d_x} \). Due to the similarity to classical state-space models, a certain class of linear RNNs are known as structured state-space models (SSMs). The simple recurrence allows for parallelisation via associative scans or convolution. However, linear RNNs are not particularly expressive, as will be explored further in Part 2 - The Why.
Linear Neural Controlled Differential Equations
Maximal expressivity (also known as universality) can be achieved by allowing the state-transition matrix to depend linearly on \( x_i \),
\[ h_{i+1} = h_i + \sum_{j=1}^{d_x} A^j_\theta h_i x^j_i + B_\theta x_i. \]
Ignoring the bias term, this is a discretisation of a linear neural controlled differential equation (LNCDE),
\[ \mathrm{d}h_s = \sum_{j=1}^{d_x} A^j_\theta h_s \, \mathrm{d}\omega^{x, j}_s, \]
where \( \omega^{x,j}_{i+1} - \omega^{x,j}_i = x^j_i \). LNCDEs are maximally expressive (see here) and can still be calculated using parallel associative scans. However, their cost and parameter count grow as \( \mathcal{O}(d_\omega d_h^2) \), which limits their scalability.
Quick Tangent on Names
Since this modification can equivalently be viewed as $A_{\theta}$ at each time step being a linear map on the input data, these models are also known as input-dependent linear recurrent neural networks, and a large number of alternatives are detailed in Table 2 of this paper. However, a more accurate name would be linear multiplicative RNNs or linear 2-RNNs, both of which are non-linear RNNs where the matrix multiplying the hidden state depends linearly on the input. See here and here for examples. Personally, I'll stick to linear NCDEs, as being a mathematician inclines me towards viewing these as discretisations of continuous models.
Structured Linear Neural Controlled Differential Equations
Mamba is a LNCDE with diagonal state-transition matrices. But diagonal \( A^j_\theta \) are not maximally expressive, and empirically fail to length generalise on state-tracking tasks. SLiCEs are a framework for input-dependent structures that preserve expressivity whilst being cheaper than dense matrices.
Our paper proposes several SLiCEs:
- DPLR-LNCDEs: Take \( A^i_\theta = D^i_\theta + \sum_{i=1}^r u^i_\theta (v^i_\theta)^\top \). Examples include DeltaNet, DeltaProduct, and Gated DeltaNet.
- BD-LNCDEs: Take \( A^i_\theta = \mathrm{BlockDiag}(B^i_{\theta,1}, ..., B^i_{\theta,k}) \). Block-diagonal input-dependent linear RNN is an example.
- S-LNCDEs: Take each \( A^i_{\theta} \) to be a sparse matrix with \( \mathcal{O}(d_h^{1 + \epsilon}) \) non-zero entries for some \( \epsilon>0 \).
- WH-LNCDEs: Take \( A^i_{\theta} = H D^i_{\theta}, \) where \( H \) is a Hadamard matrix (entries \( \pm 1 \), mutually orthogonal rows).
All four of these choices have been shown to be maximally expressive, DPLR-LNCDEs in this paper and the other three in our paper. In all likelihood, there are other choices with the same theoretical results, and maybe better empirical results, so please reach out if you have any ideas!
Comparison of SLiCEs
The figure at the start of this blog post is a visual comparison of dense LNCDEs (DE-LNCDEs), diagonal LNCDEs (D-LNCDEs), diagonal-plus-low-rank LNCDEs (DPLR-LNCDEs), sparse LNCDEs (S-LNCDES), Walsh--Hadamard LNCDEs (WH-LNCDEs), and block-diagonal LNCDEs (BD-LNCDEs). The table below compares the models on parameter count, computational cost, and whether they are maximally expressive (Max. Exp.). Here, $d_{h}$ is the hidden dimension, $n$ is the sequence length, $b_j$ are BD-LNCDE's block-sizes, $r$ is DPLR-LNCDE's rank, $\epsilon$ is S-LNCDE's sparsity, and for simplicity we have taken $d_{\omega}=d_h$. Parallel cost is measured as $\mathcal{O}($ scan depth $,$ cost per composition $)$ when applying a parallel associative scan.
| Model | Parameters | Recurrent Cost | Parallel Cost | Max. Exp. |
|---|---|---|---|---|
| DE-LNCDEs | \( \mathcal{O}(d_h^3) \) | \( \mathcal{O}(n d_h^3) \) | \( \mathcal{O}(\log(n), d_h^3) \) | Yes |
| D-LNCDEs | \( \mathcal{O}(d_h^2) \) | \( \mathcal{O}(n d_h^2) \) | \( \mathcal{O}(\log(n), d_h^2) \) | No |
| DPLR-LNCDEs | \( \mathcal{O}(r d_h^2) \) | \( \mathcal{O}(n r d_h^2) \) | \( \mathcal{O}(\log(n), d_h^3) \) | Yes |
| S-LNCDEs | \( \mathcal{O}(d_h^{2 + \epsilon}) \) | \( \mathcal{O}(n d_h^{2 + \epsilon}) \) | \( \mathcal{O}(\log(n), d_h^3) \) | Yes |
| WH-LNCDEs | \( \mathcal{O}(d_h^2) \) | \( \mathcal{O}(n d_h^2) \) | \( \mathcal{O}(\log(n), d_h^3) \) | Yes |
| BD-LNCDEs | \( \mathcal{O}(d_h \sum_j b_j^2) \) | \( \mathcal{O}(n d_h \sum_j b_j^2) \) | \( \mathcal{O}(\log(n), d_h \sum_j b_j^2) \) | Yes |
Parallel Computation
As can be seen, block-diagonal LNCDE is the only maximally expressive structure that has a number of parameters, recurrent cost, and parallel associative scan cost that are strictly less than dense LNCDEs. This is due to block-diagonal matrices being closed under matrix multiplication, whereas the other choices are not. For large hidden dimensions, parallel associatve scans can incur high I/O costs, reducing their practical benefit. DeltaNet avoids this by using a chunk-wise algorithm specifically tailored for diagonal plus low rank matrices, see this blog post for details. Given that these chunk-wise algorithms can be applied to diagonal matrices, block-diagonal LNCDE with a pre-dominantly diagonal structure (\( b_i = 1 \) for \( i = 1, . . . , k − 1 \)) followed by a small dense block (b_k=b) emerges as an attractive solution when large hidden states are necessary. This structure can leverage an efficient chunk-wise algorithm for the diagonal portion and parallel associative scans for the small dense portion, whilst still significantly boosting the expressivity, as will be explored further in Part 2 - The Why.Next Time
Now that we've established what SLiCEs are, the next post will explore why they are powerful, covering both theoretical and empirical results.