Learning-AI

RWKV: Reinventing RNNs for the Transformer Era

August 2023

tl;dr: Linear attention (AFT) that allows for efficient parallelizable training as a tranformer, AND efficient inference as an RNN.

Overall impression

Transformers suffer from memory and computational complexity that scales quadratically with sequence length. The overarching motivation behind developing RWKV is to bridge the gap between computational efficiency and expressive capacity in neural network architecture. RWKV paves the way toward next-gen sustainable and computationally efficient AI models for seq processing tasks.

RWKV stretches the notion of attention to the point that it is NOT really an attention but rather AFT. AFT can be seen as a MHA where each feature dimension corresponds to a head (n_channels == n_heads). Note that the R in RWKV is essentially Q in AFT, and rebranded as receptance.

RWKV is very similar to RetNet, achieving the impossible triangle of parallelizable training AND efficient inference AND Transformer-level language modeling quality.

Efficient, RNN-style inference means it’s possible to run an int8 14B parameter RWKV model on sequences of any length with a constant memory requirement of 3GB VRAM. This opens up opportunities for language model-powered cognitive features in tightly-constrained edge environments with streaming inputs, like robotics, even if RWKV turns out, like other Transformer alternatives, to fall off the scaling laws eventually.

Key ideas

\[\text{Attn}^+(W, K, V)_t = \frac{\sum_{i=1}^T \exp(w_{t, i} + k_i) \odot v_i}{\sum_{i=1}^T \exp(w_{t, i} + k_i)} = \sum_{i=1}^T \frac{\exp(w_{t, i} + k_i) }{\sum_{i=1}^T \exp(w_{t, i} + k_i)} \odot v_i \\ = \sum_{i=1}^T \text{softmax}(w_{t, i} + k_i) \odot v_i\]

Technical details

Notes

Raw notes from Yannic’s video