and its state-space models
Recently a plethora of methods within machine learning have made use of state-space models (SSMs) to model sequences
is associative
The (parallel) associative scan, as the name suggest, is a way to apply an associative operator in parallel. The simplest of such operator is addition, which is associative because we can switch around the order of computation (i.e $a + (b + c) = (a + b) + c$). In Jax the associative scan is implemented in the accurately named lax.associative_scan function while in Julia it is denoted by the accumulate function. If we supply the add operator the resulting computation will be equivalent with the cumulative sum, e.g. in Python
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
Array([0, 1, 3, 6], dtype=int32)
or in Julia
julia> accumulate(+, 0:3)
5-element Vector{Int64}:
0
1
3
6
While the above computation is trivial and easily computed by looping from the first to last element in the vector, the main idea of the associative scan is that the operator can be applied pairwise in parallel. That means if enough processors is available the computation can be done in $O(\log n)$ time rather than the trivial implementation of $O(n)$.
The first step in showing the associativity of the state-space model is to define the transition of the state-space models using matrix multiplication (which is associative) by embedding the transition into a larger matrix $\bs_k$ as follows
\[\bs_k = \begin{bmatrix} \bA_k & \bB_k \bx_k \\ \bzero & 1 \end{bmatrix}, \quad \bs_0 = \begin{bmatrix} \bzero & \bh_0 \\ \bzero & 1 \end{bmatrix}\]Using the definition of $\bs_k$ the state transition from state $k-1$ to state $k$ can be computed using matrix multiplication as
\[\bs_k\bs_{k-1} = \begin{bmatrix} \bA_k\bA_{k-1} & \bA_k(\bB_{k-1}\bx_{k-1}) + \bB_k\bx_k \\ \bzero & 1 \end{bmatrix}.\]Using this we can compute the $i$th state of the state-space model as
\[\begin{equation} \bx_i = \begin{bmatrix}\bI & \bzero \end{bmatrix} \left(\prod_{k=i}^0 \bs_k\right) \begin{bmatrix}\bzero \\ 1\end{bmatrix}. \end{equation}\]Given that the cumulative product can be computed using the associative scan operator the full dynamics can be computed as
\[\begin{equation} \begin{aligned} \bp_i &= \text{associative_scan}(\bs_i, \text{init} = \bs_0)\\ \by_i &= \bC \bh_i = \begin{bmatrix}\bC & \bzero \end{bmatrix} \bp_i \begin{bmatrix}\bzero \\ 1\end{bmatrix}. \end{aligned} \end{equation}\]While the above works, it can be simplified slightly. As we are really only interested in what happens in top block (as the top right block contain $\bx_i$) we can instead define elements by just the top row, i.e. instead define the states as $\bs_k = \begin{bmatrix}\bA_k & \bB_k \bx_k \end{bmatrix}$ and then define the associative operator (denoted by $\bullet$) by how the top row propagates, i.e.
\[\begin{equation} \bs_k \bullet \bs_{k-1} = \begin{bmatrix} \bA_k \bA_{k-1} & \bA_k (\bB_{k-1} \bx_{k-1}) + \bB_k \bx_k\end{bmatrix}. \end{equation}\]The SSM stages can then be computed as $\bp_i = \bs_i \bullet \bp_{i-1}$ with $\bs_0 = \begin{bmatrix} \bzero & \bh_0 \end{bmatrix}$.
As a final remark note that while the associative scan is parallelizable it performs matrix-matrix products of the form $\bA_k \bA_{k-1}$ which will be computational prohibitive unless $\bA_k$ has structure (e.g. diagonal or low-rank). This is one of the reasons why e.g. Mamba-2 utilizes a scaled identity as its $\bA_k$