00:00:00

Share Your Feedback 🏝️

Google Tandem Transformers

Google Tandem Transformers

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Model | Open AI - Sora Next: Mistral Large

Google Tandem Transformers

  • Related Project: Private
  • Category: Paper Review
  • Date: 2024-02-14

Tandem Transformers for Inference Efficient LLMs

  • url: https://arxiv.org/abs/2402.08644
  • pdf: https://arxiv.org/pdf/2402.08644
  • abstract: The autoregressive nature of conventional large language models (LLMs) inherently limits inference speed, as tokens are generated sequentially. While speculative and parallel decoding techniques attempt to mitigate this, they face limitations: either relying on less accurate smaller models for generation or failing to fully leverage the base LLM’s representations. We introduce a novel architecture, Tandem transformers, to address these issues. This architecture uniquely combines (1) a small autoregressive model and (2) a large model operating in block mode (processing multiple tokens simultaneously). The small model’s predictive accuracy is substantially enhanced by granting it attention to the large model’s richer representations. On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko demonstrates a 3.3% improvement in next-token prediction accuracy over a standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter model with comparable downstream performance. We further incorporate the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model. This ensures that the Tandem of PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream task accuracy.

Contents

TL;DR


  • 대규모 언어모델(LLM)의 자연어 이해(NLU)와 자연어 생성(NLG) 작업 분리를 위한 새로운 아키텍처 ‘Tandem Transformers’ 소개.
  • 두 개의 모델(ML과 MS)을 이용하여 각각의 처리 용량을 최적화하고 효율성을 높임.
  • SuperGLUE 등의 벤치마크 데이터셋을 통해 아키텍처의 효율성 및 정확성 검증.

1. 서론

언어 모델의 인퍼런스 효율성은 중요한 이슈이며, 특히 대규모 언어모델에서는 더욱 그렇습니다. 기존의 자기회귀 방식은 연산 과정에서 높은 비용을 초래합니다. 본 연구에서는 자연어 이해(NLU)와 자연어 생성(NLG) 간의 연산 요구 사항이 상이하다는 점에 착안하여, 이 두 작업을 분리 처리할 수 있는 새로운 아키텍처 Tandem Transformers를 제안합니다.


2. 방법

2.1 Tandem Transformers 아키텍처 설계

아키텍처 구성

Tandem Transformers는 크게 두 부분으로 구성됩니다.

  1. ML (Large Model): 입력된 프롬프트를 처리하고, 중간 표현(representations)을 생성합니다.
  2. MS (Small Model): ML이 처리한 표현을 바탕으로 γ 토큰을 순차적으로 생성합니다.

이 구조는 ML이 주로 NLU 작업을 수행하고 MS가 NLG 작업을 수행함으로써, 각 모델의 처리 능력을 최적화합니다.

수학적 모델링

Tandem Transformers의 수학적 모델은 다음과 같이 표현될 수 있습니다.

\[\text{ML}_{\text{out}} = \text{Encoder}(x_{1:n})\] \[\text{MS}_{\text{out}} = \text{Decoder}(\text{ML}_{\text{out}}, x_{n+1:n+\gamma})\]

\(x_{1:n}\)은 입력 토큰 시퀀스, \(\gamma\)는 MS가 한 번에 생성하는 토큰의 수를 의미합니다. MS는 ML이 생성한 표현을 참조하여 다음 \(\gamma\) 토큰을 생성합니다.

2.2 훈련 및 인퍼런스 프로세스

훈련 프로세스

훈련 시 ML과 MS는 독립적으로 훈련되며, 각 훈련 과정에서는 다음과 같은 손실 함수를 최소화합니다.

\[L = \sum_{i=1}^{N} \text{CrossEntropy}(\text{MS}_{\text{out}}, y_i)\]

\(y_i\)는 실제 토큰 레이블입니다. 이 과정은 MS가 ML의 표현을 효과적으로 활용하여 정확한 토큰을 생성할 수 있도록 합니다.

인퍼런스 프로세스

인퍼런스 시, ML은 입력된 프롬프트를 처리하고, 이 표현을 MS에 전달합니다. MS는 이 정보를 바탕으로 초기 \(\gamma\) 토큰을 생성하고, 이 과정은 응답이 완성될 때까지 반복됩니다.


