00:00:00

Share Your Feedback 🏝️

Model | Generative Representational Instruction Tuning

Model | Generative Representational Instruction Tuning

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Hack Websites using LLM Next: GLoRe

Model | Generative Representational Instruction Tuning

  • Related Project: Private
  • Category: Paper Review
  • Date: 2024-02-05

Generative Representational Instruction Tuning

  • url: https://arxiv.org/abs/2402.09906
  • pdf: https://arxiv.org/pdf/2402.09906
  • abstract: All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8x7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at this https URL.

Contents

TL;DR


  • 대규모 언어모델의 대표적이고 생성적 작업의 통합
  • 수학적 인퍼런스 및 논증 기반의 방법 설계
  • 벤치마크를 통한 모델 성능 검증 및 효율성 증대

1. 서론

최근 대규모 언어모델(LLM)이 다양한 문제 해결을 위한 유망한 접근 방식으로 주목받고 있다. 이들 모델은 일반적으로 여러 태스크에서 향상된 성능을 보이는 단일 모델을 개발하는 것을 목표로 한다. 특히, 텍스트 기반의 언어 문제는 생성 작업으로 환원될 수 있으며, 이를 단일 LLM으로 처리할 수 있다고 알려져 있다. 하지만 임베딩 작업을 필요로 하는 클러스터링이나 검색과 같은 태스크는 이런 접근 방식에서 주로 간과되어 왔다. 본 연구에서는 생성적 지시 조정(generative instruction tuning)과 대표적 지시 조정(representational instruction tuning)을 결합한 새로운 방법인 GRIT(Generative Representational Instruction Tuning)을 제안한다.


2. 이론적 배경 및 연구 방법

2.1 수학적 모델과 손실 함수

GRIT 방법은 대규모 언어모델을 이용해 텍스트의 대표적 특성을 추출하고, 텍스트를 생성하는 두 가지 주요 작업을 동시에 수행할 수 있도록 설계되었습니다. 이를 위해 대표적 손실(Representational Loss)과 생성적 손실(Generative Loss)이라는 두 가지 손실 함수를 사용합니다. 이런 방법을 통해 하나의 모델이 두 가지 작업을 유연하게 처리할 수 있습니다.

2.1.1 대표적 손실 (Representational Loss)

대표적 손실은 모델이 텍스트 데이터의 임베딩을 효과적으로 생성하도록 돕습니다. 임베딩 과정에서 모델은 텍스트의 의미적 특성을 수치화된 벡터 형태로 변환하는 작업을 수행합니다. 이 손실 함수는 특히 문서 검색이나 유사도 평가와 같은 작업에서 주로 사용됩니다.

\[\mathcal{L}_{\text{Rep}} = -\log \frac{\exp\left(\text(f_{\theta}(q), f_{\theta}(d^+)) / \tau\right)}{\sum_{d \in \{d^+\} \cup \{d^-\}} \exp\left(\text(f_{\theta}(q), f_{\theta}(d)) / \tau\right)}\]

\(f_{\theta}\)는 모델 파라미터 \(\theta\)로 파라미터화된 함수이며, \(q\)는 쿼리, \(d^+\)는 관련 문서, \(d^-\)는 비관련 문서를 나타냅니다. \(\sigma\)는 코사인 유사도 함수로, 두 벡터 간의 유사도를 측정합니다. \(\tau\)는 온도 파라미터로, 소프트맥스 함수의 평활화를 조정하여 모델이 얼마나 선택적으로 반응할지를 결정합니다.

2.2.2 생성적 손실 (Generative Loss)

생성적 손실은 모델이 주어진 텍스트에 이어질 적절한 텍스트를 생성할 수 있도록 합니다. 이는 챗봇, 기계 번역, 요약 생성 등 다양한 자연어 생성 작업에 필수적입니다.

\[\mathcal{L}_{\text{Gen}} = -\sum_{i=1}^{|x|} \log P_{\theta, \eta}(x_i \mid x_{<i})\]

이때 \(x_i\)는 생성해야 하는 토큰, \(x_{<i}\)는 이전에 생성된 토큰 시퀀스입니다. \(P_{\theta, \eta}\)는 파라미터 \(\theta\)와 언어 모델링 헤드 \(\eta\)를 사용하여 주어진 조건에 기반한 토큰의 확률을 계산합니다. 이 손실 함수는 모델이 주어진 텍스트에 기반하여 다음 토큰을 얼마나 정확하게 예측하는지를 평가합니다.

GRIT 방법은 이 두 손실 함수를 통합하여 단일 모델이 텍스트 임베딩과 텍스트 생성 두 가지 태스크를 동시에 수행할 수 있도록 합니다. 이로 인해 모델의 범용성과 효율성이 크게 향상되며, 다양한 언어 처리 작업에 효과적으로 대응할 수 있습니다.

2.2 데이터셋과 벤치마크

연구에서 사용된 주요 벤치마크는 Massive Text Embedding Benchmark (MTEB)이며, 다양한 임베딩 태스크를 통해 모델의 성능을 평가한다. 생성 태스크의 성능 평가를 위해 HumanEvalSynthesize 데이터셋을 사용하여 모델의 지시에 따른 반응을 평가한다. 이런 벤치마크를 통해 GRITLM이 임베딩과 생성 모두에서 우수한 성능을 보이는지 검증한다.


3. 실험 결과

GRIT 방법을 적용한 결과, 모델은 임베딩과 생성 작업 모두에서 높은 성능을 달성했다. 예를 들어, 7B 파라미터를 가진 GRITLM은 MTEB에서 우수한 결과를 보였으며, 생성 작업에서도 70B 파라미터를 가진 모델을 능가하는 결과를 보였다. 이는 GRIT 방법이 대규모 언어모델에 효과적으로 적용될 수 있음을 시사한다.


4. 결론 및 향후 연구 방향

본 논문에서는 GRIT라는 새로운 방법을 통해 대규모 언어모델이 임베딩과 생성 작업을 통합적으로 수행할 수 있음을 보였다. 이 연구 결과는 향후 대규모 언어모델의 효율성과 다양성을 증대시킬 수 있는 방법을 제시한다. 그러나 다양한 언어와 도메인에서의 적용 가능성을 탐구할 필요가 있다.


