Contents
편의상 Frquency와 Wavelength는 프리퀀시와 파장으로 번역
프로젝트 깃 허브 레포지토리에서 최대 128k 컨텍스트 길이까지의 모델을 사용할 수 있음.
1 서론
Transformer 기반 대규모 언어모델(LLMs)은 많은 자연어 처리(NLP) 작업에서 거의 필수적인 선택이 되었습니다. 이런 작업들에서, pre-training된 LLM의 최대 시퀀스 길이(context window)는 주요 제한 중 하나입니다. 훈련 과정에서 결정된 context window을 동적으로 확장할 수 있는 능력은 점점 더 바람직해지고 있습니다. 이를 위해, 트랜스포머의 위치 인코딩이 중요한 논의의 중심이 되었습니다.
기존의 Transformer 아키텍처는 절대적인 사인 위치 인코딩을 사용했으며, 이는 학습 가능한 절대 위치 인코딩으로 개선되었습니다. 이후 상대적 위치 인코딩 스키마가 Transformer의 성능을 더욱 향상시켰습니다. 현재 가장 유명한 상대적 위치 인코딩은 T5 상대적 편향, RoPE, XPos, ALiBi 등입니다.
위치 인코딩의 반복되는 제한 사항은 훈련 중에 본 context window을 넘어서 일반화하는 능력의 부족입니다. 일부 방법들은 제한된 일반화를 할 수 있지만, ALiBi와 같은 방법은 pre-trained 길이보다 훨씬 긴 시퀀스로 일반화할 수 없습니다.
이런 제한을 극복하기 위해 몇 가지 작업이 수행되었습니다. [9]와 [21]은 위치 보간(PI)을 통해 RoPE를 약간 수정하고 소량의 데이터에 대해 파인튜닝을 함으로써 컨텍스트 길이를 확장할 것을 제안했습니다. 대안으로, [6]은 하이-프리퀀시의 손실을 고려한 “NTK-aware” 보간을 제안했습니다. 이후 “NTK-aware” 보간의 두 가지 개선이 제안되었으며, 각기 다른 강조점을 두고 있습니다.
“NTK-aware” 보간과 “Dynamic NTK” 보간은 이미 Code Llama [31] (NTK-aware 보간 사용) 및 Qwen 7B [2] (Dynamic NTK 사용)와 같은 오픈 소스 모델에서 그 존재감을 드러내고 있습니다.
이 논문에서는 이전에 발표되지 않은 “NTK-aware”, “Dynamic NTK”, 및 “NTK-by-parts” 보간에 대한 전체적인 설명을 제공하고, RoPE를 사용하여 훈련된 모델의 context window을 효율적으로 확장할 수 있는 개선된 방법, YaRN (Yet another RoPE extensioN method)을 제시합니다.
YaRN은 원래 pre-training 데이터의 약 0.1% 미만으로 파인튜닝한 후 context window 확장에서 최고의 성능을 달성합니다. 동시에, 인퍼런스 시간 기법인 Dynamic Scaling과 결합함으로써, Dynamic-YaRN은 어떠한 파인튜닝 없이도 2배 이상의 context window 확장을 가능하게 합니다.
2 배경 및 관련 작업
2.1 Rotary Position Embeddings
선행연구와의 비교
항목 | 선행 연구(RoPE 및 기타 확장) | YaRN 방법의 개선점 |
---|---|---|
문제점 | context window 확장 시 일반화 부족 | 효율적인 context window 확장 및 외삽 가능성 향상 |
기술적 접근 | 위치 보간(PI), “NTK-aware” 보간 사용 | “NTK-by-parts” 보간 및 동적 스케일링 사용 |
성능 | 제한된 컨텍스트에서의 일반화 | 파인튜닝된 데이터셋 제한을 뛰어넘어 훨씬 더 긴 컨텍스트에서 효과적인 외삽 |
효율성 | 하이-프리퀀시 정보의 손실 및 training dataset 요구량 | 훈련 단계 및 데이터 요구량 감소로 더 높은 효율성 달성 |
본 연구의 기초는 [34]에서 소개된 Rotary Position Embedding (RoPE)입니다. 숨겨진 뉴런의 집합이 \(D\)로 표시된 숨겨진 계층에서 작업합니다. 벡터 시퀀스 \(x_1, \cdots, x_L \in \mathbb{R}^{\\|D\\|}\)가 주어졌을 때, [34]의 표기법을 따라, attention 계층은 먼저 벡터를 쿼리 벡터와 키 벡터로 변환합니다.
\[q_m = f_q(x_m, m) \in \mathbb{R}^{\\|D\\|}, \quad k_n = f_k(x_n, n) \in \mathbb{R}^{\\|D\\|} \tag{1}\]다음으로, attention 가중치는 다음과 같이 계산됩니다.
\[\text{softmax}\left(\text{q_m^T k_n}{\sqrt{\\|D\\|}}\right) \tag{2}\]\(q_m, k_n\)은 열 벡터로 간주되어 \(q_m^T k_n\)은 단순히 유클리드 내적입니다. RoPE에서는 \(\\|D\\|\)가 짝수라고 가정하고, 임베딩 공간과 hidden state를 2-복소 벡터 공간으로 식별합니다. \(\mathbb{R}^{\\|D\\|} \sim \mathbb{C}^{\\|D\\|/2}\), 내적 \(q^T k\)는 표준 헤르미트 내적의 실수 부분 \(\Re(q^* k)\)가 됩니다. 좀 더 구체적으로, 동형사상은 실수 부분과 복소 부분을 교차시킵니다.
\[(x_m)_1, \cdots, (x_m)_{\\|D\\|} \to (x_m)_1 + i (x_m)_2, \cdots, (x_m)_{\\|D\\| - 1} + i (x_m)_{\\|D\\|} \tag{3}\] \[(q_m)_1, \cdots, (q_m)_{\\|D\\|} \to (q_m)_1 + i (q_m)_2, \cdots, (q_m)_{\\|D\\| - 1} + i (q_m)_{\\|D\\|} \tag{4}\]임베딩 \(x_m, x_n\)을 쿼리 및 키 벡터로 변환하기 위해, 먼저 \(R\)-선형 연산자 \(W_q, W_k: \mathbb{R}^{\\|D\\|} \to \mathbb{R}^{\\|D\\|}\)를 제공받습니다. 복소 좌표에서, 함수 \(f_q, f_k\)는 다음과 같이 주어집니다.
\[f_q(x_m, m) = e^{im\theta} W_q x_m, \quad f_k(x_n, n) = e^{in\theta} W_k x_n \tag{5}\]\(\theta = \text{diag}( ext_1, \cdots, \theta_{\\|D\\|/2})\)는 \(\theta_d = b^{-2d/\\|D\\|}\) 및 \(b = 10000\)을 갖는 대각 행렬입니다. 이 방식으로 RoPE는 각각의 (복소수-값) 숨겨진 뉴런을 별도의 프리퀀시 \(\theta_d\)와 연관짓습니다. 이렇게 함으로써, 쿼리 벡터와 키 벡터 사이의 내적은 단순히 상대적 거리 \(m - n\)에만 의존하게 됩니다.
\[\langle f_q(x_m, m), f_k(x_n, n) \rangle_{\mathbb{R}} \tag{6}\] \[= \Re \left( \langle f_q(x_m, m), f_k(x_n, n) \rangle_{\mathbb{C}} \right) \tag{7}\] \[= \Re \left( x_m^* W_q^* W_k x_n e^{i \theta (m - n)} \right) \tag{8}\] \[= g(x_m, x_n, m - n) \tag{9}\]실수 좌표에서, RoPE는 다음과 같은 함수를 사용하여 작성할 수 있습니다.
\[f_W(x_m, m, \theta_d) = \begin{pmatrix} \cos(m \theta_1) & -\sin(m \theta_1) & 0 & 0 & \cdots & 0 & 0 \\ \sin(m \theta_1) & \cos(m \theta_1) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(m \theta_2) & -\sin(m \theta_2) & \cdots & 0 & 0 \\ 0 & 0 & \sin(m \theta_2) & \cos(m \theta_2) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(m \theta_l) & -\sin(m \theta_l) \\ 0 & 0 & 0 & 0 & \cdots & \sin(m \theta_l) & \cos(m \theta_l) \end{pmatrix} W x_m\]따라서 \(f_q = f_{W_q}\), \(f_k = f_{W_k}\)입니다.
2.2 위치 보간(Position Interpolation)
언어 모델은 일반적으로 고정된 컨텍스트 길이로 사전 훈련됩니다. 따라서 RoPE를 위치 인코딩으로 사용하는 언어 모델의 경우, 상대적으로 적은 양의 데이터에 파인튜닝을 함으로써 컨텍스트 길이를 확장하는 방법이 자연스럽습니다. Chen et al. [9] 및 kaiokendev [21]은 위치 보간(PI)을 제안하여 pre-trained 한계를 넘어 컨텍스트 길이를 확장했습니다. 직접적인 외삽은 \(w_1, \cdots, w_L\)과 같은 시퀀스에서 \(L\)보다 큰 길이에서 잘 수행되지 않지만, 위치 지표를 pre-trained 한계 내에서 보간하는 것이 소량의 파인튜닝을 통해 잘 작동함을 발견했습니다. 구체적으로, RoPE를 사용하는 pre-trained 언어 모델이 주어졌을 때, 그들은 RoPE를 다음과 같이 수정합니다.
\[f'_W(x_m, m, \theta_d) = f_W\left( x_m, m \frac{L}{L'}, \theta_d \right) \tag{10}\]\(L' > L\)는 pre-trained 한계를 넘어서는 새로운 context window입니다. 원래의 pre-trained 모델과 수정된 RoPE 공식을 사용하여, 그들은 몇 오더의 크기 더 적은 토큰(Chen et al. [9]의 몇 십억)에 대해 언어 모델을 추가로 파인튜닝하여 context window 확장을 성공적으로 달성했습니다.
2.3 추가 표기법(Additional Notation)
확장된 컨텍스트 길이와 원래의 컨텍스트 길이 사이의 비율은 특별한 중요성을 가지며, 다음과 같이 \(s\)를 정의하고, 스케일 인자로 정의합니다.
\[s = \frac{L'}{L} \tag{11}\]또한 Eq. 10을 다음과 같은 일반 형식으로 재작성하고 간소화합니다.
\[f'_W(x_m, m, \theta_d) = f_W(x_m, g(m), h( ext_d)) \tag{12}\]\(g(m), h( ext_d)\)는 방법에 따라 다른 함수인데, PI의 경우, \(g(m) = \frac{m}{s}\), \(h( ext_d) = \theta_d\)입니다.
후속 섹션에서 새로운 보간 방법을 소개할 때, 함수 \(g(m)\) 및 \(h( ext_d)\)만을 명시하여 설명하고, 다음과 같이 \(\lambda_d\)를 \(d\)번째 향상된에서의 RoPE 임베딩의 파장으로 정의합니다.
\[\lambda_d = \frac{2\pi}{\theta_d} = 2\pi b^{2d/|D|} \tag{13}\]파장은 RoPE 임베딩이 \(d\)번째 차원에서 전체 rotation(\(2\pi\))을 수행하는 데 필요한 토큰의 길이를 설명합니다.
일부 보간 방법(e.g., PI)은 차원의 파장을 고려하지 않지만, 다른 방법들은 고려합니다(e.g., YaRN). 이런 방법들을 “맹목적” 보간 방법과 “표적” 보간 방법으로 분류할 것입니다.
3 방법
PI는 모든 RoPE 차원을 동등하게 늘리지만, PI가 설명하는 이론적인 보간 경계는 RoPE와 LLM의 내부 임베딩 사이의 복잡한 state를 예측하는 데 불충분하다는 것을 발견했습니다.
3.1 Loss of High Frequency Information - “NTK-aware” Interpolation
RoPE를 정보 인코딩의 관점에서만 본다면, [36]에서 사용된 Neural Tangent Kernel (NTK) 이론을 통해, 입력 차원이 낮고 해당 임베딩이 하이-프리퀀시 구성요소를 결여할 경우 심층 신경망이 하이-프리퀀시 정보 학습에 어려움을 겪는다는 것이 입증되었습니다. 유사점을 볼 수 있습니다. 토큰의 위치 정보는 일차원이며, RoPE는 이를 \(n\)차원 복소 벡터 임베딩으로 확장합니다. RoPE는 여러 면에서 Fourier Features [36]와 유사하며, RoPE를 Fourier Feature의 특수한 1D 사례로 정의할 수 있습니다. RoPE 임베딩을 무차별적으로 늘리면 네트워크가 유사하고 가까운 토큰을 구별하는 데 필요한 중요한 하이-프리퀀시 세부 정보가 손실됩니다(가장 작은 거리를 설명하는 rotation은 네트워크가 감지할 수 있을 정도로 작아서는 안 됩니다).
PI [9]에서 더 큰 컨텍스트 크기에 대한 파인튜닝 후 짧은 컨텍스트 크기에 대한 복잡도가 약간 증가하는 것과 관련이 있을 수 있다고 가설을 세웠습니다. 이상적인 상황에서 더 큰 컨텍스트 크기에 대한 파인튜닝이 더 작은 컨텍스트 크기의 성능을 저하시킬 이유는 없습니다.
RoPE 임베딩을 보간할 때 하이-프리퀀시 정보 손실 문제를 해결하기 위해 [6]에서 “NTK-aware” 보간이 개발되었습니다. 모든 RoPE 차원을 \(s\) 요인으로 동등하게 확장하는 대신, 하이-프리퀀시를 덜 확장하고 로우프리퀀시를 더 확장함으로써 보간 압력을 여러 차원에 걸쳐 분산시킵니다. 이런 변환은 여러 방법으로 얻을 수 있지만, 가장 간단한 방법은 \(\theta\)의 값을 기준으로 기본 변경을 수행하는 것입니다. 구체적으로, 2.3절에서 설정한 표기법을 따라 “NTK-aware” 보간 체계를 다음과 같이 정의합니다(인퍼런스의 세부 사항은 Appendix A.1 참조).
정의 1: “NTK-aware” 보간은 Eq. 12를 사용하여 RoPE을 수정한 것입니다.
\[g(m) = m \tag{14}\] \[h( ext_d) = b'^{-2d/|D|} \tag{15}\] \[b' = b \cdot s \cdot \frac{|D|}{|D| - 2} \tag{16}\][6]의 결과에 따르면, 이 방법은 PI [9]에 비해 파인튜닝되지 않은 모델의 컨텍스트 크기를 확장하는 데 훨씬 더 나은 성능을 보입니다. 그러나 이 방법의 주요 단점 중 하나는 단순한 보간 체계가 아니라 일부 차원이 “범위를 벗어난” 값으로 약간 외삽되기 때문에, “NTK-aware” 보간 [6]으로 파인튜닝하면 PI [9]에 비해 열등한 결과를 낳습니다. 또한 “범위를 벗어난” 값으로 인해 이론적인 스케일 인자 \(s\)는 실제 컨텍스트 확장 스케일을 정확하게 설명하지 못합니다. 실제로, 주어진 컨텍스트 길이 확장을 위해 예상보다 더 높은 스케일 값 \(s\)를 설정해야 합니다. 이 글이 발표되기 직전에 Code Llama [31]가 출시되었으며, 기본 \(b\)를 1M으로 수동 조정하여 “NTK-aware” 스케일링을 사용합니다.
3.2 Loss of Relative Local Distances - “NTK-by-parts” Interpolation
PI와 “NTK-aware” 보간과 같은 맹목적인 보간 방법에서는 모든 RoPE 향상된을 동등하게 처리합니다(즉, 네트워크에 동일한 영향을 미칩니다). 그러나, 표적 보간 방법이 필요하다는 강력한 단서가 있습니다. 이 절에서는 RoPE의 공식에서 정의된 파장 \(\lambda_d\)를 중점적으로 생각합니다. 간단히하기 위해 \(\lambda_d\)의 첨자 \(d\)를 생략하고 독자들이 \(\lambda\)를 임의의 주기 함수의 파장으로 생각하도록 권장합니다.
RoPE 임베딩의 흥미로운 관찰 중 하나는 주어진 컨텍스트 크기 \(L\)에 대해 일부 차원 \(d\)에서 파장이 사전 훈련 중에 본 최대 컨텍스트 길이보다 길다는 것입니다(\(\lambda > L\)). 이는 일부 차원의 임베딩이 rotation 도메인에서 고르게 분포되지 않을 수 있음을 시사합니다. 이런 경우, 모든 고유 위치 쌍을 가지고 있다고 가정하면 절대적 위치 정보가 그대로 유지됩니다. 반면, 파장이 짧으면 네트워크가 상대적 위치 정보만 접근할 수 있습니다.
또한, 모든 RoPE 차원을 스케일 \(s\) 또는 기본 변경 \(b'\)로 늘릴 때, 모든 토큰이 서로 더 가까워지면서 두 벡터가 덜 rotation된 상태에서의 내적이 커집니다. 이런 스케일링은 LLM이 내부 임베딩 간의 작고 지역적인 관계를 이해하는 능력을 심각하게 저해합니다. 이런 압축이 모델이 근접한 토큰의 위치 순서를 혼동하게 하여 결국 모델의 능력을 해치게 된다고 가설을 세웠습니다.
이 문제를 해결하기 위해, 이전에 발견한 두 가지 관찰을 고려하여, 높은 프리퀀시 차원은 전혀 보간하지 않으면서 낮은 프리퀀시 차원은 항상 보간하기로 결정했습니다. 특히,
결과적으로, 원래 컨텍스트 크기 \(L\)와 파장 \(\lambda\) 사이의 비율 \(r = \frac{L}{\lambda}\)을 도입하는 것이 더 편리합니다. \(d\)번째 hidden state에서, 비율 \(r\)은 다음과 같이 \(d\)에 따라 달라집니다.
\[r(d) = \frac{L}{\lambda_d} = \frac{L}{2\pi b'^{2d/|D|}} \tag{17}\]위에서 설명한 다양한 보간 전략의 경계를 정의하기 위해, 추가적으로 두 개의 파라미터 \(\alpha\)와 \(\beta\)를 도입합니다. 모든 향상된 \(d\)에서 \(r(d) < \alpha\)인 경우에는 스케일 \(s\)로 선형 보간합니다(PI와 같이 외삽을 피하면서), 그리고 \(r(d) > \beta\)인 경우에는 전혀 보간하지 않습니다. 램프 함수 \(\gamma\)를 다음과 같이 정의합니다.
\[\gamma(r) = \begin{cases} 0 & \text{if } r < \alpha \\ 1 & \text{if } r > \beta \\ \frac{r - \alpha}{\beta - \alpha} & \text{otherwise} \end{cases} \tag{18}\]램프 함수의 도움으로 “NTK-by-parts” 방법은 다음과 같이 설명될 수 있습니다.
정의 2: “NTK-by-parts” 보간은 Eq. 12를 사용하여 RoPE을 수정한 것입니다.
\[g(m) = m \tag{19}\] \[h( ext_d) = \left(1 - \gamma(r(d))\right) \frac{\theta_d}{s} + \gamma(r(d)) \theta_d \tag{20}\]\(\alpha\)와 \(\beta\)의 값은 사례별로 조정해야 합니다. 예를 들어, Llama 모델군의 경우, \(\alpha = 1\), \(\beta = 32\)가 좋은 값으로 실험적으로 밝혀졌습니다. 이 섹션에서 설명한 기법을 사용하여, 결과적인 방법의 변형이 “NTK-by-parts” 보간 [7]이라는 이름으로 출시되었습니다. 이 개선된 방법은 파인튜닝되지 않은 모델과 파인튜닝된 모델 모두에서 이전의 PI [9] 및 “NTK-aware” 보간 방법보다 더 나은 성능을 확인합니다.
3.3 Dynamic Scaling - “Dynamic NTK” Interpolation
다양한 사용 사례에서, 1부터 최대 컨텍스트 크기까지 다양한 시퀀스 길이를 가진 여러 순방향 패스가 수행됩니다. 자동 회귀 생성과 같은 전형적인 예에서 시퀀스 길이는 각 단계마다 1씩 증가합니다. 스케일 인자 \(s\)를 사용하는 보간 방법에는 두 가지 접근 방식이 있습니다(PI, “NTK-aware” 및 “NTK-by-parts” 포함):
첫 번째 방법의 문제점은 모델이 \(L\)보다 작은 길이에서 성능 저하를 경험하고, 시퀀스 길이가 \(L'\)를 초과할 때 갑작스럽게 성능이 저하될 수 있다는 것입니다. 그러나 동적 스케일링을 통해 두 번째 방법을 적용하면, 훈련된 컨텍스트 한계 \(L'\)에 도달했을 때 모델이 즉시 실패하는 대신 점진적으로 성능이 저하되도록 할 수 있습니다. 이 인퍼런스 시간 방법을 동적 스케일링 방법이라고 부르며, “NTK-aware” 보간과 결합될 때는 “Dynamic NTK” 보간이라고 합니다. 이 방법은 처음으로 [14]의 reddit 게시물에서 공개되었습니다.
특히 주목할 만한 사실은 “Dynamic NTK” 보간이 파인튜닝 없이 \(L\)에 pre-trained 모델에서 향상된 성능을 보인다는 것입니다(\(L' = L\)). 이는 Appendix B.3의 실험에서 뒷받침됩니다.
반복되는 순방향 패스에서는 kv-caching [8]이 적용되어 이전 키-값 벡터를 재사용하여 전체 효율성을 개선합니다. RoPE 임베딩이 캐시될 때는 동적 스케일링을 위해 RoPE를 수정하기 전에 kv-임베딩을 캐시해야 합니다. 이는 \(s\)가 변경될 때마다 모든 토큰의 RoPE 임베딩이 변경되기 때문입니다.
3.4 YaRN
이전의 보간 기술에 더해, 주의 깊게 관찰한 결과 로짓에 온도 \(t\)를 도입하는 것이 확장된 context window에서 데이터 샘플과 토큰 위치에 관계없이 복잡도에 균일한 영향을 미친다는 것을 발견했습니다. (Appendix A.2 참조)
보다 정확하게는, Eq. 2 대신에 attention weights 계산을 수정하여 다음과 같이 표현합니다.
\[A_{mn} = \text{softmax}\left(\text{q_m^T k_n}{\sqrt{|D| t}}\right) \tag{21}\]RoPE를 2D 행렬 집합으로 재구성하는 것은 이런 attention scaling의 구현에 분명한 이점을 제공합니다. 단순히 복소 RoPE 임베딩을 동일한 비율로 스케일링함으로써 \(q_m\)과 \(k_n\)을 상수 인자 \(\sqrt{1/t}\)로 스케일링하는 “길이 스케일링” 기법을 사용할 수 있습니다. 이를 통해 YaRN은 코드를 수정하지 않고도 어텐션 메커니즘을 효과적으로 변형할 수 있습니다. 또한, RoPE 임베딩이 사전에 생성되고 모든 순방향 패스에 재사용되므로 인퍼런스 및 훈련 모두에서 추가적인 오버헤드가 없습니다. 이를 “NTK-by-parts” 보간과 결합함으로써, YaRN 방법을 완성합니다.
정의 3: “YaRN 방법”은 Eq. 21에서 소개된 attention scaling과 3.2절에서 소개된 “NTK-by-parts” 보간의 결합을 의미합니다.
LLaMA 및 Llama 2 모델에 대해서는 다음과 같은 값을 권장합니다.
\[t = \sqrt{\frac{1}{s}} \tag{22}\]위의 방정식은 “NTK-by-parts” 방법(3.2절)을 사용하여 LLaMA 7b, 13b, 33b 및 65b 모델에서 다양한 요인 \(s\)에 의해 확장된 스케일에서 가장 낮은 복잡도에 맞춰 \(\sqrt{1/t}\)를 계산함으로써 도출되었습니다. 이와 같은 \(t\) 값은 Llama 2 모델(7b, 13b 및 70b)에도 상당히 잘 적용됩니다. 이는 증가된 엔트로피와 온도 상수 \(t\)가 일정 정도의 “보편성”을 가지며 일부 모델과 training dataset 간에 일반화될 수 있음을 시사합니다.
YaRN 방법은 모든 발견을 결합하고 파인튜닝 및 파인튜닝되지 않은 시나리오 모두에서 이전의 모든 방법들을 능가할 수 있음을 확인합니다. YaRN은 Flash Attention 2 [13]과 같은 attention mechanism을 수정하는 라이브러리와 직접 호환됩니다.
4 실험
YaRN 모델은 RoPE 위치 임베딩을 활용하여 언어 모델의 context window을 확장하는 것을 성공적으로 보여주며, 이는 단지 모델의 원래 전처리 코퍼스의 0.1%에 해당하는 400 훈련 단계로 달성되었습니다. 이는 Rozière 등[31]의 연구보다 10배, Chen 등[9]의 연구보다 2.5배 줄인 결과입니다. 결과 모델을 평가하기 위해 장문의 문서에 대한 퍼플렉시티를 계산하고 설정된 벤치마크에서 점수를 매기며, 모든 기타 context window 확장 방법들을 초과하는 성능을 발견하였습니다.
4.1 훈련
LLaMA 2 모델의 7B 및 13B 파라미터 버전을 사용하여 훈련을 진행하였습니다. LLaMA 모델 아키텍처에는 다른 변경 사항 없이, 3.4에서 설명한 대로 임베딩 프리퀀시의 계산만 추가되었습니다. 학습률은 \(2 \times 10^{-5}\), 가중치 감소는 없이, 20단계의 선형 웜업과 함께 AdamW[24]의 \(\beta_1 = 0.9\) 및 \(\beta_2 = 0.95\)를 사용하였습니다. \(s = 16\)에 대해 글로벌 배치 크기 64로 400 단계를 fine-tuned하였으며, PyTorch[26]의 Fully Sharded Data Parallelism[42] 및 Flash Attention 2[13]를 사용하여 PG19 데이터셋[29]의 64k 세그먼트로 청크된 데이터를 사용하였습니다. \(s = 32\)의 경우에는 같은 절차를 따랐으나, 완성된 \(s = 16\) 체크포인트에서 시작하여 추가적으로 200 단계를 훈련하였습니다.
4.2 외삽 및 전이 학습
Code Llama[31]는 16k 컨텍스트의 데이터셋을 사용하고, 스케일 인자를 \(s \approx 88.6\)로 설정하여 실질적으로 355k의 context window을 갖게 되며, 네트워크가 훈련 중 해당 컨텍스트 크기를 본 적이 없음에도 불구하고 최대 100k 컨텍스트까지 외삽할 수 있다는 것을 보입니다.
YaRN 또한 3.1 및 Rozière 등[31]과 유사하게 데이터셋 길이보다 더 높은 스케일 인자 \(s\)로 훈련을 수행하였습니다.
컴퓨트 제약으로 인해 \(s = 32\)만을 추가 fine-tuning하여 200 단계 동안 64k 컨텍스트를 사용하여 테스트하였습니다.
\(s = 32\) 모델은 훈련 중 64k 컨텍스트만을 사용하여 최대 128k 컨텍스트까지 성공적으로 외삽하여, “blind” 내삽 방법들보다 훨씬 더 효율적인 전이 학습을 가능하게하고, \(s = 32\) 모델은 \(s = 16\) 모델과 전체 컨텍스트 크기에서 동일함에도 불구하고 단 200 단계의 훈련으로 완성됩니다. 이는 네트워크가 내삽된 임베딩을 다시 학습할 필요가 없음을 의미할 수 있습니다.
[참고자료 1] 선행 연구와 YaRN 방법 비교
내삽과 외삽은 모델이 보지 못한 새로운 데이터나 컨텍스트 길이에 대응하는 능력을 개선하는 데 기여하는 걸로 알려져있습니다.
내삽(Interpolation)은 모델이 training dataset 범위 내에서 미묘한 패턴이나 특징을 더 잘 이해하도록 돕는 반면, 외삽(extrapolation)은 모델이 training dataset 범위를 넘어서는 새로운 상황에 대응하게 합니다.
선행 연구의 접근 방법
선행 연구들은 주로 Rotary Position Embeddings(RoPE)를 사용하고, 위치 인코딩을 보간(PI) 및 “NTK-aware” 방법으로 조정하여 언어 모델의 context window을 확장하는 방법을 개발하였습니다.
YaRN 방법의 접근 방법
YaRN은 이런 기존 방법을 통합하고, 보다 정교하게 위치 인코딩을 조정합니다. 다음은 YaRN 방법에서 사용되는 수학적 변환들입니다.
이는 프리퀀시별로 다음과 같이 표현됩니다.
\[\theta_d' = \theta_d \cdot \begin{cases} s, & \text{if } \lambda_d \text{ is high} \\ s^2, & \text{if } \lambda_d \text{ is low} \end{cases}\]\(\lambda_d\)는 d번째 차원의 파장, \(s\)는 스케일링 인자입니다.
스케일링 인자 \(s\)는 입력 시퀀스 길이에 따라 변경됩니다.
\[s = \frac{l'}{L}\]\(l'\)는 현재 입력 길이, \(L\)은 훈련 중 사용된 최대 길이입니다.
성능 향상의 근거
선행연구와의 비교
기법 | 프리퀀시 처리 | 데이터 범위 | 성능 향상의 근거 |
---|---|---|---|
PI | 일관된 스케일링 | 내삽 | 컨텍스트 길이 확장 가능 |
NTK-aware 보간 | 프리퀀시별 스케일링 | 내삽 및 외삽 | 고주파 정보 보존 |
YaRN (“NTK-by-parts”) | 프리퀀시별 차별 조정 | 내삽 및 외삽 | 고주파 및 로우프리퀀시 정보의 최적화 및 동적 스케일링 |
YaRN 방법의 수학적 설명 및 연계 이론
YaRN은 기존의 RoPE 방식을 개선하여, 효율적으로 context window을 확장하는 새로운 방법들을 제시하였는데, 그 중 “NTK-by-parts” 보간을 위주로 살펴보면 다음과 같습니다.
\(\theta\)는 위치 정보의 주기성을 나타내며, 변형된 임베딩은 실수 및 복소수의 상호 작용을 통해 내적을 계산할 때 위치 차이 \(m - n\)에 대한 고유 정보를 제공합니다.
\(g(m) = m\)는 기본 위치 함수이고, \(h( ext_d)\)는 스케일 조정 함수입니다.
YaRN은 이런 수학적 접근을 통해 기존 방법에 비해 더 넓은 범위의 컨텍스트에서 모델을 일반화할 수 있으며, 계산 효율성도 향상시켰다고 보고합니다.
[참고자료 2] 외삽(extrapolation) 및 프리퀀시(frequency) 컨셉의 정의와 응용
1. 문제 정의 및 기존 기술의 한계
2. 프리퀀시 개념의 도입
3. 외삽 기법의 필요성 및 개발
4. 외삽과 프리퀀시 조정의 구체적 방법
5. 논문의 주장과 그 근거
Transformer-based Large Language Models[40] (LLMs) have become the near-ubiquitous choice for many natural language processing (NLP) tasks where long-range abilities such as in-context learning (ICL) has been crucial. In performing the NLP tasks, the maximal length of the sequences (the context window) determined by its training processes has been one of the major limits of a pretrained LLM. Being able to dynamically extend the context window via a small amount of fine-tuning (or without fine-tuning) has become more and more desirable. To this end, the position encodings of transformers are the center of the discussions.
The original Transformer architecture used an absolute sinusoidal position encoding, which was later improved to a learnable absolute position encoding [15]. Since then, relative positional encoding schemes [32] have further increased the performance of Transformers. Currently, the most popular relative positional encodings are T5 Relative Bias [30], RoPE [34], XPos [35], and ALiBi [27].
One reoccurring limitation with positional encodings is the inability to generalize past the context window seen during training. While some methods such as ALiBi are able to do limited generalization, none are able to generalize to sequences significantly longer than their pre-trained length [22].
Some works have been done to overcome such limitation. [9] and concurrently [21] proposed to extend the context length by slightly modifying RoPE via Position Interpolation (PI) and fine-tuning on a small amount of data. As an alternative, [6] proposed the “NTK-aware” interpolation by taking the loss of high frequency into account. Since then, two improvements of the “NTK-aware” interpolation have been proposed, with different emphasis:
The “NTK-aware” interpolation and the “Dynamic NTK” interpolation have already seen their presence in the open-source models such as Code Llama [31] (using “NTK-aware” interpolation) and Qwen 7B [2] (using “Dynamic NTK”).
In this paper, in addition to making a complete account of the previous unpublished works on the “NTK-aware”, the “Dynamic NTK” and the “NTK-by-part” interpolations, we present YaRN (Yet another RoPE extensioN method), an improved method to efficiently extend the context window of models trained with Rotary Position Embeddings (RoPE) including the LLaMA [38], the GPT-NeoX [5], and the PaLM [10] families of models.
YaRN reaches state-of-the-art performances in context window extensions after fine-tuning on less than ∼0.1% of the original pre-training data. In the meantime, by combining with the inference-time technique called Dynamic Scaling, the Dynamic-YaRN allows for more than 2x context window extension without any fine-tuning.
The basis of our work is the Rotary Position Embedding (RoPE) introduced in [34]. We work on a hidden layer where the set of hidden neurons are denoted by \(D\). Given a sequence of vectors \(x_1, \cdots, x_L \in \mathbb{R}^{\\|D\\|}\), following the notation of [34], the attention layer first converts the vectors into the query vectors and the key vectors:
\[q_m = f_q(x_m, m) \in \mathbb{R}^{\\|D\\|}, \quad k_n = f_k(x_n, n) \in \mathbb{R}^{\\|D\\|} \tag{1}\]Next, the attention weights are calculated as
\[\text{softmax}\left(\text{q_m^T k_n}{\sqrt{\\|D\\|}}\right) \tag{2}\]where \(q_m, k_n\) are considered as column vectors so that \(q_m^T k_n\) is simply the Euclidean inner product. In RoPE, we first assume that \(\\|D\\|\) is even and identify the embedding space and the hidden states as 2-complex vector spaces: \(\mathbb{R}^{\\|D\\|} \sim \mathbb{C}^{\\|D\\|/2}\) where the inner product \(q^T k\) becomes the real part of the standard Hermitian inner product \(\Re(q^* k)\). More specifically, the isomorphisms interleave the real part and the complex part:
\[(x_m)_1, \cdots, (x_m)_{\\|D\\|} \to (x_m)_1 + i (x_m)_2, \cdots, (x_m)_{\\|D\\| - 1} + i (x_m)_{\\|D\\|} \tag{3}\] \[(q_m)_1, \cdots, (q_m)_{\\|D\\|} \to (q_m)_1 + i (q_m)_2, \cdots, (q_m)_{\\|D\\| - 1} + i (q_m)_{\\|D\\|} \tag{4}\]To convert embeddings \(x_m, x_n\) into query and key vectors, we are first given \(R\)-linear operators \(W_q, W_k: \mathbb{R}^{\\|D\\|} \to \mathbb{R}^{\\|D\\|}\). In complex coordinates, the functions \(f_q, f_k\) are given by
\[f_q(x_m, m) = e^{im\theta} W_q x_m, \quad f_k(x_n, n) = e^{in\theta} W_k x_n \tag{5}\]where \(\theta = \text{diag}( ext_1, \cdots, \theta_{\\|D\\|/2})\) is the diagonal matrix with \(\theta_d = b^{-2d/\\|D\\|}\) and \(b = 10000\). This way, RoPE associates each (complex-valued) hidden neuron with a separate frequency \(\theta_d\). The benefit of doing so is that the dot product between the query vector and the key vector only depends on the relative distance \(m - n\) as follows:
\[\langle f_q(x_m, m), f_k(x_n, n) \rangle_{\mathbb{R}} \tag{6}\] \[= \Re \left( \langle f_q(x_m, m), f_k(x_n, n) \rangle_{\mathbb{C}} \tag{7}\right)\] \[= \Re \left( x_m^* W_q^* W_k x_n e^{i \theta (m - n)} \right) \tag{8}\] \[= g(x_m, x_n, m - n) \tag{9}\]In real coordinates, the RoPE can be written using the following function
\[f_W(x_m, m, \theta_d) = \begin{pmatrix} \cos(m \theta_1) & -\sin(m \theta_1) & 0 & 0 & \cdots & 0 & 0 \\ \sin(m \theta_1) & \cos(m \theta_1) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(m \theta_2) & -\sin(m \theta_2) & \cdots & 0 & 0 \\ 0 & 0 & \sin(m \theta_2) & \cos(m \theta_2) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(m \theta_l) & -\sin(m \theta_l) \\ 0 & 0 & 0 & 0 & \cdots & \sin(m \theta_l) & \cos(m \theta_l) \end{pmatrix} W x_m\]so that \(f_q = f_{W_q}\), \(f_k = f_{W_k}\).
As language models are usually pre-trained with a fixed context length, it is natural to ask how to extend the context length by fine-tuning on a relatively small amount of data. For language models using RoPE as the position embedding, Chen et al. [9], and concurrently kaiokendev [21] proposed the Position Interpolation (PI) to extend the context length beyond the pre-trained limit. While a direct extrapolation does not perform well on sequences \(w_1, \cdots, w_L\) with \(L\) larger than the pre-trained limit, they discovered that interpolating the position indices within the pre-trained limit works well with the help of a small amount of fine-tuning. Specifically, given a pre-trained language model with RoPE, they modify the RoPE by
\[f'_W(x_m, m, \theta_d) = f_W\left( x_m, m \frac{L}{L'}, \theta_d \right) \tag{10}\]where \(L' > L\) is a new context window beyond the pre-trained limit. With the original pre-trained model plus the modified RoPE formula, they fine-tuned the language model further on several orders of magnitude fewer tokens (a few billion in Chen et al. [9]) and successfully achieved context window extension.
The ratio between the extended context length and the original context length has been of special importance, and we introduce the notation \(s\) defined by
\[s = \frac{L'}{L} \tag{11}\]and we call \(s\) the scale factor.
We also rewrite and simplify Eq. 10 into the following general form:
\[f'_W(x_m, m, \theta_d) = f_W(x_m, g(m), h( ext_d)) \tag{12}\]where \(g(m), h( ext_d)\) are method-dependent functions. For PI, we have \(g(m) = \frac{m}{s}\), \(h( ext_d) = \theta_d\). In the subsequent sections, when we introduce a new interpolation method, we sometimes only specify the functions \(g(m)\) and \(h( ext_d)\).
Additionally, we define \(\lambda_d\) as the wavelength of the RoPE embedding at \(d\)-th hidden dimension:
\[\lambda_d = \frac{2\pi}{\theta_d} = 2\pi b^{2d/|D|} \tag{13}\]The wavelength describes the length of tokens needed in order for the RoPE embedding at dimension \(d\) to perform a full rotation (\(2\pi\)).
Given that some interpolation methods (e.g., PI) do not care about the wavelength of the dimensions, we will refer to those methods as “blind” interpolation methods, while others do (e.g., YaRN), which we will classify as “targeted” interpolation methods.
ReRoPE [33] also aims to extend the context size of existing models pre-trained with RoPE, and claims “infinite” context length without needing any fine-tuning. This claim is backed by a monotonically decreasing loss with increasing context length up to 16k on the Llama 2 13B model. It achieves context extension by modifying the attention mechanism and thus is not purely an embedding interpolation method. Since it is currently not compatible with Flash Attention 2 [13] and requires two attention passes during inference, we do not consider it for comparison.
Concurrently with our work, LM-Infinite [16] proposes similar ideas to YaRN, but focuses on “on-the-fly” length generalization for non-fine-tuned models. Since they also modify the attention mechanism of the models, it is not an embedding interpolation method and is not immediately compatible with Flash Attention 2.
Whereas PI stretches all RoPE dimensions equally, we find that the theoretical interpolation bound described by PI [9] is insufficient at predicting the complex dynamics between RoPE and the LLM’s internal embeddings. In the following subsections, we describe the main issues with PI we have individually identified and solved, so as to give the readers the context, origin, and justifications of each method which we use in concert to obtain the full YaRN method.
If we look at RoPE only from an information encoding perspective, it was shown in [36], using Neural Tangent Kernel (NTK) theory, that deep neural networks have trouble learning high-frequency information if the input dimension is low and the corresponding embeddings lack high-frequency components. Here we can see the similarities: a token’s positional information is one-dimensional, and RoPE expands it to an \(n\)-dimensional complex vector embedding. RoPE closely resembles Fourier Features [36] in many aspects, as it is possible to define RoPE as a special 1D case of a Fourier Feature. Stretching the RoPE embeddings indiscriminately results in the loss of important high-frequency details which the network needs in order to resolve tokens that are both very similar and very close together (the rotation describing the smallest distance needs to not be too small for the network to be able to detect it).
We hypothesize that the slight increase of perplexity for short context sizes after fine-tuning on larger context sizes seen in PI [9] might be related to this problem. Under ideal circumstances, there is no reason that fine-tuning on larger context sizes should degrade the performance of smaller context sizes.
In order to resolve the problem of losing high-frequency information when interpolating the RoPE embeddings, the “NTK-aware” interpolation was developed in [6]. Instead of scaling every dimension of RoPE equally by a factor \(s\), we spread out the interpolation pressure across multiple dimensions by scaling high frequencies less and low frequencies more. One can obtain such a transformation in many ways, but the simplest would be to perform a base change on the value of \(\theta\). More precisely, following the notations set out in Section 2.3, we define the “NTK-aware” interpolation scheme as follows (see Appendix A.1 for the details of the deduction).
Definition 1: The “NTK-aware” interpolation is a modification of RoPE by using Eq. 12 with the following functions:
\[g(m) = m \tag{14}\] \[h( ext_d) = b'^{-2d/|D|} \tag{15}\] \[b' = b \cdot s \cdot \frac{|D|}{|D| - 2} \tag{16}\]Given the results from [6], this method performs much better at extending the context size of non-fine-tuned models compared to PI [9]. However, one major disadvantage of this method is that given it is not just an interpolation scheme, some dimensions are slightly extrapolated to “out-of-bound” values, thus fine-tuning with “NTK-aware” interpolation [6] yields inferior results to PI [9]. Furthermore, due to the “out-of-bound” values, the theoretical scale factor \(s\) does not accurately describe the true context extension scale. In practice, the scale value \(s\) has to be set higher than the expected scale for a given context length extension. We note that shortly before the release of this article, Code Llama [31] was released and uses “NTK-aware” scaling by manually scaling the base \(b\) to 1M.
In the case of blind interpolation methods like PI and “NTK-aware” interpolation, we treat all the RoPE hidden dimensions equally (as in they have the same effect on the network). However, there are strong clues that point us towards the need for targeted interpolation methods. In this section, we think heavily in terms of the wavelengths \(\lambda_d\) defined in Eq. 13 in the formula of RoPE. For simplicity, we omit the subscript \(d\) in \(\lambda_d\) and the reader is encouraged to think about \(\lambda\) as the wavelength of an arbitrary periodic function.
One interesting observation of RoPE embeddings is that given a context size \(L\), there are some dimensions \(d\) where the wavelength is longer than the maximum context length seen during pretraining \(\lambda > L)\). This suggests that some dimensions’ embeddings might not be distributed evenly in the rotational domain. In such cases, we presume having all unique position pairs implies that the absolute positional information remains intact. On the contrary, when the wavelength is short, only relative positional information is accessible to the network.
Moreover, when we stretch all the RoPE dimensions either by a scale \(s\) or using a base change \(b'\), all tokens become closer to each other, as the dot product of two vectors rotated by a lesser amount is bigger. This scaling severely impairs a LLM’s ability to understand small and local relationships between its internal embeddings. We hypothesize that such compression leads to the model being confused on the positional order of close-by tokens, and consequently harming the model’s abilities.
In order to remedy this issue, given the two previous observations that we have found, we choose not to interpolate the higher frequency dimensions at all while always interpolating the lower frequency dimensions. In particular,
As a result, it is more convenient to introduce the ratio \(r = \frac{L}{\lambda}\) between the original context size \(L\) and the wavelength \(\lambda\). In the \(d\)-th hidden state, the ratio \(r\) depends on \(d\) in the following way:
\[r(d) = \frac{L}{\lambda_d} = \frac{L}{2\pi b'^{2d/|D|}} \tag{17}\]In order to define the boundary of the different interpolation strategies as above, we introduce two extra parameters \(\alpha\) and \(\beta\). All hidden dimensions \(d\) where \(r(d) < \alpha\) are those where we linearly interpolate by a scale \(s\) (exactly like PI, avoiding any extrapolation), and the \(d\) where \(r(d) > \beta\) are those where we do not interpolate at all. Define the ramp function \(\gamma\) to be
\[\gamma(r) = \begin{cases} 0 & \text{if } r < \alpha \\ 1 & \text{if } r > \beta \\ \frac{r - \alpha}{\beta - \alpha} & \text{otherwise} \end{cases} \tag{18}\]With the help of the ramp function, the “NTK-by-parts” method can be described as follows.
Definition 2: The “NTK-by-parts” interpolation is a modification of RoPE by using Eq. 12 with the following functions:
\[g(m) = m \tag{19}\] \[h( ext_d) = \left(1 - \gamma(r(d))\right) \frac{\theta_d}{s} + \gamma(r(d)) \theta_d \tag{20}\]The values of \(\alpha\) and \(\beta\) should be tuned on a case-by-case basis. For example, we have found experimentally that for the Llama family of models, good values for \(\alpha\) and \(\beta\) are \(\alpha = 1\) and \(\beta = 32\). Using the techniques described in this section, a variant of the resulting method was released under the name “NTK-by-parts” interpolation [7]. This improved method performs better than the previous PI [9] and “NTK-aware” interpolation methods, both with non-fine-tuned models and with fine-tuned models, as shown in [7].
In a lot of use cases, multiple forward-passes are performed with varying sequence lengths from 1 to the maximal context size. A typical example is the autoregressive generation where the sequence lengths increment by 1 after each step. There are two ways of applying an interpolation method that uses a scale factor \(s\) (including PI, “NTK-aware” and “NTK-by-parts”):
The problem of (1) is that the model may experience a performance discount at a length less than \(L\) and an abrupt degradation when the sequence length is longer than \(L'\). But by doing Dynamic Scaling as (2), it allows the model to gracefully degrade instead of immediately breaking when hitting the trained context limit \(L'\). We call this inference-time method the Dynamic Scaling method. When it is combined with “NTK-aware” interpolation, we call it “Dynamic NTK” interpolation. It first appeared in public as a reddit post in [14].
One notable fact is that the “Dynamic NTK” interpolation works exceptionally well on models pre-trained on \(L\) without any fine-tuning \((L' = L)\). This is supported by the experiment in Appendix B.3.
Often in the repeated forward-passes, the kv-caching [8] is applied so that we can reuse the previous key-value vectors and improve the overall efficiency. We point out that in some implementations when the RoPE embeddings are cached, some care has to be taken in order to modify it for Dynamic Scaling with kv-caching. The correct implementation should cache the kv-embeddings before applying RoPE, as the RoPE embedding of every token changes when \(s\) changes.
In addition to the previous interpolation techniques, we also observe that introducing a temperature \(t\) on the logits before the attention softmax has a uniform impact on perplexity regardless of the data sample and the token position over the extended context window (See Appendix A.2). More precisely, instead of Eq. 2, we modify the computation of attention weights into
\[A_{mn} = \text{softmax}\left(\text{q_m^T k_n}{\sqrt{|D| t}}\right) \tag{21}\]The reparametrization of RoPE as a set of 2D matrices has a clear benefit on the implementation of this attention scaling: we can instead use a “length scaling” trick which scales both \(q_m\) and \(k_n\) by a constant factor \(\sqrt{1/t}\) by simply scaling the complex RoPE embeddings by the same amount. With this, YaRN can effectively alter the attention mechanism without modifying its code. Furthermore, it has zero overhead during both inference and training, as RoPE embeddings are generated in advance and are reused for all forward passes. Combining it with the “NTK-by-parts” interpolation, we have the YaRN method.
Definition 3: By the “YaRN method”, we refer to a combination of the attention scaling in Eq. 21 and the “NTK-by-parts” interpolation introduced in Section 3.2.
For LLaMA and Llama 2 models, we recommend the following values:
\[t = \sqrt{\frac{1}{s}} \tag{22}\]The equation above is found by fitting \(\sqrt{1/t}\) at the lowest perplexity against the scale extension by various factors \(s\) using the “NTK-by-parts” method (Section 3.2) on LLaMA 7b, 13b, 33b and 65b models without fine-tuning. We note that the same values of \(t\) also apply fairly well to Llama 2 models (7b, 13b and 70b). It suggests that the property of increased entropy and the temperature constant \(t\) may have a certain degree of “universality” and may be generalizable across some models and training data.
The YaRN method combines all our findings and surpasses all previous methods in both fine-tuned and non-fine-tuned scenarios. Thanks to its low footprint, YaRN allows for direct compatibility with libraries that modify the attention mechanism such as Flash Attention 2 [13].
We show that YaRN successfully achieves context window extension of language models using RoPE as its position embedding. Moreover, this result is achieved with only 400 training steps, representing approximately 0.1% of the model’s original pre-training corpus, a 10x reduction from Rozière et al. [31] and 2.5x reduction in training steps from Chen et al. [9], making it highly compute-efficient for training with no additional inference costs. We calculate the perplexity of long documents and score on established benchmarks to evaluate the resulting models, finding that they surpass all other context window extension methods.
We broadly followed the training and evaluation procedures as outlined in [9].
For training, we extended the Llama 2 [39] 7B and 13B parameter models. No changes were made to the LLaMA model architecture other than the calculation of the embedding frequencies as described in 3.4 with \(s = 16\) and \(s = 32\). We used a learning rate of \(2 \times 10^{-5}\) with no weight decay and a linear warmup of 20 steps along with AdamW [24] \(\beta_1 = 0.9\) and \(\beta_2 = 0.95\). For \(s = 16\) we fine-tuned for 400 steps with global batch size 64 using PyTorch [26] Fully Sharded Data Parallelism [42] and Flash Attention 2 [13] on the PG19 dataset [29] chunked into 64k segments bookended with the BOS and EOS token. For \(s = 32\) we followed the same procedure, but started from the finished \(s = 16\) checkpoint and trained for an additional 200 steps.
In Code Llama [31], a dataset with 16k context was used with a scale factor set to \(s \approx 88.6\), which corresponds to a context size of 355k. They show that the network extrapolates up to 100k context without ever seeing those context sizes during training. Similar to 3.1 and Rozière et al. [31], YaRN also supports training with a higher scale factor \(s\) than the length of the dataset. Due to compute constraints, we test only \(s = 32\) by further fine-tuning the \(s = 16\) model for 200 steps using the same dataset with 64k context.
We show in 4.3.1 that the \(s = 32\) model successfully extrapolates up to 128k context using only 64k context during training. Unlike previous “blind” interpolation methods, YaRN is much more efficient at transfer learning when increasing the scale \(s\). This demonstrates successful transfer learning from \(s = 16\) to \(s = 32\) without the network needing to relearn the interpolated embeddings, as the \(s = 32\) model is equivalent to the \(s = 16\) model across the entire context size, despite only being trained on \(s = 32\) for 200 steps.
The evaluations focus on three aspects:
To evaluate the long sequence language modeling performances, we use the GovReport [18] and Proof-pile [4] datasets, both of which contain many long sequence samples. For all evaluations, the test splits of both datasets were used exclusively. All perplexity evaluations were calculated using the sliding window method from Press et al. [27] with \(S = 256\).
Firstly, we evaluated how the model performed as the context window increased. We selected 10 random samples from Proof-pile with at least 128k tokens each and evaluated the perplexity of each of these samples when truncated at 2k steps from a sequence length of 2k tokens through 128k tokens.
Table 1 shows a side-by-side comparison of Llama-2 model extended from 4096 to 8192 context length via PI (LLongMA-2 7b), “NTK-aware” and YaRN. Note that PI and “NTK-aware” models were trained using the methodology in Chen et al. [9], while YaRN used the same methodology but 2.5x less training steps and data, as described in 4.
LLongMA-2 7b [28] is fine-tuned from Llama-2 7b, trained at 8k context length with PI using the RedPajama dataset [12].
Extension Method | Trained Tokens | Context Window | 2048 | 4096 | 6144 | 8192 | 10240 |
---|---|---|---|---|---|---|---|
PI(s=2) | 1B | 8k | 3.92 | 3.51 | 3.51 | 3.34 | 8.07 |
NTK(θ=20k) | 1B | 8k | 4.20 | 3.75 | 3.74 | 3.59 | 6.24 |
YaRN(s=2) | 400M | 8k | 3.91 | 3.50 | 3.51 | 3.35 | 6.04 |
Table 1: Sliding window perplexity (S=256) often 128k Proof-pile documents over Llama-2 extended via PI, NTK and YaRN
We further evaluated YaRN at the scale factor s=16, 32 and compared them against a few open source models fine-tuned from Llama-2 and extended to more than 32k context window such as Together.ai[37] and “NTK-aware” CodeLlama[31]. The results are summarized in Table 2 (with a more detailed plot in Figure 1).
Model Size | Model Name | Context Window | Extension Method | 8192 | 32768 | 65536 | 98304 | 131072 |
---|---|---|---|---|---|---|---|---|
7B | Together | 32k | PI | 3.50 | 2.64 | >10^2 | >10^3 | >10^4 |
7B | CodeLlama | 100k | NTK | 3.71 | 2.74 | 2.55 | 2.54 | 2.71 |
7B | YaRN(s=16) | 64k | YaRN | 3.51 | 2.65 | 2.42 | >10^1 | >10^1 |
7B | YaRN(s=32) | 128k | YaRN | 3.56 | 2.70 | 2.45 | 2.36 | 2.37 |
13B | CodeLlama | 100k | NTK | 3.54 | 2.63 | 2.41 | 2.37 | 2.54 |
13B | YaRN(s=16) | 64k | YaRN | 3.25 | 2.50 | 2.29 | >10^1 | >10^1 |
13B | YaRN(s=32) | 128k | YaRN | 3.29 | 2.53 | 2.31 | 2.23 | 2.24 |
Table 2: Sliding window perplexity (S=256) often 128k Proof-pile documents truncated to evaluation context window size
We observe that the model exhibits strong performance across the entire targeted context size, with YaRN interpolation being the first method to successfully extend the effective context size of Llama 2 to 128k. Of particular note are the YaRN(s=32) models, which show continued declining perplexity through 128k, despite the fine-tuning data being limited to 64k tokens in length, demonstrating that the model is able to generalize to unseen context lengths.
Furthermore, in Appendix B.1, we show the results of the average perplexity on 50 untruncated GovReport documents with at least 16k tokens per sample evaluated on the setting of 32k maximal context window without Dynamic Scaling in Table 4. Similar to the Proof-pile results, the GovReport results show that fine-tuning with YaRN achieves good performance on long sequences.
The passkey retrieval task as defined in [25] measures a model’s ability to retrieve a simple passkey (i.e., a five-digit number) from amongst a large amount of otherwise meaningless text. For our evaluation of the models, we performed 10 iterations of the passkey retrieval task with the passkey placed at a random location uniformly distributed across the evaluation context window on different context window sizes ranging from 8k to 128k. Both 7b and 13b models fine-tuned using YaRN at 128k context size passes the passkey retrieval task with very high accuracy (>99%) within the entire context window size. We show detailed results in Appendix B.2.
The Hugging Face Open LLM Leaderboard [19] compares a multitude of LLMs across a standardized set of four public benchmarks. Specifically, we use 25-shot ARC-Challenge [11], 10-shot HellaSwag [41], 5-shot MMLU [17], and 0-shot TruthfulQA [23]. To test the degradation of model performance under context extension, we evaluated our models using this suite and compared it to established scores for the Llama 2 baselines as well as publicly available PI and “NTK-aware” models. The results are summarized in Table 3.
Model Size | Model Name | Context Window | Extension Method | ARC-c | Hellaswag | MMLU | TruthfulQA |
---|---|---|---|---|---|---|---|
7B | Llama2 | 4k | None | 53.1 | 77.8 | 43.8 | 39.0 |
7B | Together | 32k | PI | 47.6 | 76.1 | 43.3 | 39.2 |
7B | CodeLlama | 100k | NTK | 39.9 | 60.8 | 31.1 | 37.8 |
7B | YaRN(s=16) | 64k | YaRN | 52.3 | 78.8 | 42.5 | 38.2 |
7B | YaRN(s=32) | 128k | YaRN | 52.1 | 78.4 | 41.7 | 37.3 |
13B | Llama2 | 4k | None | 59.4 | 82.1 | 55.8 | 37.4 |
13B | CodeLlama | 100k | NTK | 40.9 | 63.4 | 32.8 | 43.8 |
13B | YaRN(s=16) | 64k | YaRN | 58.1 | 82.3 | 52.8 | 37.8 |
13B | YaRN(s=32) | 128k | YaRN | 58.0 | 82.2 | 51.9 | 37.3 |
Table 3: Performance of context window extensions methods on the Hugging Face Open LLM benchmark suite compared with original Llama 2 baselines
We observe that there is minimal performance degradation between the YaRN models and their respective Llama 2 baselines. We also observe that there was on average a 0.49% drop in scores between the YaRN s=16 and s=32 models. From this we conclude that the iterative extension from 64k to 128k results in negligible performance loss.
In conclusion, we have shown that YaRN improves upon all existing RoPE interpolation methods and can act as a drop-in replacement to PI, with no downsides and minimal implementation effort. The fine-tuned models preserve their original abilities on multiple benchmarks while being able to attend to a very large context size. Furthermore, YaRN allows efficient extrapolation with fine-tuning on shorter datasets and can take advantage of transfer learning for faster convergence, both of which are crucial under compute-constrained scenarios. Finally, we have shown the effectiveness of extrapolation with YaRN where it is able to “train short, and test long”.
To aid in reproducibility, we provide, as supplementary material, the entirety of of the code used to train the YaRN models in Table 2, as well as the evaluation code that produced Figure 1 and Tables 1, 2, 3, 4, and 5. The code also contains implementations of various extension methods referenced throughout the paper. For training YaRN, we used the publicly available PG19 dataset [29] tokenized to 64k tokens.