3. 실험

3.1 데이터셋 및 벤치마크

다양한 벤치마크 데이터셋에서 Tandem Transformers의 성능을 평가했습니다. SuperGLUE, TydiQA 등의 데이터셋을 사용하여 ML과 MS의 연동이 모델의 성능과 효율성에 미치는 영향을 분석했습니다.

3.2 성능 평가 및 분석

실험 결과, Tandem Transformers는 기존 모델에 비해 인퍼런스 속도가 약 1.16배 향상되었으며, 정확도 면에서도 경쟁력 있는 결과를 보였습니다. 이는 탠덤 아키텍처가 효율성과 정확성을 동시에 달성할 수 있음을 시사합니다.


4. 결론 및 향후 연구 방향

본 연구에서 제안한 Tandem Transformers는 대규모 언어모델의 인퍼런스 효율성을 개선하는 데 유의미한 접근 방식입니다. 더욱 정교한 모델 설계와 다양한 데이터셋에 대한 추가 실험을 통해 모델의 범용성과 신뢰성을 더욱 강화할 예정입니다.


1 Introduction

Despite significant advancements in inference optimization techniques (Leviathan et al., 2023; Du et al., 2022; Liu et al., 2023), the widespread deployment of very large language models (LLMs) remains hindered by their substantial computational costs. A key factor contributing to high inference latency is the autoregressive generation process, where tokens are produced sequentially. This inherent limitation restricts the full utilization of ML accelerators (GPUs/TPUs), which are optimized for matrix-matrix multiplications rather than the matrix-vector operations prevalent in LLMs. Consequently, prompt processing (where all tokens are handled simultaneously) is significantly more efficient than autoregressive response generation.

On the other hand, it is not well understood how much capacity is required to understand the prompt/query/prefill (natural language understanding aka NLU) vs the capacity required to generate a response (natural language generation aka NLG). Current decoder-only LLM architectures tightly couple both these tasks.

Tandem Transformers. In this work, we investigate this fundamental question from an efficiency perspective. We propose Tandem Transformers, a novel architecture that allocates significantly more model capacity to prefill processing (NLU) compared to response generation (NLG). Our goal is to understand whether high-quality response generation can be maintained under this design. Concretely, Tandem transformers consists of two models – a small model MS and a large model ML, where:

  1. ML processes the prompt/query. 2. MS generates the first γ tokens (called a block) autoregressively, while attending to the prompt/query representations generated by ML.
  2. ML processes the γ tokens generated by MS together (i.e., in a non-autoregressive fashion) and computes their representations.
  3. MS then generates the next γ tokens autoregressively, while attending to representations of all tokens until the previous prefill block generated by ML.
  4. This process is repeated until the response generation is complete. Tandem Transformer Training. We introduce a projection layer to align the potentially higher-dimensional representation space of ML with that of MS. For efficiency, we initialize ML and MS as independently trained, standard decoder-only models.

Experiments with Tandem (PaLM2-Bison, PaLM2-Gecko) (where PaLM2-Gecko < PaLM2-Otter < PaLM2-Bison, in terms of model size) demonstrate that the capacity needed for NLU vs NLG aspects of LLMs can indeed be decoupled, leading to a more efficient architecture without significant accuracy loss. Evaluation on benchmark datasets show that Tandem (PaLM2-Bison, PaLM2-Gecko) with block length γ = 3 is substantially more accurate than PaLM2-Gecko, and comparable to PaLM2-Otter, while achieving approximately 1.16× lower inference latency than PaLM2-Otter. For example, on SuperGLUE (Wang et al., 2019), the tandem model is 3% less accurate than PaLM2-Bison, 16% more accurate than PaLM2-Gecko and 0.2% less accurate than PaLM2-Otter, with 1.16× speedup over PaLM2-Otter.

Encoder-Decoder. In contrast to an encoder-decoder architecture which would only process query/prefix through an encoder and then generate the entire response through a decoder, Tandem is able to generate only block-size γ (say = 3) tokens through the secondary model MS and then refresh the entire prefill representations using primary model ML which is critical to maintaining high accuracy. That is, by setting γ = 0, Tandem can mimic decoder-only ML model while setting γ → ∞ leads to decoder-only MS model.