1 Introduction

Creating a single general model that performs well at a wide range of tasks has been a long-standing goal of the field of artificial intelligence [73, 67, 21, 130, 139]. Recently, large language models (LLMs) have emerged as a promising direction for a single multi-task model [125, 13]. Prior work has argued that all text-based language problems can be reduced to generation and thus handled by a single LLM [128, 38].

Figure 1: Performance of various models on text representation (embedding) and generation tasks. GRITLM is the first model to perform best-in-class at both types of tasks simultaneously.

Figure 2: GRIT. The same model handles both text representation and generation tasks based on the given instruction. For representation tasks, instructions ideally contain the target domain , intent , and unit [5]. The representation is a tensor of numbers, while the generative output is text.

However, tasks that use embeddings, such as clustering or retrieval [107], have largely been ignored from this perspective. Today, text embeddings power many critical real-world applications ranging from search engines to user-facing chatbots [63, 144]. While integrating text embeddings into the generative paradigm is possible by generating a sequence of numbers to form the embedding tensor, it becomes impractical due to the high dimensionality and precision requirements of embeddings. Thus, it is more common and much easier to use the hidden state of the model as the embedding representation, which is already a numeric tensor [104, 158, 102]. However, doing so for current generative models leads to poor performance. For example, while the T5 model [128, 134] can handle any generative task in a sequence-to-sequence fashion, it requires fine-tuning to make its hidden state useful for text embedding [111, 112] during which it loses its generative capabilities.

We introduce GRIT (generative representational instruction tuning) which unifies embedding and generative tasks, leading to a model that excels at both tasks as shown in Figure 1. Figure 2 depicts how GRIT combines two previously disjoint training paradigms: (1) Generative instruction tuning, whereby the model is trained to respond to instructions by generating an answer [164, 134]; and (2) Representational instruction tuning, whereby the model is trained to represent a provided input according to an instruction [143, 5]. Via the instructions and separate loss functions the model learns to differentiate the two streams. We test our approach on models with up to 47B parameters and, due to its simplicity, we expect the method to generalize to any LLM, even non-transformers. This unification via GRIT leads to three advantages:

  • Performance: Our unified model matches the performance of embedding-only and generative-only variants, even outperforming them on some tasks. At 7B parameters, GRITLM sets a new state of the art on the Massive Text Embedding Benchmark [107] among open models and at the same time outperforms much larger models on generative tasks, such as Llama 2 70B. By scaling further, GRITLM 8X7B is the best open generative language model on our task average, while only using 13B parameters at inference. Further, as our models use sliding window attention [20, 9] they can handle generative and embedding inputs of arbitrary length.
  • Efficiency: Generative and embedding models are commonly used together to make up for each other’s deficiencies [56, 84]. One such scenario is Retrieval-Augmented Generation (RAG) [84], where an embedding model is used to retrieve context that is provided to the generative model to answer a user query. This requires passing the user query and the context into both the generative and the embedding model for a total of four forward passes. With GRITLM, the embedding and generative model are equivalent, allowing us to cache computations and halve the necessary number of forward passes. We find that this can lead to >60% faster RAG at inference with long documents.
  • Simplicity: Currently, API providers such as OpenAI provide separate generative and embedding endpoints. This requires separate load balancing, additional storage, and more complex serving software. A single model that handles both use cases significantly simplifies infrastructure needs.

The main downside of GRIT is that it requires more compute due to training with two objective functions. However, as fine-tuning is cheap compared to pretraining, we think the benefits vastly outstrip this problem and thus recommend practitioners building instruction-following language models to adopt GRIT during fine-tuning.

2 GRIT

GRIT unifies representational instruction tuning [143, 5, 160] and generative instruction tuning [164, 134, 108] into a single model. We finetune a pretrained large language model [13] with embedding and generative instruction data in a consistent format as depicted in Figure 3. For embedding data, we follow prior work and compute the loss using a contrastive objective with in-batch negatives [18, 51]:

\[\mathcal{L}_{\text{Rep}} = -\log \frac{\exp\left(\text(f_{\theta}(q), f_{\theta}(d^+)) / \tau\right)}{\sum_{d \in \{d^+\} \cup \{d^-\}} \exp\left(\text(f_{\theta}(q), f_{\theta}(d)) / \tau\right)} \tag{1}\]

where \(f\) is GRITLM parametrized by the model \(\theta\), \(\tau\) is a temperature hyperparameter, and \(\sigma\) corresponds to pooling applied to each output followed by cosine similarity. \(q\) and \(d\) are query and document samples. As depicted in Figure 3, we use bidirectional attention followed by mean pooling, which corresponds to averaging the hidden states across the sequence length. During pooling, we only average the final hidden states of the input sample, ignoring the instruction and format tokens. However, the instruction and format tokens still influence the final representation through the self-attention mechanism [156].

Figure 3: GRITLM architecture and format. Left: GRITLM uses bidirectional attention over the input for embedding tasks. Mean pooling is applied over the final hidden state to yield the final representation. Right: GRITLM uses causal attention over the input for generative tasks. A language modeling head on top of the hidden states predicts the next tokens. The format supports conversations with multiple turns (indicated with “…”).

To compute the loss on generative data, we use the language modeling objective whereby the model needs to predict the next token [124, 125]:

\[\mathcal{L}_{\text{Gen}} = -\sum_{i=1}^{|x|} \log P_{\theta, \eta}(x_i \mid x_{<i}) \tag{2}\]

where \(f\) is GRITLM parametrized by the model \(\theta\) and the language modeling head \(\eta\), which is only used for generation. \(x\) are generative training samples. We only compute loss over predicted tokens i.e. “{response}</s>” in Figure 3. A key consideration is whether the generative loss is aggregated at the sample or token level. Aggregating at the sample level corresponds to giving each sample the same weight within a batch regardless of its token count. Such aggregation is commonly used for instruction tuning, as it can boost performance on discriminative tasks [108]. However, Muennighoff et al. [108] also show how this in turn leads to a model biased toward short generations. Meanwhile, aggregation at the token level corresponds to giving each token the same weight, thus samples with many tokens become more important. This usually leads to a model producing longer generations, which can be important for performance on generative tasks. Especially, human or machine-evaluated generative tasks, such as AlpacaEval [89], are known to be biased toward preferring longer generations [162]. Note that when every sample has the same sequence length such as during pretraining or when the batch size is 1, token and sample level generative loss are equal to each other. One can also mix the two to balance their trade-offs, for example doing token level loss across a subset of the batch and then giving each subset the same weight. We explore the trade-offs in our ablations in §3.3. We sum the objectives with optional loss weights \(\lambda_{\text{Rep}}\) and \(\lambda_{\text{Gen}}\):

