00:00:00

Share Your Feedback 🏝️

Flex Attention

Flex Attention

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Google | Compute Optimal Next: LLM Format Impact

Flex Attention

  • Related Project: Private
  • Category: Paper Review
  • Date: 2024-08-08

Flex Attention

  • url: https://pytorch.org/blog/flexattention/
  • abstract: In theory, the transformative paper “Attention is All You Need” posits that simple attention mechanisms suffice for complex tasks. In practice, however, the demand for optimized implementations such as FlashAttention is paramount to handle extensive contexts efficiently. Despite the performance gains from these optimized implementations, they often restrict flexibility, locking users into predefined kernels and preventing easy experimentation with novel attention variants. For instance, integrating variants like Causal, Relative Positional Embeddings, or Alibi into existing systems necessitates extensive custom coding, which could deter innovation due to the high entry barrier and potential computational inefficiencies. To address these limitations, we introduce FlexAttention, a novel PyTorch API designed to revolutionize how attention mechanisms are implemented and extended. FlexAttention allows researchers to define a wide range of attention behaviors using idiomatic PyTorch code that integrates seamlessly into existing models without the need for custom kernels. This flexibility is achieved through a dynamic compilation approach that optimizes runtime and memory usage without sacrificing the versatility needed for cutting-edge research. By empowering developers to explore new attention mechanisms effortlessly, FlexAttention paves the way for more innovative and adaptable machine learning models. Our extensive benchmarks, available in the Attention Gym repository, demonstrate that FlexAttention not only supports a diverse array of attention mechanisms but also maintains competitive performance compared to traditional implementations. Visit the Attention Gym for examples and to contribute: Attention Gym GitHub Repository.

  1. PyTorch 기반의 유연한 API를 통해 다양한 어텐션 변형을 구현할 수 있습니다.
  2. 기존 구현 대비 메모리 사용량 감소와 실행 시간 단축을 달성했습니다.
  3. 다양한 어텐션 변형 및 마스킹을 지원하며, 사용자 정의 가능한 점수 수정 기능을 제공합니다.

1. 서론

FlexAttention은 기존 어텐션 메커니즘의 한계를 극복하고, 다양한 어텐션 변형을 손쉽게 구현할 수 있는 PyTorch API입니다. 본 논문에서는 기존에 최적화된 어텐션 구현이 가진 유연성의 결여를 해결하고자 하며, 이를 통해 새로운 어텐션 변형을 실험하는 데에 있어 필요한 개발 부담을 줄이고자 합니다.


2. 관련 연구 및 배경

어텐션 메커니즘은 주로 Query(Q), Key(K), Value(V) 세 요소의 상호작용을 통해 정보를 가중치에 따라 집중시키는 방식으로 작동합니다. 기존의 구현에서는 다음과 같은 수식을 통해 계산이 진행됩니다.

\[\text{score} = \frac{QK^T}{\sqrt{d_k}}\]

$d_k$는 헤드의 차원을 나타냅니다. 이 점수는 softmax 함수를 거쳐 확률로 변환되며, 이를 기반으로 최종 출력이 계산됩니다.

\[\text{output} = \text{softmax}(\text{score})V\]

이런 표준 어텐션 외에도 다양한 변형이 연구되었으나, 각각의 구현이 별도의 최적화가 필요하다는 점에서 한계를 가지고 있었습니다.


3. FlexAttention 메커니즘

FlexAttention은 사용자가 정의할 수 있는 score_mod 함수를 통해 어텐션 점수를 수정할 수 있게 하여, 다양한 어텐션 변형을 간단히 구현할 수 있도록 합니다. 기본적인 어텐션 점수 계산 후에 score_mod를 적용하면 다음과 같습니다. 코드에서 정의된 과정을 반영하여 수식을 수정하면 다음과 같습니다.

\[\begin{align*} \text{Q, K, V} &: \text{Tensor}[batch\_size, num\_heads, sequence\_length, head\_dim] \\ \text{score} &= \frac{QK^T}{\sqrt{\text{head_dim}}} \\ \text{modified_scores} &= \text{score_mod}(\text{score}) \\ \text{probabilities} &= \text{softmax}(\text{modified_scores}, \text{dim}=-1) \\ \text{output} &= \text{probabilities} \times V \end{align*}\]

이 수정된 점수는 softmax를 거쳐 최종 출력값을 계산하는 데 사용되며, PyTorch의 자동 미분 기능을 사용하여 역전파가 자동으로 생성되며, sparsity를 활용하여 성능을 개선할 수 있습니다.

FlexAttention API를 사용하여 다양한 어텐션 변형을 구현하고 성능을 평가한 결과, 기존 구현 대비 메모리 사용량이 감소하고 실행 시간이 단축되는 등의 개선이 확인되었습니다. 특히, 문서 마스킹, 슬라이딩 윈도우 어텐션, ALiBi 같은 다양한 변형을 손쉽게 구현할 수 있었습니다.

FlexAttention은 어텐션의 다양한 변형을 간편하게 실험할 수 있는 유연성을 제공하며, 기존의 구현보다 향상된 성능을 보인다고 언급하며 이런 구현을 통해 머신러닝 연구에서의 소프트웨어 복잡성을 감소시킬 수 있을 수 있습니다. 향후 연구에서는 이 API를 확장하여 더 다양한 어텐션 변형을 지원, 더욱 최적화 하는 방향을 탐색할 예정이라고 합니다.


서론

이론적으로는 “어텐션(Attention)만이 필요하다”고 알려져 있지만, 실제로는 FlashAttention 같은 최적화된 어텐션 구현이 필요합니다.