Tandem + SPEED. For applications requiring output identical to the primary model, we propose Tandem + SPEED. The speculative decoding (SPEED) framework (Leviathan et al., 2023) leverages the small model MS in Tandem to generate draft tokens, which are then verified by the large model ML. Crucially, the ability of MS in Tandem to attend to ML’s representations significantly improves draft quality, reducing verification overhead compared to standard SPEED. For example, on the Reddit Posts dataset, using the MS in Tandem as the drafter model in SPEED leads to about 11.24% higher per-block acceptance rate compared to a vanilla secondary model. Finally, we show that Tandem transformers can be further improved using logit distillation and their efficacy within SPEED can be improved using an adaptive block length parameter.

Contrast with Parallel Decoding and Distillation. Recently multiple speculative or parallel decoding style techniques have been proposed in the literature (Leviathan et al., 2023; Kim et al., 2023; Stern et al., 2018). These techniques attempt to generate a draft of tokens using a relatively inexpensive drafter model. Parallel decoding attempts to generate multiple drafter tokens in parallel by learning classifiers on top of output of primary model ML while speculative decoding could provide significantly better drafts by using a small, but auto regressive model. In contrast, Tandem is a stand alone model on its own and doesn’t natively require verification by ML to generate reasonable outputs (see benchmark numbers in Table 3). Furthermore, Tandem + SPEED is able to use representations of ML while still generating tokens autoregressively, which is able to provide overall much better tradeoff in terms of token quality vs model latency for the drafter. Finally, recent works have also shown the efficacy of logit distillation for training better drafter models within SPEED (Zhou et al., 2023). Our approach is complementary, and can be combined with distillation.

Empirical Results for Tandem + SPEED. Finally, we conduct extensive latency evaluation on TPUv5e for both standa alone and SPEED versions of Tandem (PaLM2Bison, PaLM2-Gecko) with PaLM2-Bison and PaLM2Gecko being the primary ML and secondary MS model, respectively. In particular, on multiple datasets, we observe that Tandem + SPEED with distillation can be at least 2.19× faster than the baseline PaLM2-Bison model while ensuring same output quality. Furthermore, compared to standard SPEED with MS being secondary model, our model is 1.11× to 1.17× faster. An adaptive block length in SPEED further helps reduce Tandem’s latency by 1.04× to 1.09× on multiple datasets. Finally, we demonstrate that our results also hold for practical settings like batch-size > 1.

Contributions. In summary, following are the key contributions of the work:

  1. Tandem architecture: A novel architecture to disaggregate prompt/prefill processing capacity from response generation.
  2. Tandem + SPEED: Improved speculative decoding leveraging Tandem’s superior drafting for guaranteed output equivalence with lower latency.
  3. Adaptive Block Length: Enhances Tandem + SPEED by dynamically adjusting drafted token count.
  4. TPUv5e evaluation: End-to-end evaluation on TPUv5e with PaLM2-Bison being the primary model. A distilled Tandem + SPEED is 2.4x faster compared to vanilla PaLM2-Bison model and 1.11 − 1.17× faster compared to distilled MS + SPEED (Leviathan et al., 2023) applied in the same setting.

Outline of the paper: The rest of the paper is organized as follows. We briefly review related work in Section 2. In Section 3, we present the main ideas and the design of Tandem transformers architecture. Section 4 presents the experimental results on Tandem transformers. We then conclude with some future directions in Section 6.

Encoder-Decoder models: Encoder-decoder transformer architectures are widely used for specific tasks such as machine translation (Vaswani et al., 2017). Given the computational inefficiency of autoregressive decoding, several works have explored using a large encoder with a small decoder. Our work can be seen as extending these ideas to use an encoder-decoder model for the decoder itself.

Mixture of experts (MoE)/Sparsity based approaches: Mixture of experts (Du et al., 2022) and sparsity based approaches (Li et al., 2022) have also been studied for optimizing inference cost of LLMs. However these approaches are complementary to the approaches proposed in our paper. For example, either or both the large model ML and small model MS can be an MoE or sparse model.