\[\mathcal{L}_{\text{GRIT}} = \lambda_{\text{Rep}}\mathcal{L}_{\text{Rep}} + \lambda_{\text{Gen}}\mathcal{L}_{\text{Gen}} \tag{3}\]

Notably, our formulation supports differing numbers of embedding samples (\(M\)) and generative samples/tokens (\(N\)). This allows for significantly increasing the embedding batch size while keeping the generative batch size fixed. A large embedding batch size is often key to well-performing text embedding models [169]. However, it comes at the cost of requiring more compute at each step.

3 Experiments

In this section, we first outline our experimental setup in §3.1. In §3.2, we discuss and benchmark the embedding and generative performance of our models. Finally, in §3.3, we ablate the settings that led to our final models, including training data, precision, pooling, sequence length, and loss weights.

3.1 Setup

We finetune our final models from Mistral 7B [68] and Mixtral 8x7B [69] using adaptations of E5 [160] and the Tülu 2 data [64]. For E5, we adapt it by adding S2ORC [91] to increase its scientific data (“E5S”), while for Tülu 2 we filter out their custom prompts that contain answers related to the origin of their model. For GRITLM 7B, we use a batch size of 2048 for embedding data and 256 for generative data and we train the model for a total of 1253 steps corresponding to one epoch on the generative data and 1.36 epochs on the embedding data. For GRITLM 8X7B, the embedding batch size is 256 due to compute limitations. We use several strategies to reduce the memory required during training including a novel technique to split the embedding triplet into separate forward and backward passes detailed in Appendix G. Other hyperparameters are detailed in the ablation experiments in §3.3 and Appendix H.

For embedding performance we evaluate using the 56 main datasets from MTEB [107]. For generative performance, we largely follow the evaluation setup of Ivison et al. [64] except that we use the HumanEvalSynthesize [105] variant of HumanEval, as it is more adequate for instruction-following models. We explain each task in more detail in Appendix D.

3.2 Main Results

GRIT leads to a state-of-the-art embedding and generative model We benchmark GRITLM 7B, GRITLM 8X7B and generative- and embedding-only variants with other models in Table 1 and Table 2. We find that GRITLM 7B outperforms all prior open models on the Massive Text Embedding Benchmark [107] while still outperforming all generative models up to its size of 7 billion parameters.

Table 1: Embedding performance of GRITLM and others. We indicate parameter counts where available (B=billions). See Appendix D for task, metric, and dataset details. Appendix F contains per-dataset results of GRITLM models. LLMs not finetuned for embedding (Llama 2 70B, Mistral 7B (Instruct), GPT-J 6B, Gen.-only) are evaluated with weighted-mean pooling [104].

