00:00:00

Share Your Feedback 🏝️

Model | Giraffe

Model | Giraffe

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Model | Code Llama Next: Unnatural Instructions

Model | Giraffe

  • Related Project: Private
  • Category: Paper Review
  • Date: 2023-08-27

Giraffe: Adventures in Expanding Context Lengths in LLMs

  • url: https://arxiv.org/abs/2308.10882
  • pdf: https://arxiv.org/pdf/2308.10882
  • github: https://github.com/abacusai/Long-Context
  • abstract: Modern large language models (LLMs) that rely on attention mechanisms are typically trained with fixed context lengths which enforce upper limits on the length of input sequences that they can handle at evaluation time. To use these models on sequences longer than the train-time context length, one might employ techniques from the growing family of context length extrapolation methods - most of which focus on modifying the system of positional encodings used in the attention mechanism to indicate where tokens or activations are located in the input sequence. We conduct a wide survey of existing methods of context length extrapolation on a base LLaMA or LLaMA 2 model, and introduce some of our own design as well - in particular, a new truncation strategy for modifying the basis for the position encoding. We test these methods using three new evaluation tasks (FreeFormQA, AlteredNumericQA, and LongChat-Lines) as well as perplexity, which we find to be less fine-grained as a measure of long context performance of LLMs. We release the three tasks publicly as datasets on HuggingFace. We discover that linear scaling is the best method for extending context length, and show that further gains can be achieved by using longer scales at evaluation time. We also discover promising extrapolation capabilities in the truncated basis. To support further research in this area, we release three new 13B parameter long-context models which we call Giraffe: 4k and 16k context models trained from base LLaMA-13B, and a 32k context model trained from base LLaMA2-13B. We also release the code to replicate our results.

Contents

TL;DR


  • context length 확장 실험: Llama 모델을 기반으로 context length 확장을 위한 다양한 접근 방식 실험
  • 데이터셋 및 벤치마크 활용: RedPajama 및 Vicuna 데이터셋을 이용한 세밀한 튜닝
  • RoPE 인코딩 변형: 선형 스케일링, 빈도 조정, Fourier 기반의 수정으로 문맥 인식 개선 시도

1. 서론

본 연구는 transformer 모델의 context length를 확장하는 새로운 방법을 탐구하며, 특히 위치 정보 인코딩 방식의 변형에 초점을 맞추었다. 기존의 Rotary Position Embedding(RoPE) 인코딩은 모델이 입력 시퀀스의 위치 정보를 효과적으로 학습하도록 도와주나, 긴 문맥에서는 그 성능이 제한적이라는 문제점이 있었다. 이를 극복하기 위해 다양한 스케일링 및 인코딩 기법을 실험하였다.


2. 이론적 배경 및 관련 연구

Transformer 아키텍처는 입력 시퀀스의 각 요소에 대한 위치 정보를 인코딩하여, 시퀀스의 순서 정보를 모델에 제공한다. 이 위치 인코딩에는 여러 방법이 사용될 수 있으며, 본 연구에서는 RoPE 기법을 활용하였다. RoPE는 각 위치의 벡터를 rotation시키는 방식으로, 위치 간의 상대적인 관계를 보존하는 특징이 있다.

\[\text{RoPE}(p) = \begin{bmatrix} \cos(p) \\ \sin(p) \end{bmatrix},\]

\(p\)는 위치 인덱스를 나타낸다. 이런 인코딩은 입력 시퀀스의 길이가 고정되어 있는 상황에서 잘 작동하지만, 시퀀스가 길어질 경우 그 효율성이 떨어진다는 한계가 있다.


3. 방법

본 연구에서는 기존의 RoPE 인코딩을 개선하기 위해 세 가지 주요 접근 방식을 시도하였다.

  1. 선형 스케일링: 기존 RoPE의 파라미터를 선형적으로 조정하여 context length를 확장하려는 시도이다.
  2. 빈도 조정: Fourier 변환의 기초를 이루는 빈도 요소들을 변형하여, 낮은 빈도는 더욱 확장하고 높은 빈도는 축소시키는 방법이다.
  3. Fourier 기반의 수정: Fourier 변환의 특정 기저를 제거하거나 수정하여, 모델이 특정 빈도에만 민감하게 반응하도록 조정하였다.