Distillation: Since the seminal paper (Hinton et al., 2015), distilling the knowledge of a large model to a smaller model by using the logits of large model as a training target has been widely used in several settings. Our work can be seen as a more general version of distillation for transformers, where the small model can directly refer to large model representations for tokens from previous blocks. Furthermore, our experiments (see Section 4) show that our techniques are complementary to logit distillation, and provide additional gains on top of vanilla logit distillation.

Speculative decoding (SPEED): Speculative decoding (Leviathan et al., 2023; Kim et al., 2023) is a framework to reduce inference latency of LLMs without affecting their quality, which has shown substantial improvements in LLM inference. We demonstrate that Tandem transformers can be used within the SPEED framework, improving the efficacy of SPEED. While multiple drafters have been explored in the context of SPEED such as a stand alone model (Leviathan et al., 2023), retrieval based (He et al., 2023), distillation based (Zhou et al., 2023), as of now distillation based drafters seem to perform the best. As we demonstrate in Section 4, Tandem is able to provide significantly more powerful drafter thus providing better draft of tokens leading to lower latency.

3 Tandem Transformers

In this section, we will describe tandem transformers architecture, it’s training and inference.

Standard (decoder) transformer : Given a sequence \(t1, t2, · · · , tS\) of S tokens as inputs, where ti corresponds to the ith token id, a standard decoder transformer with L layers executes as follows:

where x(0) is the representation after the jth layer and \(Atn(j)(·\\|·)\) and \(FF(j)(·)\) are the jth attention and feedforward layers respectively (Vaswani et al., 2017). Note that the attention is purely causal (i.e., the ith token attends only tokens tk for k ≤ i) since we are considering a decoder-only transformer.

Tandem transformer: A Tandem transformer model comprises of a primary model ML and a secondary model MS. Typically, \(SIZEOF(ML) ≫ SIZEOF(MS)\). Given a sequence of tokens \(t1, t2, · · · , tS\) as inputs, the primary model ML processes these tokens just like a standard (decoder) transformer (1).

Let γ be the block length parameter, and LS and LL be the number of layers of the secondary model and primary model, respectively. Let \(ℓ : [LS] → [LL]\) be a layer assignment function from secondary model to primary model. The secondary model attends to the primary model’s representations for all tokens from the previous blocks.

Training: Given a block length parameter γ, we partition the training sequence into blocks, each consisting of γ consecutive tokens. Consider the autoregressive prediction of the jth token (for some j ≤ γ) within the ith block. The input to the secondary model MS is the previous token. Crucially, within the attention blocks of MS:

  • Key/value pairs for all tokens up to the jth token in the current block are computed by MS itself.
  • Key/value pairs for tokens in previous blocks are computed by the primary model ML. A projection/tandem feedforward layer then aligns the representational dimensions from ML to MS, as described in Equation (2).

We explore multiple training configurations for Tandem transformers:

  • Primary Model Frozen: Only the secondary model parameters MS and the tandem feedforward layer FF(j) S are updated. Loss is applied solely to the secondary model’s output \(y(LS)\) (Equation (2)).

The Tandem-Distil model follows a two stage training setup, where initially it is trained to minimize the CE loss with respect to the ground truth labels, and in the second stage a weighing factor of λ = 0.5 is used to balance the CE loss with respect to ground truth labels and the CE logit distillation loss with respect to the outputs of the PaLM2-Bison model. We note that Tandem-Distil in general performs better than Tandem-CE.

  • Both Models Trained, Loss on Secondary Outputs: Similar to the above, loss is applied to the secondary model’s output. However, both ML and MS, along with FF(j) S are trained.
  • Both Models Trained, Loss on Both Outputs: The combined loss incorporates both the primary model’s outputs x(LL) and the secondary model’s outputs y(LS).

For training efficiency, we initialize the primary and secondary models with high quality pretrained checkpoints, and then continue pretraining the tandem architecture for a small number of additional steps. In particular, we use the pretrained PaLM2-Bison and PaLM2-Gecko checkpoints to initialize ML and MS respectively. In this setting, we found that Primary Model Frozen approach provides the best accuracy. Our Tandem-CE model is obtained by using cross entropy (CE) loss on the output of the secondary model as described above.