Results from the MTEB leaderboard (https://hf.co/spaces/mteb/leaderboard)

Table 2: Generative performance of GRITLM and others. We indicate parameter counts where available (B=billions). See Appendix D for dataset, setup, and metric details.

Results from Ivison et al. [64] except for numbers marked with ♦ which are from Touvron et al. [154] and † which are from us. For models that cannot be easily used as chat models, we set Alpaca to 0.

GRIT models are the only ones that can handle both embedding and generation at best-in-class performance (Figure 1). For example, using Llama 70B [154] for embedding leads to a score of only 35.6 on MTEB as depicted in Table 1. GRITLM almost doubles that performance on MTEB leading to state-of-the-art performance, while still outperforming Llama 70B on generative tasks by more than 20% (Table 2). Scaling even further, GRITLM 8X7B outperforms all openly available models on our generative average. We also train embedding-only and generative-only variants of GRITLM that only use representational or generative instruction tuning but are otherwise equivalent. Benchmarking the embedding-only variant or SGPT BE 5.8B [104] on generative tasks in Table 2 by simply re-adding the language modeling head that was dropped during embedding fine-tuning leads to around random performance (25.0 is the random baseline on MMLU). Similarly, benchmarking the embedding performance of the generative-only model only leads to a score of 41.2 in Table 1. Thus, joint optimization via the GRIT approach is critical to achieve strong performance for both embedding and generation. We note, however, that with 7 billion parameters GRITLM 7B is significantly more costly to run than many other embedding models in Table 1, such as BGE Large with only 335 million parameters [169]. In addition, GRITLM 7B produces representations of 4096 dimensions, which require 4× more storage than the 1024-dimensional embeddings of BGE Large.

GRITLM matches embedding-only and generative-only variants We find that unifying the two objectives via GRITLM matches both the generative-only and the embedding-only variants. This is similar to observations made for visual models [176]. However, while GRITLM is trained for the same number of steps as the embedding-only and generative-only variants, it requires more compute per training step as it does a forward and backward pass on both embedding and generative data.

Table 3: Reranking (Rerank) using GRITLM as both Bi- and Cross-Encoder.

ArguAna ClimateFEVER CQADupstack DBPedia FiQA2018 FEVER HotpotQA NFCorpus NQ MSMARCO QuoraRetrieval SCIDOCS SciFact TRECCOVID Touche2020

Reranking with GRITLM For retrieval tasks, it is common to follow the embedding-based retrieval stage by a reranking stage [113]. In the reranking stage, for each query, the top- k chosen documents are reranked based on a usually more expensive but more performant method. For LLMs, prior work has shown that this can be done by passing each of the k doc- uments together with the query to the model and scoring the pair with log probabilities [104]. Note that this scales quadratically with the num- ber of documents and queries and is thus usu- ally too expensive for the first stage (“Cross- Encoder”). Meanwhile, using embeddings for the first stage is much cheaper as it only requires passing each query and each document once and thus scales linearly (“Bi-Encoder”). More re- cent work relies on instructions to use LLMs for reranking [145, 96, 120, 121]. While prior work uses separate models for the embedding and reranking stages, GRITLM can be used for both stages due to its unified capabilities. In Table 3, we display the embedding performance of GRITLM 7B when additionally allowing it to rerank its top 10 documents for each query. For reranking, we use the model’s generative capabilities following the permutation generation approach from Sun et al. [145] and reusing their prompt. We find that reranking via the generative capabilities of GRITLM 7B allows it to improve on its own embedding performance on almost every retrieval dataset. Increasing the top-k documents beyond ten is likely to further improve results, however, at the cost of more compute [104].

Few-shot embedding does not work For gen- erative models it has been well-established that providing in-context examples (“few-shots”, FS) improves performance [13]. However, to the best of our knowledge, there has been no work on in-context learning with embedding models. In Table 4, we benchmark the default 0-shot format versus providing a single few-shot ex- ample following the task instruction. We take the few-shot example from the respective eval- uation dataset (see §O.2 for the prompts). We find that providing few-shot examples overall worsens performance. While there are small gains among PairClassification tasks (SprintDup. and TwitterURL), these are marginal and incon- sistent. For the model trained on MEDI2, we even include few-shot embedding samples in the training data for around 5% of training samples. However, the model seems not to have learned to make good use of the few-shot examples.

3.3 Ablations

Attention and pooling We train GRITLM starting from a pretrained decoder language model which has been trained with causal attention. Prior work has shown that while embeddings of causal LLMs are competitive, they are outperformed by BERT-like encoders with bidirectional attention at the same number of parameters [104, 34]. This lines up with intuition, as bidirectional attention allows the model to adjust the representation of the first tokens based on information obtained from future tokens. Meanwhile, causal attention only allows information to propagate one way. Thus, for causal attention early tokens may yield poor representations due to a lack of understanding of the entire sample. To counter this issue, we experiment with adapting the model during fine-tuning to learn to use bidirectional attention. In Table 5 we find that adapting the causally pretrained LLM with bidirectional attention provides the best embedding performance. For fully causal embeddings, we confirm findings from Muennighoff [104] that position-weighted mean pooling (“Wmean”) leads to better embedding performance than taking the embedding of the last token despite recent work finding the opposite [179, 95]. For last token pooling, we follow Zhang et al. [179] and use a special token. We find that adapting the model to be a PrefixLM [128], whereby the attention over the generative instruction is bidirectional but still causal for the response (“Sample”) worsens performance in contrast to prior work [161]. Thus, we stick with fully causal generation. The unified variant significantly outperforms the embedding-only variants, while underperforming the best generative-only variant. However, once we switched from MEDI to the E5 dataset in later ablations the embedding-only variant matched the unified variant. Meanwhile, the worse generative performance of the unified model was due to a suboptimal loss setting that we fixed in the loss ablations.

Table 4: Few-shot embedding. The 12 MTEB datasets (“DS”) are grouped by the 7 main MTEB tasks in the same order as in Table 1.

Base model The GRITLM approach generalizes to any generative language model, thus we ablate initializing from GPT-J 6B [157], Llama 2 7B or Mistral 7B [68]. Using Mistral 7B leads to the best performance for both embedding and generative tasks. For generative tasks, this is expected as the pretrained Mistral 7B performs the best among the three (Table 2). However, for embedding tasks, GPT-J outperforms Mistral 7B (Table 1). Thus, the embedding performance of a pretrained model is not predictive of its embedding performance after fine-tuning. Rather, its generative performance appears to be a more reliable indicator of its embedding performance after fine-tuning.

Generative dataset We benchmark our filtered Tülu 2 introduced in §3.1 [64] with UltraChat [36, 155] and the OpenAssistant version from OctoPack [105, 82, 92]. Using Tülu 2 leads to better performance on every generative task considered (see Appendix E for per-task results). This is likely due to Tülu 2 containing a larger diversity of tasks [64]. Another possible reason is that Tülu 2 may have been carefully tuned on the generative evaluation datasets, as we use largely the same evaluation setup as the creators of Tülu 2 [64].

(k) Loss ablations. LRep/LGen is the loss ratio of the 1st step adjusted via λRep and λGen. Mix refers to mixing sample and token level loss, e.g. (32->8) is token level loss across 32 samples and then sample level loss across 8 sub-batches for a total batch size of 256.

Table 5: GRIT ablations. Emb corresponds to the MTEB average, while Gen corresponds to the average across generative tasks (Appendix D). The embedding head variant “-> 1024” corresponds to down-projecting the final hidden state with a linear layer from 4096 to 1024 dimensions, only for embedding tasks. BF16∗ means that some computations are still in FP32 as explained in §3.3. The setting chosen for GRITLM is bold. Once an ablation was successful, we adopted its setting, thus the bold performance slightly varies from one table to the next. For example, the base model ablation (b) is done for just 100 hundred steps with sub-optimal formatting. Full results are in Appendix E.

Embedding dataset We benchmark MEDI [143], a new version of MEDI with better negatives which we build and call MEDI2, and the E5 dataset [160]. While MEDI and MEDI2 always preface instructions with “Represent” (see e.g. Figure 10), the E5 dataset places no constraint on the instruction prefix (see e.g. Figure 11). Thus, when using the E5 dataset the <|embed|> formatting is critical to tell the model that it will be subject to the representation loss, not the generative loss (Figure 3). Further, MEDI and MEDI2 always contain instructions for both queries and documents, which we refer to as two-sided instructions. Meanwhile, the E5 dataset uses one- sided instructions for asymmetric datasets [104], whereby the documents receive no instructions, only the queries. The advantage of not using document instructions is that the document corpus can be encoded once and then cached and reused across a variety of tasks. During training on E5, symmetric tasks are also in a one-sided setting, but we still evaluate them in the two-sided format. This should not be a problem as the cosine similarity function we use during training is transitive: if sentence A with instruction is similar to sentence B without instruction, and sentence B without instruction is similar to sentence C with instruction, then we can confidently say that sentence A with instruction is also similar to sentence C with instruction. As depicted in Table 5, using the E5 dataset performs best by a wide margin. An inspection of samples, suggests that this is likely due to its superior hard negatives and diversity of tasks generated by GPT-4 (Appendix N). For our final runs with the E5 dataset, we additionally add scientific data (§3.1).

Embedding head The cost of caching the embeddings of a large document corpus is directly proportional to the embedding dimensionality. To minimize such costs, we experiment with adding an embedding head consisting of a linear layer with activation that down-projects the embedding [111, 104]. This layer is only used for embedding tasks. Down-projecting the embeddings four-fold (from 4096 to 1024) leads to an embedding performance decrease of around 1%. This may be acceptable for certain use cases where the saved storage is more important. However, for our final model, we do not use such a head to keep it simple and achieve maximum performance. Search techniques [3, 72, 37] or dimensionality reduction techniques such as Principal Component Analysis still allow for reducing the embedding dimension of our final model post-training while maintaining most of the performance.

Batch size Due to the utilization of in-batch negatives for contrastive training (§2), a larger batch size provides a more accurate gradient. Thus, scaling up the batch size is a key ingredient in most well-performing embedding models [169, 159]. We experiment with scaling up the embedding batch size to 4096 while keeping it at 256 for generative data. This leads to a 1.0 gain on the embedding average while generative performance remains stable. Especially the 15 retrieval datasets that are part of the embedding average benefit from the increase in batch size (see Table 17). For our final model, we use a batch size of 2048 for embedding and 256 for generative data.

Precision The parameters of the Mistral 7B model are in bfloat16 (BF16) precision as it was pretrained in this format. We experiment with fine-tuning it with float32 (FP32) precision versus keeping the BF16 format and training with mixed precision. FP32 training is more costly, however, the additional precision may result in a better model. Our intuition is that more precision is important for embedding but not as much for generation. This is because while for generative tasks evaluated greedily, the model output is a discretionary argmax over the predictions of the language modeling head, for embedding tasks it is a continuous representation. Thus, small differences due to a lack of precision may not change the model’s generation but will affect its representation. Hence, for embedding tasks, we always cast the hidden states to FP32 during the pooling operation and keep them this way for the similarity computation. Not keeping them in FP32 after pooling worsens performance slightly, but may be necessary for cheap storage (see Appendix K). In addition, some operations such as layer normalization [7] are also performed in FP32 even for BF16 training due to PyTorch autocast [182]. In Table 5, we find that there is no benefit from doing even more computations in FP32 besides the ones listed above. Thus, we train and evaluate all our other models in BF16 mixed precision to speed up training and inference.

In-batch negatives We always use in-batch negatives for embedding training (§2), however, we ablate whether or not they come from the same dataset. We hypothesize that making them all come from the same dataset leads to better negatives as the model needs to distinguish them based on more nuanced differences. In practice, we find that the average embedding performance remains around the same. However, we notice a 1.3 jump on the 15-dataset Retrieval average (Table 19). Thus, we stick with the variant where in-batch negatives stem from the same dataset.

Format Our chosen format is depicted in Figure 3, which is equivalent to Tülu 2 [64] for generative tasks. We also benchmark the Zephyr β format [155], which has an additional end-of-sequence token (“</s>”) after each user utterance. We find that it performs worse on generative tasks. The additional end-of-sequence after the user utterance increases the likelihood of the model generating another end-of-sequence token earlier than necessary. This significantly harms HumanEvalSynthesize performance and slightly reduces AlpacaEval, where long generations can be critical (see Appendix E for task-specific performance).

Max tokens Our base model, Mistral 7B, can handle sequences of arbitrary length due to its sliding window attention [68]. As fine-tuning with longer sequences is more expensive we ablate its benefits. We compare training with a maximum token limit of 512 versus 2048 for embedding documents. For embedding queries, we always use 256, and for generative data, we always use 2048. We find that increasing the embedding document sequence length during training slightly boosts performance on both embedding and generation even though we still evaluate embedding tasks with 512. This boost likely comes from our training data containing many documents beyond 512 tokens, which need to be truncated if the maximum sequence length is 512. Such truncation may remove the critical parts that make two texts a positive or a negative contrastive pair and thus hinder learning. As our embedding evaluation (MTEB) contains few documents longer than 512 tokens there is little truncation happening at evaluation [107, 58, 57]. Note that just like their base models, our final models GRITLM 7B and GRITLM 8X7B can produce embeddings for sequences of arbitrary length. However, due to a lack of benchmarks, we do not know how well the embeddings of our models perform for input sequences longer than 512 tokens.

Loss ablations As detailed in §2, we experiment with both token and sample level generative loss. Further, we ablate the representation and generative loss weights, λRep and λGen. For the unified visual model CoCa, the authors find that giving a weight of 2 to generation and 1 to embedding boosts performance on both streams [176]. However, rather than the weights, we argue that the loss ratio, LRep/LGen, is of more interest as it reveals which objective has a larger impact on the optimization of the model. We maintain a ratio of LRep/LGen > 1 i.e. giving more weight to the representation loss. This is because the model has already been pretrained with the generative loss, thus we expect less additional generative training to be necessary. Meanwhile, the contrastive loss for embedding data is new to the model, thus we expect more learning to be needed on the embedding side. Further, the embedding loss drops off extremely quickly as can be seen in the loss graphs in Appendix C. Thus, even though the representation loss has a higher weight at the start, throughout training they have very similar weights with both hovering around a loss of 1.0. We find that mixing sample and token level generative loss leads to the best performance by a small margin. As expected in §2, token level loss to some degree is critical for good performance on AlpacaEval. For “Mix (4 -> 64)” token level loss is applied across only 4 samples and then sample level loss across 64 sub-batches, which leads to a 7-point drop in AlpacaEval performance. This drop is accompanied by a decrease in median AlpacaEval generation length from 941 to 865. Thus, token level loss across many samples is critical to maintaining long generations, which directly impacts the AlpacaEval score.

4 RAG with GRIT

Method By unifying embedding and generation, GRITLM simplifies Retrieval-Augmented Gen- eration (RAG). Figure 4 displays how forward passes can be reduced by caching. Specifically, we break down the caching alternatives into: (a) Query Caching: In traditional RAG, the query needs to be passed both through the embedding model and later through the generative model. In Query Caching, we cache the key-value states from the embedding forward pass and reuse them for the generative pass, exploiting the property that both are the same model: GRITLM. Thus, we save compute equivalent to one forward pass of the query. Equivalently, we can also perform the generative forward pass over the query first and use its representation to retrieve the document on the fly (depicted in Figure 4). Note that Query Caching can be completely equivalent to RAG if the query is placed at the beginning of the prompt such that it only attends to itself through causal attention. (b) Doc Caching: Here we cache the documents, D. When the index is created, we also save the key-value states of every document and add them to the index. Thus, the index consists of the document embeddings and key-value states. Note that the computational cost of creating the index remains the same as the key-value states have to be computed even if only embeddings are desired. At inference, we still retrieve based on embedding similarity but the index returns the key-value states instead of the text passage. These key-value states are then provided to the model to avoid having to recompute them. This effectively saves a forward pass for every in-context document at inference. However, this method increases the necessary storage. While the text passages no longer need to be stored, the key-value states now need to be stored and they usually require more storage depending on the model. We note that Document Caching also works for models other than GRITLM. However, for such models, one needs to pass all documents through the generation model ahead of time, thus increasing the cost of creating the index. To maintain equivalence with RAG, the document should be at the beginning of the prompt for Document Caching (opposite of Query Caching). (b) Query-Doc Caching / Doc-Query Caching: We can also combine Query Caching and Doc Caching to save even more inference costs. However, combining them inevitably leads to discrep- ancies compared to RAG, as in traditional RAG either the query or the document is conditioned on the other one. Meanwhile, if both are cached then they are not conditioned on one another via the self-attention mechanism. We refer to Query-Doc Caching if the query is followed by the document in the prompt and to Doc-Query Caching if the document comes first.

Figure 4: RAG with GRIT. Left: Traditional Retrieval-Augmented Generation (RAG) relies on a separate embedding model and generative model. Right: GRITLM simplifies RAG as it handles both embedding and generation. Query Caching removes the duplicate forward pass of the query by reusing its representation. Query-Doc Caching also removes the forward pass on the document during inference, as the cached index also stores the document key-value states.

Setup We benchmark the different caching variants using data from Natural Questions [81]. Our implementation is adapted from Izacard et al. [66], however, we use a significantly smaller index of only 2,681,468 documents stemming from the BEIR NQ corpus [152]. We score models using the match score, whereby we check if any of the correct answers are anywhere in the generation. Prior work usually uses exact match, whereby they check if the generation exactly matches the answer. However, as our model is more chatty, it tends to answer in a few sentences and thus exact match often fails to give credit to correct answers. Inspecting the first 20 samples of the “No RAG” baseline, we find that exact match leads to 4 false negatives that are correctly credited by the match metric. We did not find any false positives from choosing the match metric in those samples. We do not use any instructions for embedding, solely the format as presented in Figure 3.

Table 6: RAG benchmarking on Natural Questions with GRITLM 7B. For RAG, the retrieved context is simply placed in the context of the language model in contrast to our caching alternatives (Figure 4). CPU and GPU latencies are measured on an “Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz” and one “NVIDIA H100 80GB HBM3”, respectively. Sample A has a query of 1 token and a document of 4000 tokens, and sample B is the inverse. For each approach, we generate 16 tokens. Storage consists of the index and passages, except for Doc Caching variants where it is the index and key-value states. The index is stored in float32, while key-value states are stored in bfloat16. CPU Latency (s, ↓)

Figure 5: Inference latency of RAG with GRITLM 7B. When benchmarking scaling query length (left), document length is fixed at 1, whereas query length is fixed at 1 when scaling document length (right). In addition to the query/doc lengths, the formatting and prompt take up around 40 tokens. We visualize the standard deviation across 100 runs as the shaded area. For each approach, we generate 16 tokens.

Performance As depicted in Table 6, RAG performs better than the “No RAG” baseline where the model is not provided any context. This validates that despite its small size compared to prior work [90], our index is still valuable. While Query and Doc Caching can theoretically lead to the exact same performance as RAG, we experience differences stemming from two reasons:

  • 1) Attention: Our model is trained to embed with bidirectional attention (§2) and thus we use bidirectional attention when embedding query or document. Meanwhile, the generative model expects causal key-value states. In the Query-Doc/Doc-Query setup, there is an additional mismatch in either the documents or the queries not having attended to the other one, as both need to be embedded and cached separately.
  • 2) Formatting: The query is formatted in the embedding format as depicted in Figure 3, which the model has not seen in conjunction with a generative task during training. This could further lead to a performance drop.

