Contents
[컨텍스트 확장 문맥 확장 색인마킹]
1. 서론
언어 모델은 자연어 처리, 코드 생성, 양적 인퍼런스 및 정리 증명과 같은 여러 분야에서 상당한 진보를 이끌었습니다. 언어 모델의 주요 챌린지 중 하나는 방대한 새로운 지식을 효과적으로 통합하는 것입니다. 이 연구에서는 다문서 시나리오에서 트랜스포머 모델의 컨텍스트 길이를 확장하는 데 있어 중요한 장애물인 주의 산만 문제(distraction issue)를 해결하기 위해 개발된 집중 트랜스포머(Focused Transformer, FOT) 기술을 소개합니다.
2. 관련 연구
이전 연구들은 트랜스포머의 컨텍스트 길이를 늘리기 위해 다양한 접근 방식을 개발하였으며, 이는 주로 주의 계산의 이차 복잡도를 완화하는 데 중점을 두었습니다. (Transformer-XL, Longformer, BigBird, LongT5 등)
이런 연구들은 주로 sparse attention을 사용하여 긴 시퀀스를 처리하도록 하였습니다. 본 연구는 이런 접근 방식을 따르면서도 새로운 지식을 통합하는 방법에 중점을 둡니다.
3. FOT: Focused Transformer
FOT는 기존 트랜스포머 모델에 통합할 수 있는 플러그 앤 플레이 방식의 확장 기능을 제공하며, 이를 통해 기존 모델을 새로운 모델로 훈련하거나 기존의 대규모 모델을 더 긴 컨텍스트로 파인튜닝할 수 있습니다.
3.1 메모리 어텐션 계층
메모리 어텐션 계층은 인퍼런스 시에 추가적인 컨텍스트에서 정보를 검색할 수 있는 기능을 제공합니다. 구체적으로, 각 쿼리는 로컬 컨텍스트의 이전 키와 메모리에서 k-최근접 이웃 알고리즘(kNN)을 사용하여 검색된 상위 k개의 가장 일치하는 키들에 주의를 기울입니다. 이런 메커니즘은 수학적으로 다음과 같이 표현됩니다.
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\text{QK^T}{\sqrt{d_k}}\right)V\]상기 식에서 $Q$, $K$, $V$는 각각 쿼리, Key, Value 행렬을 나타내며, $d_k$는 키의 차원을 의미합니다. 메모리 키는 쿼리와의 내적을 통해 순위가 매겨지고, 이를 기반으로 kNN 검색 알고리즘이 적용됩니다.
3.2 크로스배치 트레이닝 절차
크로스배치 트레이닝은 트랜스포머 아키텍처를 훈련하는 새로운 방법으로 ($K$, $V$) 공간의 구조를 개선하여 메모리 어텐션 계층이 관련 정보에 쉽게 집중할 수 있도록 합니다. 이 절차를 수식으로 표현하면,
\[\text{Crossbatch}(\text{pos}, \text{neg}) = \frac{e^{\text{pos}}}{e^{\text{pos}} + \sum e^{\text{neg}}}\]상기 식에서 $\text{pos}$는 현재 문서와 이전 로컬 컨텍스트에서 ($K$, $V$) 쌍을 나타내며, $\text{neg}$는 관련 없는 문서에서의 ($K$, $V$) 쌍을 나타냅니다. 이 비율은 모델이 관련 정보에 더 집중하도록 돕습니다.
3.3 주의 산만 문제(The distraction issue)
주의 산만 문제는 다음과 같은 수학적 표현을 통해 설명될 수 있습니다.
\[r_d = \sum_{i=1}^d w_{ij} \quad \text{where} \quad w_{ij} = \text{softmax}(\text{weights})\]이 수식에서 $r_d$는 긍정적 키에 대한 주의 집중도를 나타내며, $w_{ij}$는 소프트맥스 가중치로, 이론상 $r_d \approx \frac{1}{d}$가 되면 주의가 균등하게 분산되어 있는 상태를 나타냅니다. 이는 주의가 적절하게 집중되지 않고 산만해지는 문제를 시사합니다.
크로스배치를 통해 이 문제를 완화시켜 어텐션을 집중시키는 것을 목표로 합니다.
Cross-Batch 관련 논문: https://arxiv.org/abs/1912.06798
[Long-Llama 주의산만(distraction) 문제 관련 색인마킹]
4. Long-llama
이 연구의 유망한 점 중 하나는 기존의 큰 모델을 파인튜닝하여 컨텍스트 길이를 확장할 수 있다는 것입니다. 실험을 통해 Long-llama 모델이 트레이닝 컨텍스트 길이를 넘어서는 성능을 보여줌으로써 이를 입증합니다. 또한 질의응답 및 몇 가지 태스크에서 컨텍스트 길이 증가로 인한 성능 개선을 관찰하였다고 보고합니다.
5. FOT 분석
이 섹션에서는 다양한 데이터셋에서 FOT의 성능을 분석하고 검증합니다. 주요 검증 대상은 FOT가 인퍼런스 시 컨텍스트 길이를 확장할 수 있는지, 기존에 사전 트레이닝된 모델의 컨텍스트 길이를 확장하는 데 사용될 수 있는지, 그리고 언어 모델링 태스크에서 성능 향상으로 어떻게 번역되는지입니다. 결론적으로 FOT는 기존 변형기 모델에 비해 개선될 수 있음을 보입니다.
Language models have served as a catalyst for substantial advancements in several areas, including natural language processing [Radford et al., 2019, Brown et al., 2020], code generation [Chen et al., 2021, Li et al., 2022], quantitative reasoning [Lewkowycz et al., 2022] and theorem proving [Polu and Sutskever, 2020, Jiang et al., 2022, Mikuła et al., 2023]. One of the central challenges with language models is the effective incorporation of extensive new knowledge. The common practice of fine-tuning the model is not only resource-intensive and complex to manage, but it also does not always clearly indicate how to incorporate new knowledge. For example, fine-tuning on a text such as “Alice in Wonderland” does not equip the model to answer questions about the story itself, but rather it trains the model to predict the next token or complete masked sentences. A promising alternative – integrating the new knowledge within the context – doesn’t require training but is considerably restricted by the model’s effective context length. For this method to work with large knowledge databases (like large code repositories), the model needs to manage a context length extending to millions of tokens.
Figure 1: Accuracy of Long-llama 3B on passkey retrieval compared to the original OpenLLaMA model. Our method extrapolates beyond the training length, achieving 94.5% accuracy at a context length of 100k and 73% at 256k tokens, while the baseline is unable to handle context longer than its training length (2k).
In this research, we highlight one of the primary obstacles in augmenting the context length: as the number of documents increases, the ratio of pertinent to irrelevant tokens diminishes. The standard training procedure frequently results in overlaps between keys connected with irrelevant values and those related to relevant ones, exacerbating the model’s task of differentiating between them. We term this challenge the distraction issue.
We propose the Focused Transformer (FOT), an innovative technique developed explicitly to address this issue. The Focused Transformer permits a subset of attention layers to access an additional context of ($K$, $V$) pairs through the k-nearest neighbors (kNN) algorithm, akin to the method used in [Wu et al., 2022]. This mechanism effectively extends the total context length. The distinctive aspect of the Focused Transformer is its training procedure, drawing from contrastive learning. This method addresses the distraction issue and facilitates larger context capacities. Specifically, during the training phase, we deliberately expose the chosen subset of attention layers to both relevant and irrelevant keys (like negative samples from unrelated documents). This strategy incentives the model to differentiate keys connected with semantically diverse values, thereby enhancing their structure.
Notably, Long-llamas show significant improvements on tasks necessitating long-context modeling. In particular, they can manage a 256k context length on the passkey retrieval task [Mohtashami and Jaggi, 2023].
Our research contributions are the following:
1 We further scrutinize FOT’s capabilities across various datasets and model sizes. We show that a FOT trained with a total context of 512 tokens can extrapolate to 16 million tokens in a benchmark dictionary lookup task. We also assess FOT on long-context language modeling tasks such as books (PG-19), mathematics (arXiv), code (GitHub), and formal proofs (Isabelle), where it exhibits improvements in perplexity over baselines.
Long-context transformer architectures A multitude of approaches have been developed to increase the context length of transformers, mostly focusing on alleviating the quadratic complexity of the attention computation. For instance, Transformer-XL [Dai et al., 2019] caches the previous context and enables the linear extension of context with the number of layers. Longformer [Beltagy et al., 2020] employs an attention mechanism that allows tokens to attend to distant tokens sparsely, reducing the computational complexity. BigBird [Zaheer et al., 2020], LongT5 [Guo et al., 2021], and [Dao et al., 2022] also use sparse attention to handle long sequences. Different efficiency considerations have been studied in [Kaddour et al., 2023], showing that they lead to limited gains. Hierarchical transformers [Nawrot et al., 2021, 2023] downsample activations in intermediate layers to reduce computation and enable longer contexts. COLT5 [Ainslie et al., 2023] proposes conditional computation to save memory and enable larger contexts. Memorizing Transformer [Wu et al., 2022] uses kNN lookup to pick up the most relevant tokens, which might also be seen as a way to reduce the computational complexity of attention. Our work adheres to this approach and aims to train a key space that handles longer attention context length (e.g., by mitigating the distraction issue) and, thus, has better long-context capabilities.
Fine-tuning LLMs for longer retrieval Prior works such as RETRO [Borgeaud et al., 2022] (RETROfitting) and Memorizing Transformer [Wu et al., 2022] have demonstrated a promising path for fine-tuning existing LMs to add new capabilities without the need to retrain the entire model. In contrast to those approaches our method is not framed as a retrieval but as a way of extending the context of the model. In contrast to RETRO, we propose a single-stage method for context extension instead of a two-stage retrieve-then-embed approach. We provide a more detailed comparison with the Memorizing Transformer in Appendix C.3. More recently, a number of works have explored fine-tuning LLaMA to extend its context length. Landmark attention [Mohtashami and Jaggi, 2023] proposes a compression scheme of LLM’s context into landmarks, increasing the context length of LLaMA-7B to 32K. Position Interpolation (PI, [Chen et al., 2023] and [kaiokendev, 2023]) introduces a modification to the rotary positional encoding scheme that enables fine-tuning for 32K context. In contrast to this work, our method does not rely on positional encodings, following the findings from [Haviv et al., 2022]. Removing positional encoding in additional context allows us to extrapolate to 256k tokens, although the model was only trained on sequences up to 8K, yielding theoretically unbounded context length.
Zero-shot methods KNN-LM [Khandelwal et al., 2019] shows that one can improve the performance of a LLM by combining two probability distributions. One created by a pre-trained model, and one based on the similarity between the embedding of the currently processed token and the embeddings of tokens retrieved from a large database. Meanwhile, we extend the model context in a subset of attention layers, potentially allowing for reasoning within this extended context. Parallel Context Windows for Large Language Models [Ratner et al., 2023] introduces a method for extending the context of language models without training. They achieve this by embedding several context windows independently in parallel and allowing only a subset of tokens to attend to all windows. On the other hand, we fine-tune existing models and allow all tokens to attend to all previous tokens but only in a subset of layers. Additionally, our method allows us to improve the structure of the key-value space of the existing models.
Contrastive learning Contrastive learning aims to learn good representations by comparing positive and negative examples. CLIP [Radford et al., 2021] and SimCLR [Chen et al., 2020] are two popular contrastive learning methods that have achieved state-of-the-art performance in the image domain. During contrastive pre-training, negative examples are kept in the same batch to learn to distinguish them from positive examples. Scaling the batch size in contrastive learning has been demonstrated to enhance the quality of representations, as shown in [Gao et al., 2021b]. It has been suggested [Gao et al., 2019] that the embedding space in language modeling suffers from degeneracy, where embeddings are tightly packed in a narrow cone, making it difficult to distinguish between them. TRIME [Zhong et al., 2022] proposes a training approach designed for training LMs with memory augmentation, which uses negatives to improve the quality of representations. The main difference between this and our approach is that we incorporate negatives into the chosen subset of attention layers instead of interpolating in the output layer and use the standard language modeling loss. TRIME [Zhong et al., 2022] also focuses on retrieval from large databases, whereas we focus on extending the context of the model. ContraCLM [Jain et al., 2023] applies contrastive losses at both the token and sequence levels during training to promote more uniformly distributed, isotropic representations. It is shown to enhance the discrimination of representations on textual semantic similarity benchmarks. While ContraCLM focuses on improving the general expressiveness of representations, our work introduces contrastive-inspired techniques designed specifically for training the attention mechanism to handle longer context lengths. Nonetheless, exploring other contrastive learning objectives could be beneficial for further improving the key structure in future work.
Figure 2: The Focused Transformer overview. During inference, a memory attention layer (green) uses additional context of ($K$, $V$) pairs via kNN lookup, which effectively extends its context length. This layer is trained using crossbatch. Namely, the tokens from the current context Ccurr attend in a differentiable way (Att + ∇) to the previous context Cprev of the same document and, importantly, d − 1 contexts of other documents. The latter serve as ’negative’ examples intended to better shape the ($K$, $V$) space.
Our method, the Focused Transformer (FOT), is a simple plug-and-play extension of transformer models and can be used both to train new models or fine-tune existing, possibly large, models with longer context. To this end, FOT uses memory attention layers and the crossbatch training procedure. Memory attention layers enable the model to retrieve information from the additional context at inference time, effectively extending the context. The crossbatch training procedure biases the model to learn ($K$, $V$) representations, which are easy to use by a memory attention layer. See Figure 2 for an overview of the FOT architecture and Appendix L for pseudocode.
Memory attention layers L are endowed with access to an additional context during inference. Namely, each query in ℓ ∈ L attends to preceding keys from the local context and the top k most matching keys (i.e. having the largest inner product with the query) from memory. The memory keys are ranked by the inner product with the query and retrieved using the kNN search algorithm. We use the exact kNN search implemented in FAISS [Johnson et al., 2017]. The memory is populated incrementally with ($K$, $V$) pairs processed by ℓ beforehand. Our memory attention layer design is closely related to [Wu et al., 2022], we follow most of its design choices, except for the gating, which we replace with a simpler mechanism, which turns out to be more effective in our applications. See details in Section C.3 and Appendix B.2. We remove positional encodings in memory layers in all our models except Long-llamas. This allows Long-llama checkpoints to be a drop-in replacement for LLaMA checkpoints. We treat the kNN search algorithm as an approximation of full dense attention, which opens the doors for future speed-ups.
Our training procedure is a novel way of training (or fine-tuning) transformer-based architectures in order to improve the structure of the ($K$, $V$) space. The main motivation is to shape this space so that a memory attention layer ℓ ∈ L can easily focus on relevant information. The key idea, inspired by contrastive learning, is to expose ℓ to ($K$, $V$) pairs from the current and previous local context of the given document (positives) and d − 1 contexts from unrelated documents (negatives). Importantly, this is done in a differentiable way.
To achieve this, we use a data pipeline in which each element of the batch corresponds to a different document. We embed the previous (Cprev) and the current (Ccurr) local context for each of the processed documents. The overview of our procedure can be found in Figure 2. Specifically for each document δ in Ccurr we create a set {pδ i }i={1,…,d} consisting of the ($K$, $V$) pairs from the previous local context of δ (positives), along with pairs from d − 1 other contexts coming from Cprev (negatives). We also experiment with varying the number of previous contexts and negatives for different batch elements. The operation is fully differentiable, and thus, we improve all the ($K$, $V$) pairs in pδ. Two, the procedure is easy to implement; it does not require any additional loss (i.e., uses the standard transformer training objective) and is done on the level of the data loading pipeline and a minor self-attention change. The only new hyperparameter is d, which prescribes the ratio of positive to negative samples. Typically, we find it beneficial to start with small d ≤ 8 (otherwise, the model tends to ignore the previous local context) and later switch to bigger values, say d ≥ 64. Appendix B.3 provides more details about the method. Listing 1 outlines an implementation of the crossbatch.
In this section, we conceptualize what we call the distraction issue and hypothesize it is one of the key problems in dealing with long multi-document contexts (like large code repositories). Namely, during the standard training, the model is not incentivized to distinguish the keys from different documents. We measure that the attention mass is evenly spread on the related and unrelated documents; see Figure 3. More precisely, for a document δ, let wij be the softmax weights related to pδ ij constructed as described in Section 3.2. We define the positive attention mass as rd := Σ j wij. We ob i=1 serve that rd ≈ 1/d, which can be interpreted as the fact that the attention is equally distracted by the positive (coming from the current document at i = 1) and negative keys. This is an undesirable property since when scaling the memory, the attention becomes increasingly distracted. We show that the crossbatch mostly alleviates the distraction issue, resulting in a focused attention. More information can be found in Appendix B.4. In Section 5.3, we also show that the distraction issue has a harmful effect on metrics like perplexity.
Figure 3: Distraction issue. We compare FOT trained with different values of parameter d to the standard Transformer baseline. During the evaluation, both models see the previous local context and some contexts from other documents in the chosen layer (as in crossbatch training procedure). For a document δ we measure the distribution of attention mass on pδ. Scale x: the number of contexts from documents that the model can see. Scale y: avg attention mass to the previous local context of the current document.
One of the promises of our work is that FOT can be used to fine-tune already existing large models to extend their context length. In this section, we show that this is indeed the case. We use OpenLLaMA-3B and OpenLLaMA-7B models trained for 1T tokens as starting points and fine-tune them with FOT. We show that the resulting models, which we call Long-llamas, are capable of extrapolating beyond their training context length (even up to 256K) and retain the performance on short-context tasks. We release the inference code on GitHub: https://github.com/CStanKonrad/long_llama and the Long-llama-3B checkpoint on Hugging Face: https://huggingface.co/syzymon/long_llama_3b. We note that our checkpoint is backward compatible, i.e. can be used with any existing LLaMA inference code (both in Hugging Face and other implementations), albeit without long-context capabilities.
The architecture of the models is the same as OpenLLaMAs, see Geng and Liu [2023] and Appendix A.1. We use L = {6, 12, 18} (resp. L = {8, 16, 24}) as the memory layers for 3B (resp. 7B) Long-llama model. We fine-tune the models on 10B (resp. 3B) tokens using FOT, 8k context length and our dataset mixture based on RedPajama [TogetherComputer, 2023], see Appendix A.3.
There are three minor differences from the standard FOT procedure. First, we retain the positional encodings in the local context of the memory layers (this is not necessary for FOT, but makes our checkpoints fully compatible with any existing LLaMA inference codebase). To be more precise, queries and keys from the local context (up to 2K tokens) receive the standard LLaMA rotary positional encoding, whereas memory keys are encoded as if they had position 0 in the local context window. Second, we use dense attention instead of the kNN retrieval, as we found only marginal performance differences, and it is simpler to implement. Third, we modify the crossbatch training procedure to have more fine-grained control over the number of additional contexts and the ratio of positive to negative samples. All these differences are detailed in Appendix A.2.
We first measure the effective context length of Long-llama, namely the distance for which tokens can effectively attend each other. We use passkey retrieval introduced in [Mohtashami and Jaggi, 2023], a synthetic task designed to measure this property. In this task, the model has to retrieve a passkey placed randomly in a long prompt. Results are shown in Figure 1 importantly, our 3B model is capable of solving this task much beyond its training context length 8K, achieving 94.5% accuracy for prompts of length 100k and 73% for 256k.
In Table 6 we present the performance on the validation set of Qasper [Dasigi et al., 2021] from SCROLLS [Shaham et al., 2022] and compare our results to LongChat 7B [Ma and Zhang, 2023] and two baseline short-context models. We note that our model shows gains from increased context length.
We measure long-context capabilities of these models on two downstream tasks, TREC question classification [Li and Roth, 2002, Hovy et al., 2001] and WebQS question answering [Berant et al., 2013]. We follow the experimental setup of [Hao et al., 2022]. Namely, we few-shot prompt the models with as many demonstration examples as possible up to the given context length. We do not use structured prompting like in [Hao et al., 2022] instead, we directly provide all demonstrations in context.
We observe significant accuracy gains from longer contexts on TREC and some improvements on WebQS (see Table 1). The TREC dataset consists of 50 classes. A model is tasked to predict the class label given in-context examples. Only 100 examples fit the standard context length (2K); it is not unusual that no class example is present for a given question, making the task impossible. Increasing the context length and the number of examples mitigates this risk. Moreover, having more demonstrations of the given class is also likely to be beneficial.
Table 1: Few-shot in-context learning performance of Long-llama; accuracy on TREC and WebQS. We see significant gains from the additional context on the TREC dataset. To calculate the results, we average over 20 trials for sampling in-context demonstrations from the train set; the resulting confidence intervals for TREC and WebQS are smaller than 1% and 0.1%, respectively.
Table 2: Few-shot in-context learning performance comparison between standard fine-tuning on 4K context (baseline) and FoT fine-tuning on the same context length for 1B tokens. On TREC, FOT is able to utilize additional examples beyond its training context length to achieve higher accuracy at 8K context length, which is not possible for the baseline since its context is bounded to 4K.
In this section, we compare FOT to standard long-context fine-tuning, showing that it already achieves better performance for the context length used for fine-tuning and, importantly, that it can extrapolate beyond this context length, which is not the case for the baseline.
For comparisons, we fine-tune two models, one trained with FOT and another one (baseline) with standard fine-tuning (done similarly to [MosaicML, 2023, Nijkamp et al., 2023]). In both cases, we use 3B models fine-tuned on 1B tokens using the 4K context length. We evaluate both models on a number of few-shot downstream tasks in the setting described in Section 4.4.
In most cases, see Table 2, we observe accuracy improvements when more few-shot demonstrations are provided in the extended context (from 2K used by OpenLLaMA to 4K used in our fine-tuning). On TREC, the gains from additional context are significant for both models, while on WebQS, the standard fine-tuning baseline does not provide any improvement from extended context. Notably, the model fine-tuned with FOT enjoys further accuracy gains when evaluated with context lengths beyond its training length (6K and 8K). This shows extrapolation capabilities of FOT, which are not present in the baseline (see e.g. Figure 1).
Fine-tuning for longer contexts could hurt performance on the original context length (2K), as the training data distribution changes. We show that this is not the case for the Long-llama models by evaluating them using the LM Evaluation Harness library [Gao et al., 2021a]. On most tasks, the performance is kept intact; see Appendix A.4 for details. This also confirms that Long-llamas could be used as a drop-in replacement of LLaMA models as they are compatible with the original LLaMA inference code.
In this section, we perform extensive experiments on smaller models to analyze and further validate our approach. In particular, we answer the following questions: (1) How does FOT perform when scaling the context length at inference time? (2) Can FOT be used to extend the context length of an existing, pre-trained model? (3) How effectively can it handle distractions, and how does this capability translate to enhanced performance in long-context language modeling tasks? Moreover, we provide ablation studies of our method and additional analysis.
Architecture For experiments described in this section we use decoder-only Transformer [Vaswani et al., 2017] models with 12 layers and 184M parameters (unless stated otherwise). Following Wu et al. [2022]; we pick ℓ = 8 as the memory attention layer. We tune k = 128, the number of top keys retrieved by kNN. In most experiments, we start training with a small crossbatch dimension d ≤ 8 and switch to d ≥ 64 after some training. For more details about the architecture and hyperparameters, see Appendix B and Appendix E.
Evaluation We distinguish two evaluation settings: single-document (abbreviated to single-doc) and multi-document (abbreviated to multi-doc). The single-doc setting is typically used for evaluating models that process long contexts. Here, we clear the memory for each new document, ensuring that only the current document is available in the context. The multi-doc setting retains memory across multiple documents without resets. This scenario tests whether the model can ignore irrelevant information and focus on the relevant data, which can be useful in setups like repository-level code generation.
Datasets We evaluate on the following long-context language modeling datasets: PG-19 (English books), arXiv (mathematical papers), GitHub (code), and Isabelle (formal proofs). PG-19 [Rae et al., 2019] is a large dataset of English-language books published prior to 1919, sourced from the Project Gutenberg archive. This dataset is a well-established benchmark for evaluating long-context language models [Sun et al., 2021]. The arXiv dataset contains LATEX source of papers labeled as “Mathematics” that were obtained by downloading articles through the arXiv Bulk Data Access. The token count per paper in this dataset is comparable to that of a book in PG19. For details on the remaining datasets, refer to Appendix H.
FOT is a minimal modification to the standard transformer architecture; therefore, it is possible to fine-tune existing models to endow them with a longer context length via the memory attention layer, as we already demonstrated in Section 4. In this section, we deepen this analysis (on a smaller model) by studying perplexity improvements on various datasets.
As a base model, we use a standard transformer model pre-trained for 100k steps with context of 1K tokens using the standard objective and fine-tune with the FOT objective (i.e. crossbatch). The data used for both fine-tuning and pre-training is the C4 dataset Raffel et al. [2019a] (we omit documents shorter than 2K tokens). The fine-tuning phase takes 10k steps. We use the crossbatch dimension d = 128 and local context of 1K tokens (context is 2K during training). We evaluate models in a zero-shot way on 4 language modeling datasets, which require long context: arXiv, PG-19, GitHub and Isabelle, see Section 5.1 and Appendix E for details.
In Table 3, we observe that FOT enjoys steady perplexity gains up to 64K tokens, although it was fine-tuned only with the 2K total differentiable context length. We compare the model perplexity to the following baselines: Memorizing Transformer (MT) [Wu et al., 2022] fine-tuned with the local context of 1K and memory size of 16K, and Transformer-XL [Dai et al., 2019] fine-tuned with both local context and window length of 1K. To ensure a fair comparison, all three models are fine-tuned from the same base checkpoint. When evaluated with a context of 2K, our method achieves results on par with the Transformer-XL baseline, which has access to the previous context in all layers, unlike MT and FOT. Compared to the MT baseline, we achieve better scaling when evaluated with 64K context length and significantly better perplexity values. Unlike MT, our method does not require training on long sequences, which is reflected by the lower perplexities of FOT when evaluated in the zero-shot setting. For more details, see Appendix G.
We also confirm the context extrapolation abilities using a synthetic dictionary lookup task. In this task, the model is first provided with ki : vi mappings and then asked what value is associated with a particular key. We train 37M parameter models using documents of length 512. Figure 10 shows that FOT, after 5k steps of training, can effectively utilize memory consisting of 16M tokens achieving accuracy above 92%. Details can be found in Appendix F.
Table 3: Perplexity for different context lengths after fine-tuning a standard transformer model. The model is fine-tuned using the FOT objective (i.e., crossbatch) on C4 and evaluated zero-shot varying the context size. Transformer-XL [Dai et al., 2019] and Memorizing Transformer [Wu et al., 2022] fine-tuned in the same setting are used as baselines.
In this section, we measure how handling distractions in the multi-document setting helps in language modeling. We pick the PG-19 dataset [Rae et al., 2019] and measure the perplexity of the next token prediction (language modeling task) when varying the size of multi-doc memory (in this case consisting of books). Intuitively, the memory tokens corresponding to the current book might be beneficial (which is also confirmed in [Wu et al., 2022]), while the ones from the other books are unlikely to be useful and thus are distractions.
We observe, see Figure 8, that higher values of the crossbatch dimension d lead to better perplexity. This aligns with the observations in Section 3.3, indicating that by mitigating the distraction issue, we experience benefits in language modeling.
Moreover, all versions of FOT are able to utilize memory and achieve much better perplexity than the standard Transformer (no memory). Unsurprisingly, perplexity increases with memory size, but we stress that this happens gracefully. In the standard variant of FOT (bold line), the perplexity increases only by 0.18 when scaling to > 500k tokens. Importantly, the perplexity of FOT is close to this of Memorizing Transformer with the single-doc memory, which we treat as a soft lower bound since it is not exposed to distractions from unrelated books.
The original motivation behind FOT is to improve the multi-doc setting performance by handling distractions. Interestingly, our method also helps to extrapolate to longer contexts, even when evaluated in the single-doc setting.
To study this, we perform FoT fine-tuning (as in Section 5.2) and evaluate the perplexity of the resulting model on the PG-19 dataset with different context lengths in the zero-shot fashion. To deepen the analysis, we introduce an additional parameter w (the number of previous contexts used in cross batch training procedure). We provide results for w = 1 (the standard setting for FOT, that corresponds to the total differentiable context being 2 · 1024) and w = 2 (corresponding to the total differentiable context 3 · 1024).
We observe, see Figure 9, improvements when context grows, even far beyond the training context length, which reaffirms the hypothesis that FOT helps with extrapolation to longer contexts. Moreover, d = 2 is significantly better than d = 1. When comparing d = 1 and w = 2 to d = 2 and w = 1, we observe that the former is slightly better. This is natural, as the former has longer training context.
In Appendix C we present ablations on our design choices. In particular, we note the importance of differentiability and the inclusion of negatives. We also discuss the relation to Memorizing Transformer. We note that due to the limited resources we have followed the Memorizing Transformer in the choice of memory layers.
Our research opens a few avenues for future work. We list them as well as challenges and limitations.
Scaling up context This is by far the most important future research direction. The challenges start from purely engineering, storing more than 16M ($K$, $V$) pairs will require a distributed multi-node system. In our experiments, we use the exact kNN search, which is not scalable to large memory. Using approximate kNN search will require a lot of engineering effort, as well as careful evaluation of the impact of the approximation on the model performance.
Scaling up crossbatch We observed that increasing d is beneficial. In our experiments, we used d = 64 or d = 128, which is the maximum value that fits into the memory of a single TPUv3/TPUv2 machine, see also Appendix I. In future work, we want to further increase d as well as test on devices with bigger memory or utilize multi-node training. We also note that crossbatch increases the training cost, but only in a subset of layers.
Exploring contrastive learning The FOT training is inspired by rather basic contrastive learning (CL) techniques. We show that this improves the key structure so that the distraction issue is mitigated. We expect that other CL methods could be beneficial, for example, hard negative mining to utilize a larger memory during training (see [Lindgren et al., 2021]). We leave this for future work.
Combining with other methods Developing long-context methods is an active research field, see Section 2. We believe that some of these methods could be combined with FOT, resulting in mutually beneficial interactions.
Listing 1: Possible implementation of cross-batch. To simplify the code we assume that each document occupies two consecutive elements of the batch. A more detailed version is in Appendix L.