Structured Masked Attention (SMA)

Masked Attention (MA) is given by the relation [1]

where is a mask applied to the matrix using the elementwise Hadamard product (denoted by ). When the mask is structured one can most often apply multiplication with efficiently, and we refer to the masked attention as Structured Masked Attention (SMA). In the simplest case of being a lower triangular matrix filled with ones the SMA reduces to

which can be viewed as a weighted cumulative sum. While the above is a structured computation, it is not efficient unless itself have some structure. An example of when is structured is when it is of low-rank (i.e. ). The resulting SMA is a semiseparable matrix for which multiplication can be applied in linear time [2].

In general if is a low-rank matrix (of rank ) then multiplication with scales as -times the scaling of multiplication with as

for which we see that we need to perform multiplications with as well as diagonal multiplications.

using LinearAlgebra, Test
n, p = 10, 2
U, V = randn(n,p), randn(n,p)
B = rand(n,n)
M = B .* (U*V')
x = randn(n)
@test M*x sum(i -> Diagonal(U[:,i])*(B*(Diagonal(V[:,i])) * x),1:p)

State-space models as structured matrices

First we recap that a state-space model is given by the equations

An alternative way to view the state-space model is through the lens of structured matrices. In particular a state-space model up until time step can also be written using block matrices as well as block bidiagonal matrices as

We know that by simply iterating forwards in time we can compute the states in linear time. As such it is not a surprise that the inverse of the bidiagonal matrix (i.e. ) can be computed in linear time as

This type of matrix structure is called semiseparable [2]. Using the explicit inverse of we can compute the hidden states efficiently as

Similarly, the output can be computed by applying the output matrices in a blocked fashion to the hidden states, i.e.

using LinearAlgebra, BlockBandedMatrices, Test, SymSemiseparableMatrices

T = 10 # Sequence length
n = 6 # State size
input_dim = 1 # Dimension of forcing term

A_blks = [rand()*I(n) for _ in 1:T-1] # Diagonal dynamics in Mamba 2
D_blks = [I(n)/1.0 for i in 1:T] # Diagonal blocks are identity
C_blks = [rand(n,1) for _ in 1:T] # Measurements are scalars
B_blks = [rand(n,input_dim) for _ in 1:T] # Inputs are scalars
# Defining zeros blocks for the matrices
A_zero_blks = [zeros(n,n) for b in A_blks]
C_zero_blks = [zeros(n,1) for _ in 1:T-1]
B_zero_blks = [zeros(n,input_dim) for _ in 1:T-1]
# Defining the block matrices
A = BlockTridiagonal(-A_blks,D_blks,A_zero_blks)
C = BlockTridiagonal(C_zero_blks,C_blks,C_zero_blks)
B = BlockTridiagonal(B_zero_blks,B_blks,B_zero_blks)

# Computing states by iterating forward
x_blks = [randn(input_dim) for _ in 0:T-1] # input data
h_blks = [B_blks[1]*x_blks[1]] # initial hidden state
for i in 2:T
push!(h_blks, A_blks[i-1]*h_blks[i-1] + B_blks[i]*x_blks[i])
end
y_blks = [C'*h for (C,h) in zip(C_blks,h_blks)]

# Computing states using semiseparable matrices
Ai_blks = [prod(A_blks[1:i-1],init=1.0*I(n)) for i in 1:T]
U = vcat(Ai_blks...)
V = vcat(inv.(Ai_blks)...)
# "SymSemiseparableCholesky" represents the matrix A^{-1} = tril(UV')
Ai = SymSemiseparableCholesky(U',V')
x = vcat(x_blks...) # Collecting input data

@test Ai*(B*x) vcat(h_blks...) # Checking hidden states
@test C'*(Ai*(B*x)) vcat(y_blks...) # Checking measurement

State-space models as SMAs

In this section we aim to show derive that the simplified SSM in the mamba-2 paper is a special case of Structured Masked Attention (SMA) [1]. That is that the multiplication

can be written differently as

The SSM in the Mamba-2 paper restricts the dynamics of to be scalar-times-identity dynamics in order for the masked to be structured [1]. In short this restriction mean that the dynamics for all hidden states are independent but equal.

From a practical point-of-view this mean that we can collect each index of the dynamics and treat them separately. The resulting -matrix can be described by a Kronecker product (depending on how we organize the states its either or ). In the following we choose to separate the states, resulting in having the form

Using that the inverse of a Kronecker product is the Kronecker product of the inverses it follows that

where we further used that the inverse of a bidiagonal matrix is semiseparable. Furthermore, we have to re-arrange and which result in

The final multiplication will therefore look as

Finally, using the properties of the Hadamard product we get to the Structured Masked Attention form that we were looking for

This mean that the SSM dynamics can be interpreted as a structured masked attention mechanism. Note that in the case of the dynamics being independent but different (i.e we would have rather than just ) the SSM dynamics would result in a sum of masked attentions, i.e

using LinearAlgebra, Test, SymSemiseparableMatrices

T = 10 # Sequence length
n = 6 # State size

# Here we treat the matrices in terms of their states and not sequence lengths
a_blks = rand(T-1)
a = Bidiagonal(ones(T),-a_blks,:L)
A = kron(I(n),a) # Kronecker product with identity

# The blocks are now size equal to the sequence and a block for each state!
B_blks = [rand(T) for _ in 1:n]
C_blks = [rand(T) for _ in 1:n]
# Collecting the blocks into the B and C
B = vcat(Diagonal.(B_blks)...)
C = vcat(Diagonal.(C_blks)...)

# We want to see if the full matrix M = C'*(A\B) can be written as
# structured masked attention ie.
# M = (CB')\circ a^{-1}. For this we start by computing a^{-1}
ai = inv(a) # We ignore here that inv(a) is semiseparable
@test C'*(A\B) sum(i-> Diagonal(C_blks[i])*ai*Diagonal(B_blks[i]),1:n)
@test C'*(A\B) sum(i->(C_blks[i]*B_blks[i]') .* ai , 1:n)

# We can collect the terms and write is as Structured Masked Attention!
Cn = hcat(C_blks...)
Bn = hcat(B_blks...)
@test C'*(A\B) (Cn*Bn').*ai

# We can apply the semiseparable structure of the inverse when multiplying!
ai_blks = [prod(a_blks[1:i-1],init=1.0) for i in 1:T]
u = vcat(ai_blks...)
v = vcat(inv.(ai_blks)...)
# "SymSemiseparableCholesky" represents the matrix A^{-1} = tril(UV')
ais = SymSemiseparableCholesky(u',v')
# Efficient products using the structure of "ais" and diagonal B and C
x = randn(T)
@test C'*(A\(B*x)) sum(i-> C_blks[i].*(ais*(B_blks[i].*x)),1:n)

Bibliography

  • [1] T. Dao and A. Gu, “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,” arXiv, 2024, [Online]. Available: https://arxiv.org/abs/2405.21060
  • [2] M. S. Andersen and T. Chen, “Smoothing Splines and Rank Structured Matrices: Revisiting the Spline Kernel,” SIAM Journal on Matrix Analysis and Applications, vol. 41, no. 2, pp. 389–412, 2020, doi: 10.1137/19M1267349.