Tandem-Distil: To further enhance MS’s quality, we apply a distillation loss on its predictions, using the logits of the pretrained ML as targets with CE loss. This aligns naturally with the Tandem architecture, as MS already incorporates representations from ML.

Inference. The inference process begins with the primary model (ML) processing the prompt and generating representations for all prompt tokens. The secondary model (MS) then autoregressively generates the first block of γ response tokens. Crucially, MS attends to the primary model’s representations, aligned via the projection layer.

Once the first response block is generated, the primary model (ML) processes these tokens and computes their representations. We consider two inference configurations:

  • Representation Generation + Token Prediction (Figure 2): ML additionally predicts the next token.
  • Representation Generation Only (Appendix B, Figure 4): ML solely generates representations for the response block.

In both configurations, the representations generated by ML are used by the secondary model (MS) to generate the subsequent block of γ response tokens. Also note that, as in training, MS attends to its own representations for all previous tokens within the current block.

Figure 2. Inference of Tandem transformers with free token from the primary model ML. (left) First block prediction. (right) Second block prediction. Given the query The Himalayas are a mountain range separating the, ML first processes this query and produces the first response token plains. When we use this prediction from ML, this is directly fed as an input to the secondary model MS, which autoregressively produces of India for the first block with γ = 2. In the second block, the entire response from the first block plains of India is fed to the primary model ML, which again produces the next response token from, and then the secondary model MS produces the next two tokens of the block the Tibetan autoregressively. The eventual output of the model will be plains of India from the Tibetan ….

To disaggregate query and response generation, we use Representation Generation Only for processing the input query/prefix. However, for subsequent blocks where the prefill (query+generated response till this point) is processed, we use Representation Generation + Token Prediction from ML.

Depending on the training protocol – specifically, whether primary model outputs are reliable – we may optionally allow the primary model (ML) to generate the first token of the subsequent block (processing γ + 1 tokens). Crucially, in this scenario, we must ensure the following: the keys and values associated with the next block’s first token, computed by ML, are not overwritten when the secondary model (MS) executes its attention layers.

Inference-Time Block Length Flexibility. While we train Tandem transformers with a fixed block length γ, the architecture supports arbitrary γ values during inference. Larger γ values generally improve efficiency by maximizing the primary model’s (ML) utilization of accelerator hardware. Although Tandem is trained with a fixed γ, in SPEED evaluations we find that the optimal γ is often much larger, indicating the robustness of Tandem to changes in γ at inference time.

3.1. Tandem + SPEED: Tandem in the speculative decoding framework

SPEED mitigates the inefficiency of autoregressive generation using a smaller drafter/secondary model to generate tokens and a larger verifier/primary model to confirm them. SPEED guarantees output quality matching the verifier, but its efficacy hinges on the drafter’s ability to generate long, accurate draft sequences. Tandem transformers are uniquely

Given a Tandem model, we use ML to process the query/prefix and generate representations for them. MS uses these and produces a draft for the first γ tokens autoregressively. ML then verifies this entire block simultaneously and identifies the first location i where the draft token is deemed incorrect by ML (i = γ + 1, if all the draft tokens are verified successfully). We take the output of the large model for the ith token, and the small model MS then continues to generate draft tokens from the (i + 1)th position onwards, while using the representations of all the previous tokens from the large model ML. This process continues until a full response is generated.

The above process can be generalized to the setting, where we generate multiple full responses for the same query, we refer to it as num-samples, for example to eventually rank these responses and select the “best” response (Mudgal et al., 2023). In this case, the location of the rejected token can vary across the different samples being generated.

Similarly, the above approach generalizes to larger batch sizes as well, when we are simultaneously processing multiple queries together. Practical systems potentially use both num-samples and batch-size to be > 1 but latency gains for Tandem + SPEED depend on overall batch-size which is num-samples × batch size. So, for simplicity we focus only on num-samples > 1 and fix batch-size to be 11.

Adaptive Block Length: While standard SPEED uses a fixed block length γ, we introduce an adaptive approach. We train a relatively small 2-layer multi-layer perceptron – router MLP – model to predict whether the current draft token from MS is likely to be accepted by the primary model ML. At each timestep, we compare the prediction of this small model to a threshold τ , deciding whether to: a. Verify with ML, or b. Continue drafting with MS.

