Subho's research at your service 🫡

Understanding Lightning Attention: A Breakthrough in Linear Attention Efficiency

We all needs attention in our lives, don't we? But its surprisingly difficult to compute the exponential terms which helps us taking the best plausible solution. It goes like this,

Attention(Q,K,V)=softmax(QKTdk)V

The main issue here is QKT multiplication, which perhaps has the complexity of O(n2d). What's n and d? Its sequence length and inner dimension. So, when the sequence length gets 2x your attention computation complexity gets 4x, hugee!!

Man thats ok, but what does linear attention do to counter this and how lightning attention exploits it to carve best performance of it?

In its basic form, linear attention decomposes the attention mechanism into the inner product of hidden representations, allowing for more efficient computation like this,

LinearAttention(Q,K,V)=ϕ(Q)(ϕ(K)TV)

Don't panic, I will give you the easiest example. You might have heard of the "kernel trick" of SVM right? If not you might even get the intuition here! So think of ϕ as a function which helps in generalizing softmax, such that,

softmax(QKT)V=ϕ(Q)(ϕ(K)TV)

This is the kernel trick, to approximate the softmax computation with a function ϕ, similarly you could map a point from lower dimensional space to higher dimension to make it seperable thats kernel trick in SVM.

Theoretically, kernel trick reduces the complexity to O(nd2). How so? By computing KTV first along dimension N and later an inner product with Q along dimension d, resulting in a complexity of nd2.

But thats Theoretical. It faces a significant challenge in causal (autoregressive) settings when the masking comes into the picture. To compute this efficiently using the linear attention "kernel trick", we need to maintain a running sum of key-value pairs.

kvt=stksvs

The sequential nature of cumsum prevents parallel computation, negating the theoretical O(nd²) efficiency advantage of linear attention. Each position must wait for all previous positions' computations to complete.

Get to lightning attention already, I'm hyped! Here you go.

But let me give you an intuition for lightning attention. Think of the cumsum problem like trying to maintain a running sum while reading a book - you need to keep updating your "memory" after each word. This is inherently sequential, just like RNNs need to process tokens one by one. Here's how Lightning Attention cleverly works around this,

[x1,...,xB][xB+1,...,x2B]...[xnB,...,xn]

Lightning Attention divides the input sequence into blocks. Let's say we have a sequence of length n and divide it into blocks of size B.

For each block t, the output is computed as,

Ot=[(QtKt)M]Vtintra-block+ΛQt(KV)inter-block

We had 2 part computations in RNN right? Let's get to the same intuition here,

  1. Short-term memory (intra-block): Within each block, we use regular attention - like how you can easily remember and relate words within the same paragraph.
  2. Long-term memory (inter-block): Between blocks, we use a modified linear attention - like maintaining a summary of previous paragraphs without needing every detail.

The key insight is in how the inter-block computation works. Instead of maintaining a running sum for each position, Lightning Attention updates a single KV matrix for each block,

KVt=λBKVt1+(λBΛ1Kt)Vt

here λ is a decay factor and Λ is a diagonal matrix for position-aware scaling, might tell you more as we progress further ;)

And the authors have also highlighted that even when we calculate the kvt with a complexity of O(nd2), its actually not GPU optimal. Why? Coz we are not able to parallelise across head dim (When it comes to theoretical linear attention).

But these guys have got us! They have implemented tiling first to compute linear attention in a causal setting. They did divide Q, K and V blocks into T blocks such that, blocks X1,X2,...XT are of size B × d, B being the sub set of sequence length.

This tiling helps them to get GPU optimal efficiency for Lightning Attention. Here is how the algorithm operates.

lightning_attn_algo

Now lets get to the complexity pov. Is it optimal? Lets find out! For the forward pass as mentioned above, the intra block computation is similar to regular attention so its complexity is along B dimension, so its O(B2d), again B being a subset of sequence length. And for the inter part we calculate by updating KV (Yup its similate to KV cache intuition, I get you my friend :p), so its O(Bd2), so the computation inside loop is O(B2d+Bd2).

Since we loop for T = n/B times, the total time complexity becomes,

O((B2d+Bd2)n/B) = O(nd2+nBd)

Wow! that was really nice! Though I am yet to figure out how the authors derive the backward pass to be the same complexity, may be my next blog post. Lets see the structural intuition from the paper.

structural_framework_lightning_attn

I'm not over yet!, who's gonna discuss lightning attention 2? Obviously me!

What if during forward pass, we could model how attention should diminish over distance in the sequence. Think of it like how your attention naturally fades when trying to connect words that are far apart in a sentence.

Initialize 𝐌B×B, where 𝐌ij=λij, if ij, else 0.

Initialize Λ=diag{λ,λ2,,λB}B×B.

The above equations initialise the diagonal matrix that contains position-specific scaling factors. The diagonal elements of Λ control how much weight each position gets in the final attention output.

On chip, compute 𝐎inter=Λ𝐐i(KV).

On chip, compute KV=λBKV+(λBΛ1𝐊i)𝐕i.

The λB defines the decay factor at the block level and the diagonal matrix Λ scales this interaction based on position.This scaling ensures that attention decays appropriately with distance.

Here is the structural representation of the Lightning attention 2,

lightning_attn_2_structural

Since its getting a bit late, I will be publishing this blog first and will share the implementation on the go, so you shall better keep yourself updated by following me on X

Here you could find the detailed implementation for lightning attn kernel - enjoy 🥂

See ya!! :p