Due to 1) and 2), Query Caching leads to a performance drop compared to traditional RAG. However, the Query Caching performance of 25.46 is still better than not using RAG, thus it comes down to a speed-performance trade-off. Formatting the RAG baseline using the embedding format (Figure 3) reduces its score from 30.50 to 29.36 (not depicted), thus the additional four-point discrepancy of Query Caching and the majority of the damage is because of the attention issue. Meanwhile, Doc Caching slightly improves performance resulting in the best match score among all methods considered. This is possibly because, unlike the query, the document does not need to be as thoroughly understood, and skimming it may suffice. Thus, the slightly corrupted key-value states do not result in a performance drop. Query-Doc and Doc-Query Caching only perform near the “No RAG” baseline in our experiments, which may limit their usefulness in practice. This is likely caused by the additional attention mismatch that they introduce. This issue as well as the formatting issue could likely be solved by an additional RAG fine-tuning stage on top of GRITLM, which we leave to future work.

Latency In Figure 4, we show how caching leads to significant speed-ups over RAG on both CPUs and GPUs for long sequences. If only 250 tokens are cached, however, we find the speed-up to be negligible. In Table 6, we display that for 4000 tokens, Query Caching is 54% and 33% faster on CPUs and GPUs, respectively (Sample B). For Doc Caching it is 63% and 31% (Sample A). If going beyond 4000 tokens the speed-ups will be even larger. However, for the opposite samples in Table 6 speed remains around the same. This is because while for Sample B, Doc Caching caches 4000 tokens, for Sample A it caches only 1 token, which does not provide any speed-up. Thus, Doc Caching should be used when documents are expected to be very long, while Query Caching should be used when queries are expected to be very long. In a production setting, a simple input length check could switch from one caching mode to the other. As is the case in Table 6, caching can match or even be faster than not using retrieval at all (“No RAG”). This could be due to the embedding forward pass not using the language modeling head. For Query Caching, the language modeling head is only used for the tokens that are generated, while for “RAG” and “No RAG” it is used for the entire input. The matrix multiplication with the language modeling head is computationally expensive due to its high dimensionality, which could cause the slower speed of the no retrieval baseline. Query-Doc Caching and Doc-Query Caching cache both documents and queries and thus lead to major speed-ups for both Sample A and Sample B in Table 6. Overall, speed-ups are larger on CPUs, as GPUs can process the entire sequence in parallel, thus the advantage of caching parts of it is smaller. We also note that our RAG baseline uses our 7B parameter model for both the embedding and generative model but without caching. In practice, it is often common to have an embedding model that is much smaller and cheaper than the generative model. Nonetheless, as caching with GRITLM-7B approaches the No RAG latency in Table 6, we still expect it to be faster than setups with smaller embedding models for long sequences. In addition, it would lead to significantly better performance in that case due to the state-of-the-art retrieval performance of GRITLM.