이런 통합된 어텐션 구현은 성능을 크게 향상시키고 긴 컨텍스트를 가능하게 했지만, 유연성의 손실이라는 대가가 따릅니다. 이제 몇 줄의 PyTorch 연산자로 새로운 어텐션 변형을 시도하는 것이 아니라 새로운 사용자 정의 커널을 작성해야 합니다! 이는 ML 연구자들에게 “소프트웨어 로또”와 같은 상황을 만들어냅니다 - 만약 당신의 어텐션 변형이 기존의 최적화된 커널 중 하나에 맞지 않는다면, 느린 런타임과 CUDA OOM(Out Of Memory)이 예정되어 있습니다.

여러 어텐션 변형에는 다음과 같은 것들이 있습니다.

  • 인과적 (Causal)
  • 상대 위치 임베딩 (Relative Positional Embeddings)
  • Alibi
  • 슬라이딩 윈도우 어텐션 (Sliding Window Attention)
  • PrefixLM
  • 문서 마스킹/샘플 패킹/불규칙 텐서 (Document Masking/Sample Packing/Jagged Tensors)
  • 탄젠트 소프트-캡핑 (Tanh Soft-Capping)
  • PagedAttention 등


FlexAttention 소개

이런 ‘하이퍼큐브’ 문제를 한 번에 해결하기 위해, 새로운 PyTorch API인 FlexAttention을 소개합니다.

  • 유연한 API: 많은 어텐션 변형을 몇 줄의 관용적인 PyTorch 코드로 구현할 수 있습니다.
  • 통합된 FlashAttention 커널: torch.compile을 통해 하나의 통합된 FlashAttention 커널로 낮추어, 추가 메모리를 발생시키지 않고 수작업으로 작성된 것과 경쟁력 있는 성능을 제공합니다.
  • 자동 역전파 생성: PyTorch의 autograd 기계를 이용해 자동으로 역전파를 생성합니다.
  • 어텐션 마스크의 희소성 활용: 어텐션 마스크의 희소성을 활용하여 표준 어텐션 구현보다 상당한 개선을 달성할 수 있습니다.

FlexAttention의 예제는 Attention Gym에서 확인할 수 있습니다.


어텐션 방정식과 코드 구현

기본 어텐션 방정식

기본적인 어텐션 방정식은 다음과 같습니다.

\[\text{score} = \frac{QK^T}{\sqrt{\text{head_dim}}}\]

$Q, K, V$는 각각 쿼리, 키, 값 텐서입니다.

Q, K, V = Tensor[batch_size, num_heads, sequence_length, head_dim]
score = (Q @ K) / sqrt(head_dim)
probabilities = softmax(score, dim=-1)
output = probabilities @ V

사용자 정의 함수 score_mod를 통한 점수 수정

FlexAttention은 사용자 정의 함수 score_mod를 허용하여 softmax 이전에 어텐션 점수를 수정할 수 있습니다. 이 기능은 대부분의 어텐션 변형에 충분합니다.

Q, K, V = Tensor[batch_size, num_heads, sequence_length, head_dim]
score = (Q @ K) / sqrt(head_dim)
modified_scores = score_mod(score)
probabilities = softmax(modified_scores, dim=-1)
output = probabilities @ V

FlexAttention의 표현력

이 API는 예상치 못한 방식으로 표현력이 뛰어나며, 많은 어텐션 변형을 간단하고 효율적으로 구현할 수 있습니다.


[참고자료 1] 하이퍼 큐브 문제

위에서 언급된 “하이퍼 큐브 문제”는 어텐션 메커니즘을 구현할 때 발생하는 복잡성을 비유적으로 설명한 것으로, 하이퍼 큐브는 수학에서 4차원 이상의 공간에서의 큐브를 의미하며, 이 비유를 통해 다양한 어텐션 변형들을 동시에 고려하고 최적화해야 하는 어려움을 설명합니다.

문제의 핵심

모델에서 다양한 어텐션 변형(e.g., 인과적 어텐션, 상대 위치 임베딩, 슬라이딩 윈도우 어텐션 등)을 지원하려고 할 때, 이를 효율적으로 구현하기가 어렵습니다. 각각의 변형은 자체적인 특성과 제약 조건을 가지고 있어서, 하나의 공통된 커널로 최적화하기가 힘들기 때문입니다. 이런 문제는 “소프트웨어 로또”처럼 느껴질 수 있는데, 특정 어텐션 변형이 기존의 최적화된 커널에 잘 맞지 않는다면 성능 저하나 메모리 부족(CUDA OOM) 같은 문제가 발생할 수 있기 때문입니다.

하이퍼 큐브 문제의 시각적 비유

하이퍼 큐브 문제를 시각적으로 표현한다면, 차원이 증가할수록 기하급수적으로 복잡해지는 구조를 생각할 수 있습니다. 각각의 어텐션 변형이 이 하이퍼 큐브의 한 축을 담당하며, 이 모든 축을 동시에 다루어야 하는 문제가 바로 하이퍼 큐브 문제입니다.

하이퍼 큐브를 이해하기 위해 3차원 큐브와 4차원 이상의 하이퍼 큐브를 비교하는 Figure을 생각할 수 있습니다. 이 Figure에서 각 변형은 큐브의 한 변을 나타내며, 4차원 이상의 공간에서는 그 복잡성이 더욱 커집니다.

alt text alt text alt text alt text

*출처: Visualizing Higher Dimensions

Previous: Google | Compute Optimal Next: LLM Format Impact

post contain ""

    No matching posts found containing ""