이런 변형을 통해 모델이 더 긴 문맥을 효과적으로 처리할 수 있도록 하였으며, RedPajama 및 Vicuna 데이터셋을 사용하여 이런 방법들을 세밀하게 튜닝하였다.


4. 실험 및 결과

RedPajama 데이터셋을 사용하여 context length 4096에서 모델을 훈련시킨 결과, 기대한 성능 향상을 확인할 수 있었다. 또한, 다양한 스케일링 및 인코딩 변형을 적용한 결과, 선형 스케일링 방법이 가장 강건한 결과를 보였다. 이는 다음과 같은 수학적 관계로 설명할 수 있다.

\[\text{New Position Encoding} = \text{Original RoPE}(p) \times \text{Scaling Factor}\]

위와 같은 방법으로 context length를 확장한 모델들은 특히 WikiQA 데이터셋에서 효과적으로 작동하였으며, 실험을 통해 얻은 결과는 모델이 더 넓은 범위의 문맥에서도 정보를 효과적으로 처리할 수 있음을 보여주었다.


5. 결론

본 연구를 통해 Transformer 모델의 context length를 효과적으로 확장할 수 있는 새로운 방법을 개발하였다. 특히, 위치 인코딩의 변형을 통해 모델의 성능을 크게 개선할 수 있었으며, 이는 긴 문맥을 필요로 하는 다양한 자연어 처리 작업에 유용하게 적용될 수 있을 것으로 기대된다.


Overview

We conducted a wide variety of experiments to try to extend the context length of the models. First, we tried simply using the base Llama model zero-shot. As expected, this performed well up to 2048 context length but deterioriated very rapidly afterwards.

We next investigated fine tuning approaches where we trained the model on the RedPajama dataset at context lengths of 4096. This led to expected improvements in performance up to 4096 context but again, no further.

Another approach to extending context length is to modify in some way the RoPE encoding. Here, we tried many different ideas:

  • Linear scaling, as described by kaiokendev.github.io.
  • Scaling the Fourier basis of RoPE by a power, such that low frequencies are stretched more than high frequencies.
  • Applying truncation to the Fourier basis. Our idea here was that we wanted the model to see only frequencies that were fast enough so that it got at least one full cycle during training; any slower frequencies were set to 0 (equivalent to no rotation at all, i.e. equally important at all context lengths).
  • Randomising the position vector.

In particular, we combined fine-tuning on the RedPajama dataset and instruction-fine-tuning with the Vicuna dataset with the above approaches. This is what led to the most fruitful results.

Finally, we implemented and tried the approach described in the xPos paper. This approach adds decaying amplitude penalty terms that cause fast frequencies to have less impact at long distances than slow frequencies in the Fourier basis (see our blog post for similarity heatmaps that show this).

Extending LLM Context Length

The choice of how to encode positional information for transformers has been one of the key components of LLM architectures.

An area that has been interesting to us and others in the community recently is whether LLMs can be extended to longer contexts.

We have conducted a range of experiments with different schemes for extending context length capabilities of Llama, which has been pretrained on 2048 context length with the RoPE (Rotary Position Embedding) encoding. Here we share some of the results as well as the training and evaluation scripts in the hope that it will be useful to the community. For our best performing models - linear scaling with IFT at scales 4 and 16 - we are also sharing the weights in case others wish to use them, or to conduct their own tests. We believe the scale 16 model should perform well on real world tasks up to 16k context lengths, and potentially even up to about 20-24k context lengths.

Highlighted Results

Perhaps the most pointed observation we made is that different evaluation methodologies/tasks lead to different rankings of the approaches detailed above. This will be described in further detail below.

That said, we made the following general observations:

  • Linear interpolation/scaling seems to be the most robust approach for increasing model context length.
  • Using a linear scale of N does not necessarily lead to a model context length increase by a factor of N. For example, our scale 16 experiments generally stopped performing well after a context length of 16000, not 32000 (~2048 * 16). We have ideas for how to ameliorate this effect planned for future work.
  • Truncation and randomisation both seem to have great perplexity scores but perform less well on the retrieval task.
  • Instruction fine tuning with the Vicuna dataset improves accuracy in the retrieval context significantly at lengths which the base model is capable of handling, but cannot ‘fix’ the base model at lengths where it fails.