Storage In most RAG setups the embeddings of all documents are computed ahead of time and stored to be later used at inference. This is referred to as the index. In traditional RAG, the documents themselves still need to be stored, as the index is only used for finding the document ID, which is then used to fetch the document text and pass it to the generative model. For Doc Caching variants documents no longer need to be stored, however, the key-value states need to be stored together with the index. The key-value states take up a lot of storage, as for each batch they consist of two tensors of shape (batch size, number of heads, sequence length, dimension per head). For our 2,681,468 documents and the 7-billion parameter GRITLM model, this leads to around 30TB of key-value states. However, unlike the index, the key-value states can be fully offloaded to disk and do not need to be kept in memory. Once the document ID has been determined via the index, the corresponding key-value state can be simply loaded from disk. For a single sample, this corresponds to loading around 12.5MB of key-value states into memory.

5 Discussion

Further unification To the best of our knowledge, GRITLM is the first model to unify text em- bedding and generation, and thus all text-based language problems, into a single model at strong performance. However, many adjacent directions remain to be improved or unified. (a) Multilingual- ity: Our model is also capable of embedding and generation in non-English languages as seen in its TyDi QA performance (Table 2). However, major performance gains on non-English tasks are likely possible through both data [108, 174] and architecture changes [15, 47, 41] targeting multilinguality. (b) Multimodality: Many embedding and generative problems are not purely text-based, such as joint embedding of images and text [123], generative image captioning [62], image-text pair classifi- cation [103, 78] or speech versions of every text problem [74]. It remains to be explored whether they can be as easily unified as text embedding and generation in this work.

