Contents
1. 서론
트랜스포머 모델에서 컨텍스트 길이를 확장하는 것은 계산 비용과 메모리 요구 사항이 입력 시퀀스 길이에 따라 제곱으로 증가하기 때문에 큰 도전이 됩니다. 이런 문제를 해결하기 위해 Dao 등은 FlashAttention 알고리즘을 제안하여 표준 attention의 계산과 메모리 사용을 줄이는 방법을 모색하였습니다. 본 논문에서는 FlashAttention의 개선 버전인 FlashAttention-2를 소개하고, 효율성을 더욱 향상시키기 위한 방법을 제안합니다.
2. 배경
2.1 하드웨어 특성
GPU는 높은 병렬 처리 능력을 가지고 있으며, 특히 행렬 곱셈을 위한 텐서 코어 같은 전용 유닛이 탑재되어 있습니다. 이런 특성은 행렬 곱셈 작업을 빠르게 수행할 수 있게 해 주며, FlashAttention과 같은 최적화된 알고리즘에서 중요한 역할을 합니다.
2.2 표준 Attention 구현
기본적인 Attention 메커니즘은 입력 시퀀스 \(Q\), \(K\), \(V\)에 대해 다음과 같은 수식을 계산합니다.
\(S = QK^T\) \(P = \text{softmax}(S)\) \(O = PV\)
이 과정은 메모리와 계산량이 많이 필요한 작업으로, 특히 \(S\)와 \(P\)를 메모리에 저장하는 것이 큰 부담이 됩니다.
3. FlashAttention-2: 알고리즘, 병렬 처리 및 작업 분할
FlashAttention-2는 기존 FlashAttention의 비행렬 곱셈 연산(FLOPs)을 줄이는 것을 목표로 합니다. 이는 GPU에서 행렬 곱셈이 훨씬 빠르게 수행될 수 있도록 하기 위함입니다.
3.1 알고리즘 개선
FlashAttention-2는 online softmax 기법을 사용하여 불필요한 메모리 접근을 최소화합니다. 이를 통해 각 블록의 softmax를 계산하고 최종 결과를 조정하여 정확한 출력을 도출할 수 있습니다. 예를 들어, 다음과 같이 계산을 수행합니다.
\[\begin{align*} &\text{For each block } i: \\ &m_i = \max(S_i) \\ &\ell_i = \sum \exp(S_i - m_i) \\ &P_i = \frac{\exp(S_i - m_i)}{\ell_i} \\ &O_i = P_i V \\ \end{align*}\]3.2 병렬 처리 및 작업 분할
FlashAttention-2는 시퀀스 길이 차원뿐만 아니라 배치 및 헤드 수 차원에서도 계산을 병렬화하여 GPU 자원의 활용도를 높이며, 특히 긴 시퀀스를 처리할 때 효과적입니다.
또, 각 스레드 블록 내에서 다른 워프들 사이의 작업을 분할하여 통신과 공유 메모리 접근을 최소화해서 전반적인 속도를 높입니다.
3.3 FlashAttention-2의 포워드 및 백워드 패스
FlashAttention-2는 기존의 FlashAttention 알고리즘을 개선하여 비행렬 곱셈 연산을 줄이고, GPU의 병렬 처리 능력을 극대화하기 위한 방법을 도입했습니다.
3.3.1 포워드 패스
FlashAttention-2의 포워드 패스는 입력 시퀀스 \(Q\), \(K\), \(V\)를 처리하여 출력 \(O\)를 계산합니다. 주요 개선 사항은 메모리 접근을 최소화하고, 계산 효율성을 극대화하는 것입니다.
3.3.2 백워드 패스
백워드 패스는 학습 과정에서 그래디언트를 계산하여 모델의 가중치를 업데이트하는 데 사용됩니다. FlashAttention-2는 백워드 패스를 최적화하여 메모리 사용량을 줄이고 계산 속도를 향상시킵니다.
4. 실증적 검증
FlashAttention-2는 다양한 설정에서 기존 FlashAttention 대비 약 2배의 속도 향상을 보였습니다. 특히, 이런 개선은 GPT 스타일 모델의 학습 속도를 증가시켜, 최대 225 TFLOPs/s의 성능을 A100 GPU에서 달성하였습니다.
TinyLlama-1.1B
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of “just” 90 days using 16 A100-40G GPUs. The training has started on 2023-09-01.
We adopted exactly the same architecture and tokenizer as Llama 2. This means TinyLlama can be plugged and played in many open-source projects built upon Llama. Besides, TinyLlama is compact with only 1.1B parameters. This compactness allows it to cater to a multitude of applications demanding a restricted computation and memory footprint.
Releases Schedule
We will be rolling out intermediate checkpoints following the below schedule. We also include some baseline models for comparison.
Date | HF Checkpoint | Tokens | Step | HellaSwag Acc_norm |
---|---|---|---|---|
Baseline | StableLM-Alpha-3B | 800B | – | 38.31 |
Baseline | Pythia-1B-intermediate-step-50k-105b | 105B | 50k | 42.04 |
Baseline | Pythia-1B | 300B | 143k | 47.16 |
2023-09-04 | TinyLlama-1.1B-intermediate-step-50k-105b | 105B | 50k | 43.50 |
Meanwhile, you can track the live cross entropy loss here.
Potential Usecase
Tiny but strong language models are useful for many applications. Here are some potential usecases:
Moreover, our code can be a reference for enthusiasts keen on pretraining language models under 5 billion parameters without diving too early into Megatron-LM.
Training Details
Below are some details of our training setup:
Setting | Description |
---|---|
Parameters | 1.1B |
Attention Variant | Grouped Query Attention |
Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632 |
Sequence Length | 2048 |
Batch Size | 2 million tokens (2048 * 1024) |
Learning Rate | 4e-4 |
Learning Rate Schedule | Cosine with 2000 warmup steps |
Training Data | Slimpajama & Starcoderdata |
Data Preprocessing | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata |
Combined Dataset Size | Around 950B tokens |
Total Tokens During Training | 3 trillion (slightly more than 3 epochs/1430k steps) |
Natural Language to Code Ratio | 7:3 |
Hardware | 16 A100-40G GPUs |
Blazingly Fast
Our codebase supports the following features:
Thanks to those optimizations, we achieve a throughput of 24k tokens per second per A100-40G GPU, which translates to 56% model flops utilization without activation checkpointing (We expect the MFU to be even higher on A100-80G). It means you can train a chinchilla-optimal TinyLlama (1.1B param, 22B tokens) in 32 hours with 8 A100. Those optimizations also greatly reduce the memory footprint, allowing us to stuff our 1.1B model into 40GB GPU RAM and train with a per-gpu batch size of 16k tokens. You can also pretrain TinyLlama on 3090/4090 GPUs with a smaller per-gpu batch size. Below is a comparison of the training speed of our codebase with that of Pythia and MPT.
Model | A100 GPU hours taken on 300B tokens |
---|---|
TinyLlama-1.1B | 3456 |
Pythia-1.0B | 4830 |
MPT-1.3B | 7920 |
The Pythia number comes from their paper. The MPT number comes from here, in which they say MPT-1.3B “ was trained on 440 A100-40GBs for about half a day” on 200B tokens.
The fact that TinyLlama is a relatively small model with grouped query attention means it is also fast during inference. Below are some throughputs that we measure:
Framework | Device | Settings | Throughput (tokens/sec) |
---|---|---|---|
Llama.cpp | Mac M2 16GB RAM | batch_size=1; 4-bit inference | 71.8 |
vLLM | A40 GPU | batch_size=100, n=10 | 7094.5 |