Evaluation Tasks

For evaluation we used two different datasets:

  • LMSys datasets (the ‘lines’ task) for locating a substring in the context
  • Our own open book question answering dataset, WikiQA, which is based off of other open source base QA datasets

In addition, we looked at the log loss of the train and eval sets during

For the LMSys task, we generated new and longer testcases, up to a context length of about 25000, beyond the 16000 context testcases in the original dataset.

The WikiQA task is the task of answering a question based on the information given in a Wikipedia document. We have built upon the short answer format data in Google Natural Questions to construct our QA task. It is formatted as a document and a question. We ensure the answer to the question is a short answer which is either a single word or a small sentence directly cut pasted from the document. Having the task structured as such, we can pinpoint exactly where the LLM was supposed to “look” for the answer in the context, and thus effectively evaluate every part of the expanded context length by carefully placing the answer in different locations.

We have selected large Wikipedia documents and have truncated them to get multiple versions of the same document with sizes varying between 2000 to 16000 tokens. For each size of the document, we also have multiple versions which place the question and the answer text at different locations i.e whether it occurs in the first 10%, the bulk or last 10% of the document. Having multiple version of the same document allows us to get a exhaustive and fair evaluation across model sizes, and within one model’s context positions since we intrinsically are asking for the same information.

A potential issue in a Wikipedia based dataset is that the model could perhaps correctly answer from its pretrained corpus and not from context. To resolve this, we have created another “altered” dataset. This data only consists of questions which have numerical answers. Here, we change the answer and every occurrence of the answer in the document to a different number. Essentially making sure that if the LLM recollects from its pretrained corpus, it gives a wrong answer. The modification is made as follows:

  • If the answer is a year, which is quite frequent, (i.e. is between 1000-2100), we change it to a different random value within +/- 10 of the original value. We treat years as a special case so as to not make the interpretation of the document absurd by messing up choronological information
  • If the answer is any other number, we change it to a different random number which has the same number of digits

We call our original QA task Free Form QA (FFQA) and the altered task Altered Numeric QA (AltQA).

We evaluate success on every example in both versions of our QA task by measuring “Presence Accuracy” i.e, whether or not the answer is present as a subtring in the model’s generated answer. To run inference for our models on WikiQA and compute metrics refer to run_inference_WikiQA.py and compute_metrics_WikiQA.ipynb here

We are releasing these datasets on HuggingFace so others can use it to run their own long context experiments.

Results

LMSys Eval

As a general point regarding the results below, the authors believe that small differences in accuracy on this task are not particularly indicative of model ranking quality. We would generally look at the broadest trends here in interpreting the results.

Also, as a baseline, standard Llama-13b only has non-zero accuracy up to 2048 context length (as does the Vicuna-instruction- fine-tuned version of it).

Comparison of different scaling approaches

In the above we compare the different scaling approaches. ‘Scale’ refers to linear interpolation with the designated scaling value. We see that linear interpolation with a scale of 16 is the only one to achieve a non-zero accuracy at context lengths greater than 9000. However, this seems to come with a sacrifice of some accuracy on shorter contexts.

The power = 0.5 basis seems to work particularly well for this task at shorter contexts but has the sharpest drop off in accuracy as context length increases.

It’s interesting to note that scale=16 doesn’t generalise quite as far as one would hope. Naively, one expects that following the trend of scale=4 - which is non-zero up to 8192 (and this is reasonable as the original context length is 2048, and 8192 = 2048 * 4; beyond this, the model is seeing relative distances between keys and queries it has never encountered before), scale=16 should be non-zero all the way up to 2048 * 16 = 32768.

Impact of IFT (Instruction Fine Tuning)

In the above we display the impact of IFT via training with the Vicuna instruction set using LoRA. We see that IFT does improve accuracy by a small but non-negligible margin. However, it is not sufficient to change the overall shape of the accuracy curve - and it does not confer any extension to the range of context lengths the model can achieve non-zero accuracy on this task at.

Evaluating Zero Shot at different scales than Training

