1. 서론
트랜스포머 모델은 반복적인 시퀀스 모델을 대체하는 유명한 구조로 자리잡았습니다. 이 모델은 attention 레이어를 사용하여 시퀀스 간의 정보를 효과적으로 전달합니다. 그러나, 트랜스포머 모델의 증분 인퍼런스 속도는 Key, Value 텐서를 다시 불러오는 데 필요한 메모리 대역폭에 의해 제한됩니다. 이 연구에서는 Multi-head Attention의 성능을 분석하고, 새로운 구조인 Multi-query Attention를 제안하여 인퍼런스 속도를 개선할 방안을 모색합니다.
2. 이론적 배경 및 선행 연구
2.1 Dot-Product Attention
Dot-Product Attention는 주어진 질의 벡터 \(q\)와 여러 키-값 쌍 \((K, V)\)을 입력으로 받아 출력 벡터 \(y\)를 생성합니다. 이 출력은 값 벡터의 가중합으로 계산되며, 가중치는 질의와 각 키 사이의 점곱을 통해 결정됩니다. 수식으로 표현하면 다음과 같습니다.
\(y = \text{softmax}(qK^T)V\)
상기 수식에서 \(\text{softmax}\)는 키 벡터와의 유사도를 확률로 변환하는 함수입니다.
2.2 Multi-head Attention
트랜스포머 모델은 입력 벡터 \(x\)로부터 \(h\)개의 다른 질의 벡터를 생성하고, 이를 독립적인 attention 메커니즘에 적용합니다. 이를 통해 다양한 표현의 학습이 가능해지며, 모델의 표현력이 향상됩니다. 각 머리는 입력 \(x\), \(M\)에 대한 선형 변환 \(P_q, P_k, P_v\)를 적용하고, 각 머리의 출력은 합산되어 최종 출력 \(y\)를 형성합니다.
\[y = \sum_{i=1}^h \text{softmax}(xP_{q_i} (MP_{k_i})^T)MP_{v_i}\]3. 방법: Multi-query Attention
Multi-query Attention는 기존의 Multi-head Attention와 유사하지만, 모든 머리가 동일한 키와 값 집합을 공유합니다. 이 변경으로 인해 메모리 접근과 계산량이 크게 감소합니다. 구체적으로, 각 단계에서 메모리 접근과 계산의 비율은 다음과 같이 표현할 수 있습니다.
다음을 좀 더 깔끔하게 만들면 다음과 같습니다.
\[\text{Memory-to-Computation Ratio} = O\left(\frac{1}{b}\right)\]메모리 대 연산 비율은 \(O\left(\frac{1}{b}\right)\)로 나타낼 수 있으며, \(b\)는 배치 크기입니다. 이 비율의 감소는 계산 속도의 향상으로 직결됩니다.
4. 실험 및 결과
WMT 2014 영어-독일어 번역 태스크를 사용하여 모델을 평가했습니다. Multi-query Attention 모델은 기존 모델과 비슷한 품질을 유지하면서도 메모리 사용량과 계산 시간을 줄였습니다. 학습과 인퍼런스 속도 모두 기존 대비 개선된 결과를 보였으며, 이는 Multi-query Attention가 트랜스포머 모델의 효율적인 변형임을 입증합니다.
5. 결론
본 연구에서 제안한 Multi-query Attention는 트랜스포머 모델의 계산 및 메모리 효율성을 개선함으로써, 큰 시퀀스나 복잡한 모델에서의 활용 가능성을 높입니다.
1 구글에서 TPU 개발을 하던 개발진들과 NVIDIA 개발진들이 모여 설립한 GROQ은 기존 컴퓨터 하드웨어보다 처리 능력이 더 우수한 칩을 사용하여 AI 계산을 가속화하는 것을 목표로 다양한 서비스를 준비하고 있습니다.
2 신경처리장치(Neural Processing Unit, NPU)와 관련된 다양한 기업들이 있지만, 이것들과 다르게 firework.ai는 FireAttention 그리고 최근 Nitro 모드에서는 더욱 더 좋은 성능들을 보여주고 있습니다.
- Firework Blog at https://fireworks.ai/blog
- Firework Demo at https://fireworks.ai/models
- Firework LLM Inference Benchmark at https://fireworks.ai/blog/TextGenerationLLM-inference-performance-benchmarking-part-1