1 Note that it is more challenging to obtain latency improvements with increasing num-samples, compared to that in batch size since, even without any of these optimizations such as SPEED etc., larger num-samplesobtain better efficiency on all layers while larger batch size obtains better efficiency only on feedforward and softmax layers, and not the attention layer.

Table 1. Accuracy and cross entropy (CE) loss of Tandem transformers with respect to ground truth labels as well as the predictions of the primary model ML, PaLM2-Bison. As is clear from the results, the Tandem model of PaLM2-Gecko and PaLM2-Bison substantially outperforms the stand alone PaLM2-Gecko model.

Input features to the router MLP are: MS’s entropy over the current token’s vocabulary distribution, top-k probabilities for the current token for an appropriate k, and MS’s model embeddings corresponding to these top-k most probable tokens. We train the router MLP to predict the probability of disagreement using cross-entropy loss, with ground truth being: T V (yS j ) is the total variation (TV) distance between the output logits of MS and ML for the jth token.

4 Experiments

In this section, we present experimental results evaluating Tandem transformer models. Except for the new architecture of Tandem transformers, we generally follow the same training protocols as described in (Anil et al., 2023), including the training dataset, optimizer, etc.

Further Training Details. For both Tandem-CE and Tandem-Distil, we initialize the secondary model MS to be the pretrained PaLM2-Gecko, while freezing primary model ML to be the pretrained PaLM2-Bison (Anil et al., 2023). The projection/Tandem feedforward layers are chosen to be linear layers and initialized randomly. Both the Tandem models – Tandem-CE and Tandem-Distil– are trained with a block length of γ = 2. For our evaluation within the SPEED framework, we consider a logit distillation version of PaLM2-Gecko, called PaLM2-Gecko-Distil, which is initialized with the PaLM2-Gecko model and then trained using logit distillation, similar to the second phase of training of the Tandem-Distil model, since distillation has been shown to help improve the secondary models in SPEED (Zhou et al., 2023).

Adaptive block length in SPEED. We train a small, 2-layer MLP model to predict whether the current drafter token from MS is likely to be accepted by primary model ML. We set τ = 0.8 as the threshold to determine if MS can continue generating more tokens.

4.1 Performance Evaluation

We compare the performance of Tandem-CE and TandemDistil against PaLM2-Gecko, PaLM2-Gecko-Distil, PaLM2Otter and PaLM2-Bison on several downstream tasks as well as in terms of latency.

For downstream task evaluation, we compare on SuperGLUE (Wang et al., 2019), TydiQA (Clark et al., 2020), a large collection of generation tasks, which we call Gen-tasks (comprising of SQuADv2 (Rajpurkar et al., 2018), Natural Questions (Kwiatkowski et al., 2019), TriviaQA (Joshi et al., 2017), WebQuestions (Berant et al., 2013) and Lambada (Paperno et al., 2016)), MBPP (Austin et al., 2021), and WMT22 (Zerva et al., 2022). WMT22 results are averaged over x → en translations for different languages x. For TydiQA, we pass the gold passage as part of the input, and report the average F1-score over all languages. For SuperGLUE and Gen-tasks, we follow the experimental settings as described in (Anil et al., 2023) and report the average results. We report 1-shot evaluations for all performance evaluation experiments.

4.2 Latency Evaluation

We perform latency evaluation in two different settings. In the first setting, we use Tandem-CE and Tandem-Distil as secondary models within SPEED, with PaLM2-Bison as the primary model. Note that the SPEED framework guarantees that the outputs will be of the same quality as the primary model PaLM2-Bison. For comparison, we use PaLM2-Bison as a stand alone model, as well as SPEED with PaLM2-Bison as primary and PaLM2-GeckoDistil as secondary as our baselines. In the second setting, we evaluate the latency of Tandem-CE and Tandem-Distil as stand alone models with PaLM2-Gecko, PaLM2-Otter and PaLM2-Bison. All the evaluations are performed on TPUv5e (Cloud).

We evaluate latency on the test sets of CNNDailyMail (Hermann et al., 2015), and Reddit Posts summarization (Kim et al., 2018), and 1000 prompts from the 1 Billion Word