In the above, we display various experiments with trying different scale values (for linear interpolation) at evaluation time than the model was trained on. The green curve is indicative of taking a base model (trained on 2048 context) and applying a scale value to it. It does extend the non-zero range from 2048 to 4096, but with low accuracy throughout. In general, however, once a model has been trained with a scale > 0, it seems that the model can then zero-shot to a larger scale at evaluation time quite well - very greatly increasing the range of coherent context lengths (e.g. compare Train=4, Eval=8 being non-zero here at 16k context length vs being 0 for anything above 8k two graphs above). However this does come at the cost of accuracy dropoff, particularly for Train=16, Eval=32.

The Train=16, Eval=12 run has the longest non-zero accuracy context length we have seen. It achieves a non-zero score at a context length of around 20000.

WikiQA Eval

In the below tables, both models are evaluated with scale=4. However, the ‘no scaling’ model was no finetuned (i.e. experienced no training) at a scale > 1. The Scale=4 model did receive fine-tuning at that expanded scale.

Presence Accuracy:

Context Length IFT with Scale=4 on FFQA IFT No scaling on FFQA IFT with Scale=4 on AltQA IFT No scaling on AltQA
2048 0.3233 0.2217 0.7281 0.2982
4096 0.3783 0.2467 0.7018 0.2829
8192 0.4434 0.2406 0.6582 0.2401
16384 0.3933 0.0 0.5363 0.0

Note: For 16k context length, we use a scale factor of 8 during inference. This enables expanding the original 2k context to 2*8=16k. It is interesting to point out that even though the scaled model was trained with a scale factor of 4, it can zero-shot interpolate to 16k (a scale of 8) during inference without losing too much performance. This however does not hold in the non-scaled models as is evident from the drop in accracy to 0 on the 16k datapoints. Indicating that our scaling and context length interpolation does work.

Input Context Length Stats

As mentioned previously, we truncate and modify the documents to have different version of the WikiQA data. Each version is meant to extensively test the model’s performance upto and at a certain context length as indicated by the version name

FFQA
  Mean Context Length Max Context Length
ffqa_2k.json 1936.71 3228
ffqa_4k.json 3805.06 5793
ffqa_8k.json 7598.98 9963
ffqa_16k.json 15000.54 16178
AltQA
  Mean Context Length Max Context Length
altqa_2k.json 1953.73 2698
altqa_4k.json 3737.39 5172
altqa_8k.json 7481.37 9619
altqa_16k.json 15013.44 16173

Performance Robust to Increasing Context Length

Performance Robust to Increasing Context Length

As is seen above, our technique of fine-tuning interpolated embeddings seems to give good models robust to increasing context length of inputs on the WikiQA task. We demonstrate this on both versions of the task. Since we finetune with a scale context of 4, we expect the accuracy to not drop until 4*2048=8192 sized input. Even beyond this limit, we do see some reasonable performance. This seems to be a consequence of the periodicity of RoPE embeddings, which leads to some characteristics being extrapolatable to positions beyond the limit set by the scale context

Impact of Scaling Context

Impact of Scaling Context FFQA

Impact of Scaling Context AltQA

We contrast models instruct finetuned with and without scale context to show that IFT with scaled context leads to a significant jump in performance. Note that for both models, we still use a scaled context (=4) during evaluation. Interestingly, even zero shot performance of the scaled RoPE embedding gives non-trivial accuracy. However, having the embeddings explictly finetuned does have considerable gains. We see almost a 2x improvement on FFQA and a 2.5x improvement on AltQA at all positions interpolated by the scale context factor

Location of Information

Loss curves

We trained models across all the experiment described in the overview. Not all of them seem promising for a full evaluation. Some of the experiments were abandoned because the loss curves did not seem promising. In some cases we did find that the results did not alway align with the losses we were observing during training.

The images below show curves from a subset of the experiments we ran:

For example the XPOS loss never converged towards the losses seen in the other runs. Initially we suspected that fp16 lacked sufficient precision to handle the XPOS coefficients. We adjusted the implementation to use fp32 for the core attention dot product. This did improve the convergence but not sufficiently to have the losses match the other models. Our hypothesis is that XPOS is too different from the base positional embeddings to finetune into the embedding. This is a bit surprising since XPOS can is just RoPE with a scaling factor that is a function of relative difference. One experiment we started but have not completed is to start with a factor of 1.0 and slowly shift to the XPOS function over iterations.

Previous: Model | Code Llama Next: Unnatural Instructions

post contain ""

    No matching posts found containing ""