Why does GRIT work? GRIT unifies embedding and generative tasks into a single model at no performance loss on either one, which may seem surprising. When the embedding dataset is MEDI2, we show that embedding performance even improves once the generative objective is added compared to an otherwise equivalent embedding-only model (§3.3). We think that our results confirm that generative language modeling and text embeddings are two sides of the same coin. Both tasks require a model to have a deep understanding of natural language and only differ in the way that understanding is expressed. Possibly, our unified model contains a small number of parameters that act as a switch to make the final representations either useful for mean pooling and subsequent embedding tasks or primed for the language modeling head and subsequent generative tasks. We are excited about future work exploring what is happening inside of GRITLM. To support such research, we release all our work freely.

Optimizing RAG with GRITLM RAG and the caching variants we have presented in this work operate on a frozen language model. Meanwhile, there has been extensive work on optimizing a generative model specifically for interaction with a retrieval system [52, 185, 4]. These works commonly optimize only the retriever [138] or only the reader [12, 172, 6, 93]. However, recent work has shown that jointly optimizing both models leads to the best performance [90]. With its state-of-the-art retrieval and generative performance, GRITLM can act as both the retriever and reader in a single model. Thus, optimizing either one also changes the parameters of the other. This has the potential to significantly simplify the joint optimization of the retriever and reader. For example, it may suffice to only use the next-token objective (Equation 2) to penalize the retriever for providing irrelevant context and at the same time the reader for poor use of the given context. This is in contrast to separate models and objective functions used in Lin et al. [90].

The story of text embedding and text generation has been a story of unification.

Embedding Models used to focus on word representations [118, 98] that struggled generalizing to entire sentences or passages [28]. InferSent [29], SBERT [131] and similar models [112, 111] emerged that handle both the embedding of words and sentences at good quality by considering context when present. However, for strong performance, they require separate models for symmetric and asymmetric tasks [107, 104]. Symmetric embedding tasks are ones where the query and document are expected to come from the same distribution, such as STS. Meanwhile, for asymmetric tasks, they come from different distributions and as such could have very different sequence lengths like in retrieval. For example, the MTEB benchmark [107] revealed that SentT5 [112] only performs well at symmetric tasks, while GTR [111] only at asymmetric tasks despite both using T5 [128] as their base model. Recent embedding models have been able to unify symmetric and asymmetric tasks into a single model by differentiating them in the prompt [169, 159]. Further, including detailed instructions in the prompt has allowed unifying practically any embedding task into a single model [143].

Generative Models used to be tailored to a single task, such as translation [146] or question answering [173]. McCann et al. [97] cast multiple generative tasks as question answering to unify them within a single model, however, performance was still limited and it did not generalize to arbitrary tasks. Large-scale self-supervised pretraining has enabled the use of a single large language model (LLM) for practically any generative task [13, 22, 126, 11, 135, 54, 141, 1, 86]. However, using an LLM without careful prompting often leads to poor performance [132, 100]. Finetuning LLMs on instructions has emerged as a method to significantly ease the usage of the models to apply them to any generative task with strong results [164, 134, 99, 163, 101, 108, 65, 187, 140, 184].

The two streams of embedding and generative models have respectively been unified into a single model that handles any task within its stream. Unifying the two streams into a single model that handles any task both for embedding and generation is the natural next step toward a general multi-task model. Besides generation, LLMs have also shown promise for text embeddings [104, 110, 70, 88, 87]. SGPT [104] was an early work in that direction. SGPT only changes 0.01% of the parameters of a large language model via BitFit [177] to adapt it to produce well-performing embeddings. Thus, one only needs to change this small amount of parameters to switch from one stream to the other. However, SGPT still required separate asymmetric and symmetric models and did not consider the full breadth of embedding tasks. GRITLM addresses these deficiencies. GRITLM does not require switching out biases, leverages instructions to handle asymmetric or symmetric use cases, and considers the full breadth of embedding and generative tasks.