Table 2. End-to-end latency gain of various secondary models, when used within the SPEED framework with PaLM2-Bison as the primary model. The secondary models we consider are: PaLM2-Gecko-Distil and Tandem-Distil. Since Tandem-Distil has better acceptance rate compared to PaLM2-Gecko-Distil, e.g., for γ = 5, Tandem-Distil has, on average, 11.24% more tokens accepted compared to PaLM2-Gecko-Distil, for each secondary model, and on each dataset, we use the optimal block length γ parameter. We consider two settings, one where we generate a single response and another where we generate 4 responses for the given query. The third and fourth column provide the speedup by using PaLM2-Gecko-Distil and Tandem models respectively, with respect to the PaLM2-Bison model. The last column indicates the relative gain of using the Tandem model as the secondary model in SPEED, instead of PaLM2-Gecko-Distil. The results clearly demonstrate the additional improvements Tandem obtains, on top of logit distillation.

Table 3. Standalone evaluation of the Tandem model. The first five rows present downstream evaluations of the Tandem transformers on a variety of generative and ranking tasks. We see that the Tandem model substantially improves upon the performance of stand alone PaLM2-Gecko model, and is on par with the PaLM2Otter model. On the other hand, the latency evaluations in the last row demonstrate that the Tandem model is about 1.16x faster than the PaLM2-Otter model.

4.3 Evaluation Results

We now present results of our evaluation of tandem transformers.

Pretraining metrics: Table 1 presents a comparison of accuracy and cross entropy (CE) loss of various baselines as well as tandem models, with respect to both the ground truth labels as well as the primary model ML’s predictions. As we can see, tandem transformers performs better than logit distillation, while combining logit distillation with tandem transformers, further improves its performance.

Latency within SPEED: Table 2 presents results on the latency of Tandem transformers within the SPEED framework. Specifically, we compare the speedup obtained over the PaLM2-Bison model, by using SPEED with PaLM2Gecko-Distil as the secondary model vs Tandem-Distil as the secondary model. The results clearly demonstrate the improvements obtained by tandem on top of distillation. Table 8 in Appendix A presents the speedups computed only over the decode time (i.e., excluding the query processing time). Note that since the SPEED framework guarantees that the outputs are of same quality as those of the primary model, PaLM2-Bison, the latency improvements given by the tandem model do not have any quality tradeoffs.

Evaluation as a standalone model: We evaluate the Tandem model as a stand alone model in its own right. Table 3 presents a comparison of both downstream evaluations on standard downstream benchmarks, as well as latency evaluations. As can be seen, the Tandem model substantially improves upon the downstream performance of the baseline model, and is almost on par with the PaLM2-Otter model. Detailed results presented in Tables 10 and 11 in Appendix A show that, in some cases, the tandem model is closer to the PaLM2-Bison model itself. At the same time, the tandem model is about 1.16x times faster compared to the PaLM2-Otter model, making it a compelling candidate for stand alone deployment as well.

Table 4. End-to-end latency speedup obtained by Tandem-Distil + SPEED + Adaptive γ on different evaluation datasets. The second and third columns show the speedup over the stand alone PaLM2Bison model and Tandem-Distil + SPEED model respectively. The latency is evaluated for generating a single response. Adaptive γ enables us to use much larger block lengths without losing performance. For example, on the Reddit dataset, the optimal γ for the tandem model in the standard SPEED setup is 7, while adaptive γ obtains better results with γmax = 17.

Adaptive block length: We now present a way to improve the performance of SPEED with adaptive block lengths (Adaptive γ or AG), where after every token predicted by the secondary model, we use a small, inexpensive router to determine whether to continue predicting with the secondary model, or verify the tokens generated so far with the primary model. Table 4 presents the speedup obtained by Tandem-Distil + SPEED + AG compared with the PaLM2-Bison model as well as the Tandem-Distil + SPEED model. Table 9 in Appendix A presents the speedup as measured only over the decode component of the latency i.e., excluding query processing time.

In Table 5, we present the number of primary model, and secondary model runs for Tandem-Distil + SPEED and Tandem-Distil + SPEED + Adaptive γ. The results put forth the benefits of using an adaptive block length, since it drastically reduces the number of secondary model runs while slightly increasing the number of primary model runs.

