Contents
1 서론
대규모 언어모델(Large Language Models, LLM)은 다양한 전문 분야에서 확장된 대화나 대량의 토큰을 처리하는 등의 응용 가능성 때문에 많은 관심을 받고 있다. 이런 모델들은 입력 시퀀스 길이에 따라 계산 복잡도가 기하급수적으로 증가하는 문제가 있으며, 특히 이는 $O(n^2)$로 표현할 수 있습니다. 이런 문제를 해결하기 위해 linear attention 기법이 제안되었다. linear attention는 softmax 연산을 제거하고, 행렬 곱의 결합 법칙을 이용하여 계산 복잡도를 $O(n)$으로 감소시킨다. 이 연구에서는 Lightning Attention-2(Lightning Attention-2)를 도입하여 linear attention의 계산 복잡성을 극복하고자 합니다.
2 관련 연구
2.1 linear attention
선형 트랜스포머 구조는 기존의 Softmax 주의 메커니즘을 다양한 근사 방법으로 대체할 수 있습니다. 예를 들어, Katharopoulos 등(2020)은 1 + elu 활성 함수를 사용하고, Qin 등(2022)은 코사인 함수로 softmax의 특성을 근사할 수 있습니다. 이런 방법들은 이론적으로는 $O(nd^2)$의 복잡성을 가지나, 실제로는 linear attention가 인과적 상황에서 cumsum 연산이 필요하기 때문에 계산 효율이 크게 떨어진다.
2.2 IO-인식 주의
FlashAttention 시리즈는 GPU 플랫폼에서 표준 주의 연산자를 효율적으로 구현하기 위한 시스템 수준의 최적화를 집중적으로 다루고, 이 접근 방식은 메모리 읽기/쓰기 양을 최소화하기 위해 타일링 전략을 사용합니다.
2.3 긴 시퀀스 처리
긴 시퀀스를 처리하기 위한 일반적인 전략은 상대적 위치 인코딩 기술을 통합하는 것으로, 예를 들어 Roformer는 rotation 위치 임베딩 방식을 소개하고, ALiBi는 주의 메커니즘에서 멀리 떨어진 토큰의 영향을 완화하는 전략을 취합니다.
이런 방법들은 주로 파인튜닝이나 테스트 단계에서 시퀀스 길이를 확장할 수 있으며, 본 연구에서 제안하는 방법은 처음부터 긴 시퀀스로 모델을 훈련할 수 있게 됩니다.
3 방법
3.1 기초
우선 linear attention의 기본 형식을 회고하고, Lightning Attention-2를 소개합니다. NormAttention은 기존 트랜스포머 구조에서 비용이 많이 드는 softmax 및 스케일링 연산을 생략합니다.
\[O = \text{Norm}((QK^\top)V) \tag{1}\]\(Q, K, V \in \mathbb{R}^{n \times d}\)는 각각 질의, 키, 값 행렬이며, \(n\)은 시퀀스 길이, \(d\)는 특징 차원을 나타냅니다. 행렬 곱을 통해 위 식을 선형 변형으로 수학적으로 동등하게 변환할 수 있습니다.
\[O = \text{Norm}(Q(K^\top V)) \tag{2}\]이 선형 형식은 반복 예측에 효율적인 \(O(nd^2)\)의 복잡성을 제공하고, linear attention는 시퀀스 길이에 관계없이 일정한 \(O(d^2)\)의 계산 복잡도를 보장하여 무한히 긴 시퀀스에 대한 인퍼런스를 가능하게 합니다.
그러나 인과적 예측 작업에서는 cumsum 연산의 필요성 때문에 우측 곱의 효과가 저하되므로, Lightning Attention-1에서는 전통적인 왼쪽 행렬 곱을 사용하여 이 문제를 해결하고자 합니다.
3.2 Lightning Attention-2
Lightning Attention-2는 전체 연산 과정에서 타일링 방법을 사용합니다. 각 반복에서 \(Q_i, K_i, V_i\) 행렬은 블록으로 나뉘어져 SRAM으로 전송되어 계산되며, 내부 블록에서는 왼쪽 곱을 사용하고, 외부 블록에서는 오른쪽 곱을 사용합니다.
이 접근 방식은 계산 및 메모리 효율을 최적화하여 전체 실행 속도를 향상시키고, 중간 활성화 \(KV\)는 SRAM 내에서 반복적으로 저장 및 누적되면서 내부 및 외부 블록의 출력은 SRAM 내에서 합산되기때문에 최종적으로 HBM으로 다시 쓰여지게 됩니다.
이 방법은 각 메모리 구성 요소의 독특한 이점을 활용하여 계산 워크플로를 최적화합니다.
알고리즘 1: Lightning Attention-2 정방향 패스
\[\begin{align*} \textbf{Algorithm:} & \ \text{Block-wise Attention Computation with Decay Factor} \\ \textbf{Input:} & \ Q, K, V \in \mathbb{R}^{n \times d}, \text{decay factor } \lambda \in \mathbb{R}^+, \text{block size } B \\ \textbf{Output:} & \ \text{Divide } X \text{ into } T = \frac{n}{B} \text{ blocks: } X_1, X_2, \ldots, X_T, \text{ each of size } B \times d, \text{ where } X \in \{Q, K, V, O\} \\ 1: & \ \text{Initialize mask } M \in \mathbb{R}^{B \times B}, \text{ where } M_{ij} = \lambda^{i-j} \text{ for } i \geq j, \text{ otherwise } 0. \\ 2: & \ \text{Initialize } \Lambda = \text{diag}\{\lambda, \lambda^2, \ldots, \lambda^B\} \in \mathbb{R}^{B \times B} \\ 3: & \ \text{Initialize } KV = 0 \in \mathbb{R}^{d \times d} \\ 4: & \ \textbf{for } i = 1 \textbf{ to } T \textbf{ do} \\ 5: & \ \quad \text{Load } Q_i, K_i, V_i \text{ of size } B \times d \text{ from HBM to on-chip SRAM} \\ 6: & \ \quad \text{On-chip, compute } O_{\text{intra}} = [(Q_i K_i^\top) \odot M] V_i \\ 7: & \ \quad \text{On-chip, compute } O_{\text{inter}} = \Lambda Q_i (KV) \\ 8: & \ \quad \text{On-chip, update } KV = \lambda^B KV + \lambda^B \Lambda^{-1} K_i)^\top V_i \\ 9: & \ \quad \text{Write } O_i = O_{\text{intra}} + O_{\text{inter}} \text{ to HBM for the } i\text{-th block} \\ 10: & \ \textbf{end for} \\ 11: & \ \textbf{return } O \end{align*}\]Lightning Attention-2의 정방향 패스는 대략적으로 다음과 같은 과정을 거칩니다.
Lightning Attention-2의 forward pass 동안, \(t\)번째 출력은 다음과 같이 계산할 수 있습니다.
\[O_t = \sum_{i=1}^t Q_i (K_i^\top V_i) \tag{3}\]블록 형태로 방정식을 작성하면, 전체 시퀀스 길이 \(n\)과 블록 크기 \(B\)를 고려하여 \(X\)는 \(\{X_1, X_2, ..., X_T\}\)로 나뉘어지며, 각 블록은 \(B \times d\) 크기를 가집니다.
\[O_i = \sum_{j=1}^i Q_i (K_j^\top V_j) \tag{4}\]3.2.2. 역방향 패스
알고리즘 2: Lightning Attention-2 역방향 패스
\[\begin{aligned} \textbf{Algorithm:} & \ \text{Forward and Backward Pass with Block-wise Attention and Decay Factor} \\ \textbf{Input:} & \ Q, K, V, dO \in \mathbb{R}^{n \times d}: \text{input tensors (query, key, value, output gradient)} \\ & \ \text{decay factor } \lambda \in \mathbb{R}^+, \text{decay rate hyperparameter} \\ & \ \text{block size } B: \text{block size for memory chunking} \\ & \ \textbf{Input Division} \\ 1: & \ \text{Divide } X \in \{Q, K, V\} \text{ into } T = \frac{n}{B} \text{ blocks: } X_1, \ldots, X_T, \text{ each of size } B \times d. \\ 2: & \ \text{Divide } dX \in \{dQ, dK, dV, dO\} \text{ into } T = \frac{n}{B} \text{ blocks: } dX_1, \ldots, dX_T, \text{ each of size } B \times d. \\ & \ \textbf{Initialization} \\ 3: & \ \text{Initialize mask } M \in \mathbb{R}^{B \times B}, \text{ where } M_{ij} = \lambda^{i-j} \text{ if } i \geq j, \text{ else } 0. \\ 4: & \ \text{Initialize } \Lambda = \text{diag}(\lambda, \lambda^2, \ldots, \lambda^B) \in \mathbb{R}^{B \times B} \\ 5: & \ \text{Initialize } KV = 0, dKV = 0 \in \mathbb{R}^{d \times d} \\ & \ \textbf{Forward Pass (Loop 1)} \\ 6: & \ \textbf{for } i = 1 \textbf{ to } T \textbf{ do} \\ 7: & \ \quad \text{Load } K_i, V_i, O_i, dO_i \text{ of size } B \times d \text{ from HBM to on-chip SRAM} \\ 8: & \ \quad \text{On-chip computations:} \\ & \ \quad \quad dQ_{\text{intra}} = [(dO_i V_i^\top) \odot M] K_i \\ & \ \quad \quad dQ_{\text{inter}} = \Lambda dO_i (KV)^\top \\ 9: & \ \quad \text{On-chip update: } KV = \lambda^B KV + (\lambda^B \Lambda^{-1} K_i)^\top V_i \\ 10: & \ \quad \text{Write } dQ_i = dQ_{\text{intra}} + dQ_{\text{inter}} \text{ to HBM for the } i\text{-th block} \\ 11: & \ \textbf{end for} \\ & \ \textbf{Backward Pass (Loop 2)} \\ 12: & \ \textbf{for } i = T \textbf{ to } 1 \textbf{ do} \\ 13: & \ \quad \text{Load } Q_i, K_i, V_i, O_i, dO_i \text{ of size } B \times d \text{ from HBM to on-chip SRAM} \\ 14: & \ \quad \text{On-chip computations:} \\ & \ \quad \quad dK_{\text{intra}} = [(dO_i V_i^\top) \odot M]^\top Q_i \\ & \ \quad \quad dK_{\text{inter}} = \lambda^B (\Lambda^{-1} V_i) (dKV)^\top \\ & \ \quad \quad dV_{\text{intra}} = [(Q_i K_i^\top) \odot M]^\top dO_i \\ & \ \quad \quad dV_{\text{inter}} = \lambda^B (\Lambda^{-1} K_i) dKV \\ 15: & \ \quad \text{On-chip update: } dKV = \lambda^B dKV + (\Lambda Q_i)^\top dO_i \\ 16: & \ \quad \text{Write } dK_i = dK_{\text{intra}} + dK_{\text{inter}}, dV_i = dV_{\text{intra}} + dV_{\text{inter}} \text{ to HBM for the } i\text{-th block} \\ 17: & \ \textbf{end for} \\ 18: & \ \textbf{return } dQ, dK, dV: \text{ gradient tensors for query, key, and value} \end{aligned}\]역방향 패스의 논리적 과정을 설명하기 위해, 주어진 \(do_t\)에 대해 다음과 같은 수식을 고려합니다.
\[\begin{aligned} dq_t &= do_t (kv_t)^\top \in \mathbb{R}^{1 \times d}, \\ dk_t &= v_t (dkv_t)^\top \in \mathbb{R}^{1 \times d}, \\ dv_t &= k_t (dkv_t) \in \mathbb{R}^{1 \times d}, \\ dkv_t &= \sum_{s \geq t} \lambda^{s-t} q_s^\top do_s \in \mathbb{R}^{d \times d}. \end{aligned} \tag{5}\]\(dkv_t\)를 재귀적 형태로 쓰면, 다음과 같이 표현됩니다.
\[\begin{aligned} dkv_{n+1} &= 0 \in \mathbb{R}^{d \times d}, \\ dkv_{t-1} &= \lambda dkv_t + q_{t-1}^\top do_{t-1}. \end{aligned} \tag{6}\]전체 시퀀스 길이 \(n\)과 블록 크기 \(B\)를 고려할 때, \(X\)는 크기 \(B \times d\)를 가진 블록 \(\{X_1, X_2, ..., X_T\}\)로 나누어집니다. 이를 바탕으로 블록 스타일의 방정식을 다음과 같이 작성할 수 있습니다.
\[\begin{aligned} dKV_{T+1} &= 0 \in \mathbb{R}^{d \times d}, \\ dKV_t &= \sum_{s > tB} \lambda^{s-tB} q_s^\top do_s. \end{aligned} \tag{7}\]그리고 \((t+1)\)번째 블록, 즉 \(tB+r\), \(0 \leq r < B\)에 대해,
\[\begin{aligned} dq_{tB+r} &= do_{tB+r} \sum_{s \leq tB+r} \lambda^{tB+r-s} v_s^\top k_s \\ &= do_{tB+r} \left( \sum_{s=tB+1}^{tB+r} \lambda^{tB+r-s} v_s^\top k_s + \lambda^r \sum_{s \leq tB} \lambda^{tB-s} v_s^\top k_s \right) \\ &= do_{tB+r} \sum_{s=tB+1}^{tB+r} \lambda^{tB+r-s} v_s^\top k_s + \lambda^r do_{tB+r} kv_{tB}^\top. \end{aligned} \tag{8}\]행렬 형식으로 표현하면,
\[dQ_{t+1} = \underbrace{[(dO_{t+1} V_{t+1}^\top) \odot M] K_{t+1}}_{\text{Intra Block}} + \underbrace{\Lambda dO_{t+1} (KV_t^\top)}_{\text{Inter Block}} \tag{9}\]이와 같이 역방향 패스의 각 단계는 블록 간의 그래디언트 전파와 누적을 통해 계산을 효율적으로 수행하며, Lightning Attention-2의 전체적인 메모리 및 계산 최적화를 가능하게 합니다.
Linear Attention의 개선 참고자료: Zoology (Blogpost 2): Simple, Input-Dependent, and Sub-Quadratic Sequence Mixers
(LLM) (Brown et al., 2020; Touvron et al., 2023a;b; Peng et al., 2023; Qin et al., 2023b) and multi-modal models (Li et al., 2022; 2023a; Liu et al., 2023; Radford et al., 2021; Li et al., 2023b; Lu et al., 2022; Mao et al., 2023; Shen et al., 2023; Zhou et al., 2023; Sun et al., 2023a; Hao et al., 2024). However, its computational complexity grows quadratically with the length of the input sequence, making it challenging to model extremely long sequences.
Unlimited sequence length stands out as a noteworthy aspect within the realm of LLM, attracting considerable attention from researchers who seek intelligent solutions. The potential applications of LLM with unlimited sequence length are diverse, encompassing extended conversations in various professional domains and handling a vast number of tokens in multimodal modeling tasks.
In response to the quadratic complexity challenge, a promising resolution emerges in the form of linear attention. This method involves the elimination of the softmax operation and capitalizes on the associativity property of matrix products. Consequently, it significantly accelerates both training and inference procedures. To elaborate, linear attention reduces the computational complexity from O(n2) to O(n) by leveraging the kernel trick (Katharopoulos et al., 2020b; Choromanski et al., 2020; Peng et al., 2021; Qin et al., 2022b) to compute the attention matrices, where n represents the sequence length. This avenue holds substantial promise for augmenting the efficiency of transformer-style models across a broad spectrum of applications.
It is important to note that the notable reduction in complexity from O(n2) to O(n) in linear attention is only theoretical and may not directly translate to a proportional improvement in computational efficiency on hardware in practice. The realization of practical wall-clock speedup faces challenges, primarily stemming from two issues: 1) the dominance of memory access (I/O) on the GPU could impact the overall computation speed of attention. 2) the cumulative summation (cumsum) needed by the linear attention kernel trick prevents it from reaching its theoretical training speed in the causal setting.
The first issue has been successfully addressed by Lightning Attention-1 (Qin et al., 2023b). In this paper, we introduce Lightning Attention-2 to solve the second issue. The key idea is to leverage the concept of “divide and conquer” by separately handling the intra block and inter block components in linear attention calculation. Specifically, for the intra blocks, we maintain the use of conventional attention computation mechanism to compute the product of QKV, while for the inter blocks, we employ the linear attention kernel trick (Katharopoulos et al., 2020b). Tiling techniques are implemented in both forward and backward procedures to fully leverage GPU hardware capabilities. As a result, the Lightning Attention-2 can train LLMs with unlimited sequence length without extra cost1, as its computational speed remains constant with increasing sequence length under fixed memory consumption.
Figure 1. Speed Showdown: FlashAttention vs. Lightning Attention in Expanding Sequence Lengths and Model Sizes. The diagram above provides a comparative illustration of training speed, Token per GPU per Second (TGS) for LLaMA with FlashAttention-2, TransNormerLLM with Lightning Attention-1 and TransNormerLLM with Lightning Attention-2, implemented across three model sizes: 400M, 1B, and 3B from left to right. It is strikingly evident that Lightning Attention-2 manifests a consistent training speed irrespective of the increasing sequence length. Conversely, the other methods significantly decline training speed as the sequence length expands.
We performed a comprehensive evaluation of Lightning Attention-2 across a diverse range of sequence lengths to assess its accuracy and compare its computational speed and memory utilization with FlashAttention-2 (Dao, 2023) and Lightning Attention-1. The findings indicate that Lightning Attention-2 exhibits a notable advantage in computational speed, attributed to its innovative intra-inter separation strategy. Additionally, Lightning Attention-2 demonstrates a reduced memory footprint compared to its counterparts without compromising performance.
Linear Transformer architectures discard the Softmax Attention mechanism, replacing it with distinct approximations (Katharopoulos et al., 2020a; Choromanski et al., 2020; Peng et al., 2021; Qin et al., 2022b;a). The key idea is to leverage the “kernel trick” to accelerate the attention matrix computation, i.e., compute the product of keys and values first to circumvent the n × n matrix multiplication. Multiple methods have been proposed to replace the softmax operation. For instance, Katharopoulos et al. (2020a) employ the 1 + elu activation function, Qin et al. (2022b) utilize the cosine function to approximate softmax properties, and Ke et al. (2021); Zheng et al. (2022; 2023) leverage sampling strategies to directly mimic softmax operation. Despite having a theoretical complexity of O(nd2), the practical computational efficiency of linear attention diminishes notably in causal attention scenarios, primarily due to the necessity for cumsum operations (Hua et al., 2022).
1 However, the sequence length may still be limited by hardware constraints, such as the GPU memory.
The FlashAttention series (Dao et al., 2022; Dao, 2023) focuses on system-level optimizations for the efficient implementation of the standard attention operator on GPU platforms. Extensive validation has demonstrated its effectiveness. The approach employs tiling strategies to minimize the volume of memory reads/writes between the GPU’s high bandwidth memory (HBM) and on-chip SRAM.
To address the issue of slow computation for Linear Attention in the causal setting, Lightning Attention 1 (Qin et al., 2023b) employs the approach of FlashAttention-1/2, which involves segmenting the inputs Q, K, V into blocks, transferring them from slow HBM to fast SRAM, and then computing the attention output with respect to these blocks. Subsequently, the final results are accumulated. Although this method is much more efficient than the PyTorch implementation, it does not take advantage of the computational characteristics inherent to Linear Attention, and the theoretical complexity remains O(n2d).
A widely adopted strategy to tackle challenges related to length extrapolation involves the integration of Relative Positional Encoding (RPE) techniques (Su et al., 2021; Qin et al., 2023c), strategically directing attention towards neighboring tokens. ALiBi (Press et al., 2022) utilizes linear decay biases in attention mechanisms to mitigate the impact of distant tokens. Roformer (Su et al., 2021) introduces a novel Rotary Position Embedding (RoPE) method, widely embraced in the community, effectively leveraging positional information for transformer-based language model learning. Kerple (Chi et al., 2022) explores shift-invariant conditionally positive definite kernels within RPEs, introducing a suite of kernels aimed at enhancing length extrapolation properties, with ALiBi recognized as one of its instances. Furthermore, Sandwich (Chi et al., 2023) postulates a hypothesis elucidating the mechanism behind ALiBi, empirically validating it by incorporating the hypothesis into sinusoidal positional embeddings. (Qin et al., 2024) explored the sufficient conditions for additive relative position encoding to have extrapolation capabilities.
Instead of investigating the length extrapolation capability of transformers, some works also attempt to directly increase the context window sizes. Chen et al. (2023) introduces Position Interpolation (PI), extending context window sizes of RoPE-based pretrained Large Language Models (LLMs) such as LLaMA models to up to 32768 with minimal fine-tuning (within 1000 steps). StreamingLLM (Xiao et al., 2023) proposes leveraging the attention sink phenomenon, maintaining the Key and Value information of initial tokens to substantially recover the performance of window attention. As the sequence grows longer, the performance degrades. These methods can only extend sequence length in fine-tuning or testing phases, while our method allows training models in long sequence lengths from scratch with no additional cost.
We first recall the formulation of linear attention and then introduce our proposed Lightning Attention-2. In the case of NormAttention within TransNormer (Qin et al., 2022a), attention computation deviates from the conventional Transformer structure (Vaswani et al., 2017) by eschewing the costly softmax and scaling operations. The NormAttention mechanism can be expressed as follows:
\[O = \text{Norm}((QK^\top)V),\]where $Q$, $K$, and $V \in \mathbb{R}^{n \times d}$ are the query, key, and value matrices, respectively, with $n$ denoting sequence length and $d$ representing feature dimension. To leverage the computational efficiency inherent in right matrix multiplication, the above equation can be seamlessly and mathematically equivalently transformed into its linear variant, as dictated by the properties of matrix multiplication:
\[O = \text{Norm}(Q(K^\top V)),\]This linear formulation facilitates recurrent prediction with a commendable complexity of $O(nd^2)$, rendering it efficient during training relative to sequence length. Furthermore, employing linear attention ensures a constant computation complexity of $O(d^2)$ irrespective of sequence length, thereby enabling inference over unlimited long sequences. This achievement is realized by updating $K^\top V$ recurrently without the need for repeated computation of the entire attention matrix.
In contrast, the standard softmax attention entails a computational complexity of $O(md^2)$ during the inference process, where $m$ denotes the token index. Nevertheless, when dealing with causal prediction tasks, the effectiveness of the right product is compromised, leading to the requirement for the computation of cumsum (Hua et al., 2022). This impediment hinders the potential for highly efficient parallel computation. Consequently, we persist with the conventional left matrix multiplication in Lightning Attention-1. This serves as the promotion behind the introduction of Lightning Attention-2, specifically crafted to address the challenges associated with the right product in such contexts.
Lightning Attention-2 employs a tiling methodology throughout its whole computation process. Given the huge variance in memory bandwidth between HBM and SRAM within GPU, Lightning Attention-2 applies a distinct strategy for leveraging them. In each iteration $i$, matrices $Q_i$, $K_i$, $V_i$ undergo segmentation into blocks, subsequently transferred to SRAM for computation. The intra- and inter-block operations are segregated, with intra-blocks employing the left product and inter-blocks utilizing the right product. This approach optimally exploits the computational and memory efficiencies associated with the right product, enhancing overall execution speed.
The intermediate activation $KV$ is iteratively saved and accumulated within SRAM. Subsequently, the outputs of intra-blocks and inter-blocks are summed within SRAM, and the results are written back to HBM. This method aims to capitalize on the distinct advantages of each memory component, optimizing the computational workflow.
The structural framework of Lightning Attention-2 is well illustrated in Fig. 2. The intricate details of the Lightning Attention-2 implementation are explicated through Algorithm 1 (forward pass) and Algorithm 2 (backward pass). These algorithms serve to encapsulate the nuanced computational procedures integral to Lightning Attention-2. Additionally, we provide a comprehensive derivation to facilitate a more profound comprehension of Lightning Attention-2. The derivations are systematically presented for both the forward pass and the backward pass, contributing to a thorough understanding of the underlying mechanisms.
Input: Q, K, V ∈ Rn×d, decay rate λ ∈ R+, block sizes B.
Divide X into T = n/B blocks X1, X2, ..., XT of size B × d each, where X ∈ {Q, K, V, O}.
Initialize mask M ∈ RB×B, where Mij = λi−j, if i ≥ j, else 0.
Initialize Λ = diag{λ, λ2, ..., λB} ∈ RB×B.
Initialize KV = 0 ∈ Rd×d.
for 1 ≤ i ≤ T do
Load Qi, Ki, Vi ∈ RB×d from HBM to on-chip SRAM.
On chip, compute Ointra = [(QiKi⊤) ⊙ M]Vi.
On chip, compute Ointer = ΛQi(KV).
On chip, compute KV = λBKV + (λBΛ−1Ki)⊤Vi.
Write Oi = Ointra + Ointer to HBM as the i-th block of O.
end for
return O.
We ignore the $\text{Norm}(\cdot)$ operator in eq. (2) to simplify the derivations. During forward pass of Lightning Attention-2, the $t$-th output can be formulated as
\[o_t = q_t \sum_{s \leq t} \lambda^{t-s} k_s^\top v_s.\]In a recursive form, the above equation can be rewritten as
\[\begin{aligned} kv_0 &= 0 \in \mathbb{R}^{d \times d}, \\ kv_t &= \lambda kv_{t-1} + k_t^\top v_t, \\ o_t &= q_t(kv_t), \end{aligned}\]where
\[kv_t = \sum_{s \leq t} \lambda^{t-s} k_s^\top v_s.\]To perform tiling, let us write the equations in block form. Given the total sequence length $n$ and block size $B$, $X$ is divided into $T = \frac{n}{B}$ blocks ${X_1, X_2, …, X_T}$ of size $B \times d$ each, where $X \in {Q, K, V, O}$.
We first define
\[\begin{aligned} KV_0 &= 0 \in \mathbb{R}^{d \times d}, \\ KV_t &= \sum_{s \leq tB} \lambda^{tB-s} k_s^\top v_s. \end{aligned}\]Given $KV_t$, the output of $(t+1)$-th block, i.e., $tB+r$, with $1 \leq r \leq B$ is
\[\begin{aligned} o_{tB+r} &= q_{tB+r} \sum_{s \leq tB+r} \lambda^{tB+r-s} k_s^\top v_s \\ &= q_{tB+r} \left( \sum_{s=tB+1}^{tB+r} \lambda^{tB+r-s} k_s^\top v_s + \lambda^r \sum_{s \leq tB} \lambda^{tB-s} k_s^\top v_s \right) \\ &= q_{tB+r} \sum_{s=tB+1}^{tB+r} \lambda^{tB+r-s} k_s^\top v_s + \lambda^r q_{tB+r} kv_{tB}. \end{aligned}\]Rewritten in matrix form, we have
\[O_{t+1} = \underbrace{[(Q_{t+1} K_{t+1}^\top) \odot M] V_{t+1}}_{\text{Intra Block}} + \underbrace{\Lambda Q_{t+1}(KV_t)}_{\text{Inter Block}},\]where
\[M_{st} = \begin{cases} \lambda^{s-t} & s \geq t \\ 0 & s < t \end{cases}, \quad \Lambda = \text{diag}\{1, ..., \lambda^{B-1}\}.\]And the $KV$ at $(t+1)$-th block can be written as
\[\begin{aligned} KV_{t+1} &= \sum_{s \leq (t+1)B} \lambda^{(t+1)B-s} k_s^\top v_s \\ &= \lambda^B \sum_{s \leq tB} \lambda^{tB-s} k_s^\top v_s + \sum_{s=tB+1}^{(t+1)B} \lambda^{(t+1)B-s} k_s^\top v_s \\ &= \lambda^B KV_t + \text{diag}\{\lambda^{B-1}, ..., 1\} K_t^\top V_t \\ &= \lambda^B KV_t + \lambda^B \Lambda^{-1} K_t)^\top V_t. \end{aligned}\]The complete expression of the forward pass of Lightning Attention-2 can be found in Algorithm 1.
Input:
* Q, K, V, dO ∈ Rn×d: Input tensors (queries, keys, values, output gradients)
* decay_rate λ ∈ R+: Decay rate hyperparameter
* block_size B: Block size for memory chunking
Divide Inputs:
* Divide X ∈ {Q, K, V} into T = n/B blocks: X₁,...,XT of size B×d each.
* Divide dX ∈ {dQ, dK, dV, dO} into T = n/B blocks: dX₁,...,dX_T of size B×d each.
Initialization:
* M ∈ RB×B: Mask tensor, Mij = λ^(i-j) if i ≥ j, else 0.
* Λ = diag{λ, λ², ..., λ^B} ∈ RB×B: Diagonal matrix of decay rates.
* KV = 0, dKV = 0 ∈ Rd×d: Accumulators for key-value gradients and their updates.
Forward Pass (Loop 1):
1. For i = 1 to T:
* Load Ki, Vi, Oi, dOi ∈ RB×d from HBM to on-chip SRAM.
* On-chip computation:
* dQ_intra = [(dOi * V⊤_i) ⊙ M] * Ki (element-wise multiplication with mask)
* dQ_inter = Λ * dOi * (KV)⊤
* On-chip update: KV = λ^B * KV + (λ^B * Λ⁻¹ * Ki)⊤ * Vi
* Write dQi = dQ_intra + dQ_inter to HBM as the i-th block of dQ.
Backward Pass (Loop 2):
2. For i = T down to 1:
* Load Qi, Ki, Vi, Oi, dOi ∈ RB×d from HBM to on-chip SRAM.
* On-chip computation:
* dK_intra = [(dOi * V⊤_i) ⊙ M]⊤ * Qi
* dK_inter = (λ^B * Λ⁻¹ * Vi) * (dKV)⊤
* dV_intra = [(Qi * K⊤_i) ⊙ M]⊤ * dOi
* dV_inter = (λ^B * Λ⁻¹ * Ki) * dKV
* On-chip update: KV = λ^B * dKV + (Λ * Qi)⊤ * dOi
* Write dKi = dK_intra + dK_inter, dVi = dV_intra + dV_inter to HBM as the i-th block of dK, dV.
Output:
* Return dQ, dK, dV: Gradient tensors for queries, keys, and values.
For backward pass, let us consider the reverse process. First given $do_t$, we have
\[\begin{aligned} dq_t &= do_t (kv_t)^\top \in \mathbb{R}^{1 \times d}, \\ dk_t &= v_t (dkv_t)^\top \in \mathbb{R}^{1 \times d}, \\ dv_t &= k_t (dkv_t) \in \mathbb{R}^{1 \times d}, \\ dkv_t &= \sum_{s \geq t} \lambda^{s-t} q_s^\top do_s \in \mathbb{R}^{d \times d}. \end{aligned}\]By writing $dkv_t$ in a recursive form, we get
\[\begin{aligned} dkv_{n+1} &= 0 \in \mathbb{R}^{d \times d}, \\ dkv_{t-1} &= \lambda dkv_t + q_{t-1}^\top do_{t-1}. \end{aligned}\]To facilitate the understanding of tiling, let us consider the above equations in block style. Given the total sequence length $n$ and block size $B$, $X$ is divided into $T = \frac{n}{B}$ blocks ${X_1, X_2, …, X_T}$ of size $B \times d$ each, where $X \in {Q, K, V, O, dO}$.
We first define
\[\begin{aligned} dKV_{T+1} &= 0 \in \mathbb{R}^{d \times d}, \\ dKV_t &= \sum_{s > tB} \lambda^{s-tB} q_s^\top do_s. \end{aligned}\]Then for the $(t+1)$-th block, i.e., $tB+r$, $0 \leq r < B$, we have
\[\begin{aligned} dq_{tB+r} &= do_{tB+r} \sum_{s \leq tB+r} \lambda^{tB+r-s} v_s^\top k_s \\ &= do_{tB+r} \left( \sum_{s=tB+1}^{tB+r} \lambda^{tB+r-s} v_s^\top k_s + \lambda^r \sum_{s \leq tB} \lambda^{tB-s} v_s^\top k_s \right) \\ &= do_{tB+r} \sum_{s=tB+1}^{tB+r} \lambda^{tB+r-s} v_s^\top k_s + \lambda^r do_{tB+r} kv_{tB}^\top. \end{aligned}\]In matrix form, we have
\[dQ_{t+1} = \underbrace{[(dO_{t+1} V_{t+1}^\top) \odot M] K_{t+1}}_{\text{Intra Block}} + \underbrace{\Lambda dO_{t+1} (KV_t^\top)}_{\text{Inter Block}}.\]Since the recursion of $dK_t$ steps from $t+1$ to $t$, given $KV_{t+1}$, $dK_t$ for the $t$-th block, i.e., at positions $(t-1)B+r$, $0 < r \leq B$ is
\[\begin{aligned} dk_{(t-1)B+r} &= v_{(t-1)B+r} \sum_{s \geq (t-1)B+r} \lambda^{s-(t-1)B-r} do_s^\top q_s \\ &= v_{(t-1)B+r} \left( \sum_{s=(t-1)B+r}^{tB} \lambda^{tB+r-s} do_s^\top q_s \right) \\ &\quad + v_{(t-1)B+r} \lambda^{B-r} \sum_{s > tB} \lambda^{s-tB} do_s^\top q_s \\ &= v_{(t-1)B+r} \sum_{s=(t-1)B+r}^{tB} \lambda^{tB+r-s} do_s^\top q_s + \lambda^{B-r} v_{(t-1)B+r} dKV_t^\top. \end{aligned}\]In matrix form, we get
\[dK_{t-1} = \underbrace{[(dO_{t-1} V_{t-1}^\top) \odot M]^\top Q_{t-1}}_{\text{Intra Block}} + \underbrace{\lambda^B \Lambda^{-1} V_{t-1} (dKV_t^\top)}_{\text{Inter Block}}.\]Considering $dV_t$ for the $t$-th block, i.e., at positions $(t-1)B+r$, $0 < r \leq B$, we have
\[\begin{aligned} dv_{(t-1)B+r} &= k_{(t-1)B+r} \sum_{s \geq (t-1)B+r} \lambda^{s-(t-1)B-r} q_s^\top do_s \\ &= k_{(t-1)B+r} \left( \sum_{s=(t-1)B+r}^{tB} \lambda^{tB+r-s} q_s^\top do_s \right) \\ &\quad + \lambda^{B-r} \sum_{s > tB} \lambda^{s-tB} q_s^\top do_s \\ &= k_{(t-1)B+r} \sum_{s=(t-1)B+r}^{tB} \lambda^{tB+r-s} q_s^\top do_s + \lambda^{B-r} k_{(t-1)B+r} dKV_t. \end{aligned}\]In matrix form, we get
\[dV_{t-1} = \underbrace{[(Q_{t-1} K_{t-1}^\top) \odot M]^\top dO_t}_{\text{Intra Block}} + \underbrace{\lambda^B \Lambda^{-1} K_{t-1} (dKV_t)}_{\text{Inter Block}}.\]Finally, the recursive relation for $dKV_t$ is
\(\begin{aligned} dKV_t = \lambda_{s>t} B \lambda_{s-tB} q^\top s_{dos} \\ &= \lambda_B s > (t+1) B \lambda_{s-(t+1)B} q^\top s_{dos} + (t+1) B \lambda_{s=t+1} B \lambda_{s-tB} q^\top s_{dos} \\ &= \lambda_B dKV_{t+1} +\)Lambda^Q_t)^\top dO_t. (20) \end{aligned} $$
Algorithm 2 describes the backward pass of Lightning Attention-2 in more detail.
Discussion
A recent method, GLA (Yang et al., 2023), models sequences using linear attention with data-dependent decay. Its chunk-wise Block-Parallel Algorithm employs tiling and IO-aware concepts. However, unlike Lightning Attention-2, it uses parallel computations for each block, which leads to higher memory usage. Retnet (Sun et al., 2023b) is very similar in structure to TransNormer LLM (Qin et al., 2023b) and uses the chunk-wise retention algorithm. This algorithm is comparable to the forward pass of Lightning Attention-2 but does not consider IO-aware or the backward pass.
To comprehensively assess Lightning Attention-2’s performance, speed, and memory utilization, we conducted extensive experiments on the TransNormerLLM model, with Lightning Attention-2 integrated. Our implementation utilizes the Metaseq framework (Zhang et al., 2022), a PyTorch-based sequence modeling framework (Paszke et al., 2019). All experiments are executed on the GPU cluster featuring 128 A100 80G GPUs. The deployment of Lightning Attention-2 is implemented in Triton (Tillet et al., 2019).
We conducted a comparison of speed and memory usage among attention modules Lightning Attention-1, Lightning Attention-2, and FlashAttention-2, all under a single A100 80G GPU. As depicted in Figure 3, the analysis focuses on the runtime, measured in milliseconds, for the separated forward and backward propagation. The baseline runtime demonstrates a quadratic growth relative to the sequence length. In contrast, Lightning Attention-2 exhibits a markedly superior performance with linear growth. Notably, as the sequence length increases, this disparity in runtime becomes increasingly apparent. In addition to speed enhancements, our method also maintains a significant advantage in memory usage with the increase in sequence length.
Table 2. Language Modeling Comparison between TransNormer-LLM with Lightning Attention-1 and Lightning Attention-2.
Performance Evaluation In Table 2, we evaluated the performance of the TransNormerLLM-0.4B model under 2K contexts, comparing two variants: one equipped with Lightning Attention-1 and the other with Lightning Attention-2. These experiments were carried out using 8×A100 80G GPUs. After 100,000 iterations, using the sampled corpus from our corpus with 300B tokens and initial seed, we observed a marginal performance difference. Specifically, the variant with Lightning Attention-2 demonstrated a performance decrement of 0.001 compared to its counterpart with Lightning Attention-1.
Furthermore, our analysis extended to benchmarking the toptier efficient large language models, including LLaMA-FA2 (Touvron et al., 2023a; Dao, 2023), TNL-LA2, HGRN (Qin et al., 2023d), and TNN (Qin et al., 2023a). This benchmarking focused on training loss using a 30B subset of our uniquely assembled corpus, scaling from 1 to 3 billion parameters. As depicted in Figure 4, the TNL-LA2 model achieved marginally lower loss compared to the other models under review in both 1B and 3B parameters.
Figure 3. Comparative Analysis of Speed and Memory Usage: FlashAttention vs. Lightning Attention. Upper Section: Runtime in milliseconds for the forward and backward pass across varying sequence lengths. Lower Section: Memory utilization during the forward and backward pass at different sequence lengths.
Table 1. Efficiency Comparison of LLaMA with FlashAttention2, TransNormerLLM with Lightning Attention-1, and TransNormerLLM with Lightning Attention-2. The statistical analysis was performed using 2×A100 80G GPUs. The table reports Tokens per GPU per Second (TGS) across three different model sizes, within context ranges spanning from 1K to 92K. OOM stands for out of GPU memory.
Efficiency Evaluation In Table 1, we present a comparative analysis of training speeds under the same corpora and hardware setups. This comparison encompasses three variants: TransNormerLLM with Lightning Attention-2 (TNL-LA2), TransNormerLLM with Lightning Attention-1 (TNL-LA1), and LLaMA with FlashAttention2 (LLaMA-FA2). Our findings show that during both the forward and backward passes, the TGS (tokens per GPU per second) for TNL-LA2 remains consistently high, while the other two models exhibit a rapid decline when the sequence length is scaled from 1K to 92K. This pattern suggests that Lightning Attention-2 offers a significant advancement in managing unlimited sequence lengths in LLM.
Language Model
To evaluate the performance of the Lightning Attention-2, we conducted an analysis of the TransNormerLLM-15B (Qin et al., 2023b), a model comprising 15 billion parameters. The TransNormerLLM-15B is characterized by its 42 layers, 40 attention heads, and an overall embedding dimension of 5120. The model will be trained on a corpus of more than 1.3 trillion tokens with a sequence length of 6,144. Notably, the model achieved a processing speed of 1,620 tokens per GPU per second. Given that the comprehensive pre-training phase is scheduled to span three months, we hereby present the most recent results from the latest checkpoint for inclusion in Table 3.
Figure 4. Performance Comparison of HGRN, TNN, LLaMA with FlashAttention2 and TransNormerLLM with Lightning Attention-2. For the 1B model, we used 16×A800 80G GPUs with a batch size of 12 per GPU; for the 3B model, we scaled up to 32×A800 80G GPUs and a batch size of 30 per GPU. The training context length was set to 2K.
Table 3. Performance Comparison on Commonsense Reasoning and Aggregated Benchmarks. TNL-LA2: TransNormerLLM with Lightning Attention-2. PS: parameter size (billion). T: tokens (billion). HS: HellaSwag. WG: WinoGrande.
This evaluation is conducted using the lm-evaluation-harness framework (Gao et al., 2023). Our benchmark focuses on two key areas: Commonsense Reasoning (CSR) and Multiple Choice Questions (MCQ). For comparative analysis, we also evaluated the Pythia-12B (Biderman et al., 2023) model under the same benchmarks.
Commonsense Reasoning We report BoolQ (Clark et al., 2019), PIQA (Bisk et al., 2019), SIQA (Sap et al., 2019), HellaSwag (Zellers et al., 2019), WinoGrande (Sakaguchi et al., 2019), ARC easy and challenge (Clark et al., 2018), OpenBookQA (Mihaylov et al., 2018) and their average. In all CSR tasks, the performance of TransNormerLLM-15B surpassed Pythia-12B by about 2%. Furthermore, TransNormerLLM-15B-100B showed an approximate 3.5% improvement over its 50 billion-token stage, especially in the HellaSwag task, with over a 5% performance increase.
Aggregated Benchmarks We report the overall results for MMLU (Hendrycks et al., 2021) and C-Eval (Huang et al., 2023) with both 0-shot and 5-shot settings. In the C-Eval tasks, TransNormerLLM-15B is about 2% higher than Pythia-12B. In the 0-shot and 5-shot tests in both Chinese (C-Eval) and English (MMLU), TransNormerLLM-15B’s performance also exceeded the 25% baseline (the probability of random selection in a 4 choice scenario). We also noticed fluctuations in the 5-shot MCQ tasks, with an average MCQ score of around 26.5%.
In this paper, we introduced Lightning Attention-2, a pioneering implementation of linear attention that effectively harnesses its theoretical computational advantages, particularly in the causal setting. Our approach, which adopts the concepts of “divide and conquer” and tiling techniques, successfully addresses the limitations of current linear attention algorithms, especially the challenges associated with cumulative summation. By separating the computation into intrablock and inter-block components, we effectively leverage GPU hardware to its fullest potential, ensuring efficiency. Our extensive experiments across various model sizes and sequence lengths demonstrate that Lightning Attention-2 not only maintains consistent training speeds regardless of input sequence length but also outperforms existing state-of-the-art attention mechanisms in terms of speed and accuracy. This breakthrough has profound implications for the future of large language models, particularly those requiring the processing of long sequences. Looking ahead, we intend to introduce sequence parallelism in conjunction with Lightning Attention-2, which aims to facilitate the training of extralong sequences, effectively overcoming existing hardware constraints.