7 Conclusion

We present GRIT to unify text embedding and generation, and thus all text-based language problems, into a single model: GRITLM. GRITLM 7B achieves state-of-the-art performance on the Massive Text Embedding Benchmark among open models, while at the same time beating all generative models up to its size. Notably, it matches the performance of otherwise equivalent embedding-only and generative-only variants allowing us to unify the two streams at no performance loss. By adding only 5B parameters at inference, GRITLM 8X7B is the best open generative language model among the many we have tried including much larger models based on Llama 2 with 70B parameters. Unlike the other generative models, GRITLM 8X7B also boasts very strong embedding performance thanks to the GRIT approach. Due to its unified capabilities, GRITLM can be used as both the Bi-Encoder and Cross-Encoder in a reranking pipeline leading to performance improvements on 15 out of 16 retrieval datasets. Further, we conduct extensive ablations uncovering key insights for researchers of both embedding and generative models: causal language models for embedding should be finetuned with bidirectional attention and mean pooling, embedding performance of language models before fine-tuning is not predictive of embedding performance after fine-tuning, embedding models can be trained in BF16 mixed precision without performance loss, generative models should be instruction- tuned with some form of token level loss, etc. Finally, we show that GRIT simplifies the field using the example of RAG. By unifying the retriever and reader into a single model, GRITLM allows caching operations leading to inference speed-ups of > 60% for long sequences at no performance loss with GRIT Doc Caching.

Appendix

D. Evaluation

For evaluating GRITLM, we select the most commonly used embedding and generative benchmarks:

Embedding To evaluate embedding performance we use the 7 main tasks from MTEB [107]. (1) Classification (CLF): A logistic regression classifier is trained on embeddings from texts with different labels. The classifier is scored with F1. (2) Clustering (Clust.): K-means clustering is performed on embeddings from different sources. The agreement of the clusters with respect to the source labels is scored with V-measure. (3) Pair Classification (PairCLF): The cosine similarity of two embeddings with a binary label is computed. The optimal similarity threshold across all samples is found and scored with AP (average precision). (4) Reranking (Rerank) A query embedding and reference embeddings are compared with cosine similarity. The similarities are scored versus the ground truth ranking of the references via MAP (mean AP). (5) Retrieval: A query embedding and embeddings of references are compared with cosine similarity. The position of the correct reference(s) in the top ten with the highest cosine similarity is scored with nDCG@10 (normalized discounted cumulative gain). (6) STS: The cosine similarity of two embeddings is compared with a ground truth continuous score of their similarity and scored with Spearman correlation. (7) Summarization (Summ.) Human-written and machine-written summaries of the same text are embedded. The cosine similarity of the embeddings is compared to human ratings of the machine summaries and scored with Spearman correlation. Among the tasks, Reranking, Retrieval, and Summarization are asymmetric i.e. there are two different kinds of embeddings: queries and documents. Others are symmetric i.e. there is only one kind. We use instructions for every dataset specified in §O.1. Notably, for some models, we use different instructions for query and document embeddings when dealing with asymmetric tasks. The datasets within each task cover diverse domains ranging from scientific papers to casual conversations.

Generation For evaluating the generative performance of GRITLM, we largely follow the evalua- tion setup of Tülu [162, 64] using open-source frameworks [49, 10]. (1) Multiple-Choice Question Answering via MMLU [60]: Models are tasked to answer knowledge- intensive questions from different fields, such as humanities, social sciences, and hard sciences. No few-shots are provided and answers are evaluated with exact match. (2) Problem solving via GSM [26]: Models are tasked to solve a math problem requiring multi-step reasoning. 8 few-shot (FS) examples with chain-of-thought reasoning (CoT) [165] are provided and exact match is measured.

E Ablations Detailed Results

We display a breakdown of the results from Table 5 in Table 9 to Table 20. For MTEB per- dataset results, we refer to Appendix F, the MTEB leaderboard (https://huggingface.co/ spaces/mteb/leaderboard) and our released result files (https://huggingface.co/ datasets/GritLM/results).

Table 9: Unified models attention and pooling ablations. The sequence of Cs and Bs refers to the attention mechanism for (from left to right): Emb instruction, Emb sample, Gen instruction, Gen sample, where C=Causal, B=Bidirectional, Emb=Embedding and Gen=Generative. WM, LT and M refer to position-weighted mean, last token and mean pooling, respectively.

Table 11: Generative-only models attention ablations. The sequence of Cs and Bs refers to the attention mechanism for (from left to right): Gen instruction, Gen sample, where C=Causal and B=Bidirectional. IL=interleaved, whereby the bidirectional attention is interleaved with causal attention in multi-turn samples (bidirectional for instructions, causal for answers). This allows for faster generation in multi-turn settings as the kv-cache of the answer can be reused.

Table 12: Base model ablations. Models are only trained for 100 steps and with other sub-optimal settings, such as the Zephyr format, that were rectified through later ablations. STS Spear. 10

Table 13: Embedding-only models embedding dataset ablations. NNI = No Natural Instructions, corresponding to not including natural instructions in the data. II = evaluating with the Instructor-XL instructions [143]. Other models use our new structure with domain, intent, and unit depicted in Figure 3. Thus, MEDI2 NNI II and MEDI2 NNI are the same model and only differ in the evaluation instruction set. Task (→) Metric (→) Dataset # (→)

Table 14: Unified models embedding dataset ablations. The sequence of Cs and Bs refers to the attention mechanism for (from left to right): Emb instruction, Emb sample, where C=Causal, B=Bidirectional, and Emb=Embedding. WM and M refer to position-weighted mean and mean pooling, respectively. MEDI2BGE corresponds to our MEDI2 dataset with negatives coming from the BGE training dataset MTP [169].

Previous: Hack Websites using LLM Next: GLoRe

post contain ""

    No matching posts found containing ""