5 Deep Tandem Transformers

a sketch of the next block of tokens in parallel, while MS does the actual sampling in an autoregressive manner. More concretely, we have: In tandem transformers, we used the large model ML to process tokens in blocks, so that the small model MS can use large model’s representations for all the tokens from previous blocks. In this section, we present a different approach to use ML and MS in tandem, where ML predicts

Table 5. Primary model and secondary model runs for TandemDistil and Tandem-Distil + AG on the LM1B benchmark. Note that these results are obtained for num-samples= 1. We can see that the number of secondary model runs have come down by 90 whereas the number of large model runs has gone up only by 3. The results clearly showcase that an adaptive block length can significantly cut down on the number of secondary model runs and give non-trivial latency gains.

The eventual output of the model is y(LS ) which is its prediction of the ith token in the input sequence. This is pictorially depicted in Figure 3.

5.1 Experimental results for deep tandem transformers

In this section, we present preliminary experimental results on deep tandem transformers compared with the standard architecture. For this section, we consider the LaMDA models along with the training protocol as described in (Thoppilan et al., 2022). In particular, we consider the 1B parameter model from the LaMDA family and construct a deep tandem version of it by splitting the 16 layers equally between ML and MS (so each of them has 8 layers), and with block k.

Table 7. Pretraining log perplexity of an autoregressive model compared to a block prediction model with block length γ = 8. Both models are taken to have the same architecture as LaMDA-1B, except for the difference between block prediction and autoregressive prediction.

5.2 Importance of the small autoregressive component

In this section we present the log perplexity achieved by a block prediction model similar to (Stern et al., 2018), where we predict the next block of γ = 8 tokens simultaneously. In other words, we directly train the output x(LL) of the large model in Equation (3) to predict the ith token x [i]. The CE loss of the resulting model, and its comparison with a fully autoregressive model is presented in Table 7. As we can see, the cross entropy loss of such a model is much higher compared to that of the original model, which is fully autoregressive.

6 Conclusions and Discussion

In this work, we introduce a novel architecture, Tandem transformers, which combines a small autoregressive model with a large model operating in block mode. Tandem transformers substantially boost the small model’s predictive accuracy by allowing it to attend to representations from the large model. In our experiments, a Tandem model comprising of PaLM2-Bison and PaLM2-Gecko substantially improves over a standalone PaLM2-Gecko, and gives comparable performance to the PaLM2-Otter model, while being 1.16× faster than the PaLM2-Otter model. When used within the SPEED setup as a secondary model, the distilled Tandem PaLM2-Gecko model gives around 1.14× speedup over a distilled PaLM2-Gecko model. We further improve our Tandem model through an adaptive block length procedure in SPEED and obtain around 1.22× speedup over using PaLM2-Gecko-Distil as the secondary model.

Limitations and Future directions

  • Other variants of tandem: In our current approach, we use the large model only through its representations of the past tokens. Is it possible to use the large model to also generate a plan for the future γ tokens along the lines of deep tandem transformers?
  • Alternative to LoRA for fine-tuning: The current approach for fine-tuning a base model for multiple downstream applications is through low rank adaptation (LoRA) (Hu et al., 2021). It will be interesting to explore whether tandem with block length 0 can be an effective alternative to LoRA, while reducing the training cost substantially since backpropagation needs to be done only for the small model.
  • Adaptive γ for larger num-samples/batch-size: While we see promising results with adaptive γ in SPEED for num samples 1, extending it to larger num samples seems challenging. Identifying an effective way of determining when to continue generating with small model vs verifying with large model, in the larger num samples setting, is also an interesting direction of future work.
  • Smaller drafter models in SPEED: Finally, we hope that tandem can enable using even smaller drafter models in SPEED, compared to the ones currently being pursued, leading to both memory as well as latency improvements.

7 Broader Impact Statement

Our work provides a more computationally efficient large language model inference solution, which we hope can bring down carbon emissions associated with LLM inference. It also helps with easier deployment of LLMs, which could have potential societal consequences, that seem difficult to predict.

Previous: Model | Open AI - Sora Next: Mistral Large

post contain ""

    No matching posts found containing ""