Contents
구분 | 설명 | 장점 | 단점 |
---|---|---|---|
기본 아이디어 | 전통적인 트랜스포머 대비 효율성 증대를 위해 글로벌-로컬 self-attention 구조 도입 | 캐시 IO 감소, 계산 비용 절감 | 복잡한 아키텍처 구성 가능성 |
구조 | 블록 트랜스포머: 글로벌 self-attention과 로컬 self-attention 결합 | 문맥 이해력 및 상세한 상호작용 처리 | 상대적으로 구현이 복잡함 |
효율성 및 비용 절감 | self-attention 계산에서 필요한 키-값 로드 감소 | 비용 및 메모리 효율성 증가 | 높은 효율성을 위해 블록 사이즈 및 파라미터 조정 필요 |
디코딩 방법 | 블록 디코더: 글로벌 문맥 모델링 및 로컬 디코딩 | 로컬 self-attention 사용으로 더 빠른 처리 가능 | 블록 길이와 파라미터 분배에 따라 성능 차이 발생 가능 |
성능 | 블록 트랜스포머의 성능은 기존 모델들과 비슷하거나 약간 향상됨 | 동일한 파라미터로도 효과적인 성능 유지 | 초기 설정 및 조정에 더 많은 실험적 접근 필요 |
처리량 및 메모리 | 글로벌-투-로컬 모델링으로 인한 인퍼런스 처리량 및 메모리 사용 최적화 | 특히 대용량 모델에서 더 높은 처리량 및 메모리 효율성 도달 가능 | 적절한 블록 길이 설정 필요 |
응용 및 확장성 | 사전 훈련된 모델을 바탕으로 효율적인 언어 모델링 가능 | 다양한 데이터셋 및 태스크에 적용 가능 | 구조적 복잡성으로 인해 새로운 사용자 또는 연구자에게 접근성이 떨어질 수 있음 |
[기본 구조의 변형]
전통적인 트랜스포머
블록 트랜스포머
[수학적 변형 및 효율성]
전통적인 트랜스포머의 비용 \(\text{Self-Attention Cost} = \text{Number of Tokens} \times (\text{Key-Value Load per Token})\)
블록 트랜스포머의 개선
수학적 표현:
\[\text{Reduced KV Cache IO} = \frac{\text{Original KV Cache IO}}{\text{Block Length}}\]이는 블록 길이에 반비례하여 KV 캐시 입출력을 줄이는 것을 나타냄.
[디코딩 방식의 변화]
전통적인 트랜스포머
블록 트랜스포머
[성능 및 효율성 분석]
블록 트랜스포머는 트랜스포머 모델의 기존 문제점을 개선하기 위해 수학적 및 아키텍처적 변형을 통해 self-attention 메커니즘의 비용을 효과적으로 감소시켰습니다. 이는 인퍼런스 속도와 메모리 효율성을 향상시키는 동시에, 블록 단위의 처리를 최적화하여 전체 모델의 성능을 유지함을 언급합니다.
1 서론
트랜스포머 기반의 자동 회귀 언어 모델(LM)은 모든 이전 토큰에 대한 self-attention 메커니즘을 사용하기 때문에 비용이 많이 듭니다. 이 비용을 줄이기 위해, 디코딩 중에 모든 토큰의 키-값(KV) 상태를 캐시하는 것이 일반적입니다. 그러나 각 디코딩 단계는 단일 토큰의 KV 상태만을 계산하면서도 모든 이전 토큰의 KV 상태를 불러와야 하므로, 이런 KV 캐시 입출력이 인퍼런스 비용을 대부분 차지합니다. 이 문제를 해결하기 위해, 글로벌-로컬 아키텍처를 도입하여 self-attention의 비용을 거의 제거하고자 합니다. 이 아키텍처는 코스 글로벌 모델링을 통해 전체 비용을 줄이고, 로컬 self-attention를 사용하여 과거 토큰의 KV 캐시 계산, 저장, 검색을 할 필요가 없도록 합니다.
\[\text{Self-Attention Cost} = \text{Number of Tokens} \times (\text{Key-Value Load per Token})\]이 식에서 각 토큰에 대해 필요한 키-값 로드의 양을 줄임으로써 총 비용을 감소시킬 수 있습니다. 글로벌-로컬 아키텍처는 각 블록 내의 토큰만을 대상으로 self-attention를 계산함으로써 필요한 키-값 로드의 양을 대폭 줄이고, 따라서 전체 self-attention 계산 비용을 낮출 수 있습니다.
2 블록 트랜스포머
블록 트랜스포머는 계층적 패러다임을 사용하여 글로벌과 로컬 주의 메커니즘을 결합함으로써 전체 문맥을 이해하고 상세한 상호 작용을 각각 별도의 단계에서 처리합니다.
2.1 효율성의 원리
블록 트랜스포머는 글로벌 모델링의 비용을 낮추고 독립적인 블록 내에서 로컬 모델링을 수행하여 비용을 절감합니다. 전통적인 트랜스포머는 모든 이전 토큰에 대한 self-attention로 인해 인퍼런스 시 메모리 오버헤드가 크지만, 블록 트랜스포머는 이 문제를 해결합니다.
\[\text{Reduced KV Cache IO} = \frac{\text{Original KV Cache IO}}{\text{Block Length}}\]이 식은 블록 길이에 따라 KV 캐시 입출력을 감소시키는 것을 나타냅니다.
2.2 Embedder
임베더는 각 블록의 토큰을 입력 블록 임베딩으로 집계합니다. 이는 주로 조회 테이블을 사용하여 훈련 가능한 토큰 임베딩을 검색하고 연결하는 데 사용됩니다.
\[E_{\text{emb}} \in \mathbb{R}^{V \times D_{\text{emb}}}\]\(V\)는 어휘 크기, \(D_{\text{emb}}\)는 토큰 임베딩 차원
2.3 Block Decoder
블록 디코더는 글로벌 문맥을 모델링하고, 이전 블록을 대상으로 self-attention를 적용합니다. 이 과정에서 출력 블록 임베딩을 생성하며, 이는 다음 블록의 토큰 내용을 자동 회귀적으로 디코딩하는 데 필요한 정보를 포함합니다. \(\text{Context Embedding Output} = \text{Self-Attention}(E_{\text{emb}})\)
2.4 Token Decoder
토큰 디코더는 출력된 블록 임베딩을 사용하여 다음 블록의 개별 토큰을 로컬로 디코딩합니다. 이는 글로벌 문맥 정보만을 사용하여 로컬 self-attention를 적용함으로써 이루어집니다.
\[\text{Token Output} = \text{Self-Attention}(\text{Context Block Embedding})\]이런 방법으로 토큰 디코더는 인퍼런스 과정에서 KV 캐시 입출력을 거의 제거함으로써 효율을 크게 향상시킵니다.
3 실험
3.1 실험 설정
본 실험에서는 Pythia의 트랜스포머 구조를 사용하고, Pile 데이터셋에서 기존 및 블록 트랜스포머 모델을 학습했습니다. 모델은 300B 토큰에 대해 사전 학습되었으며, 이는 약 1.5 epoch에 해당합니다. HuggingFace 학습 프레임워크를 사용하였고, 학습에는 A100 GPU 8개를, 인퍼런스 시간 측정에는 H100 GPU를 사용했습니다. 각 섹션의 실험적 세부 사항은 Appendix G에 요약되어 있습니다.
3.2 주요 결과
블록 트랜스포머의 언어 모델링 성능을 테이블 2에서 측정했습니다. 블록 모델은 기존 모델과 동일한 수의 비임베딩 파라미터로 확장되었습니다. 모델은 파라미터가 두세 배 많음에도 불구하고 기존 모델과 비슷한 수준의 혼란도와 정확도를 달성했습니다. 이는 두 개의 별도 디코더가 각 전달 단계에서 더 적은 FLOPs를 소모하므로 주목성 복잡성을 \(1/L_B^2\)와 대략적으로 \(L_B/L\)로 줄이기 때문입니다.
실제 인퍼런스 처리량과 메모리 효율성은 기존 모델에 비해 블록 트랜스포머에서 높았습니다. 최대 처리량은 메모리가 허용하는 각 모델 변형의 최대 배치 크기를 사용하여 측정되었습니다. 본 논문에서 소개한 모델은 특히 두 가지 시나리오에서 파레토 최적성을 달성했으며, 처리량이 최대 25배까지 증가했습니다.
3.3 파라미터 할당 비율 및 블록 길이 분석
언어 모델링 성능에 대한 파라미터 할당 비율의 영향을 탐구했습니다. 훈련 손실은 다양한 할당 비율에 대해 U자형 패턴을 보였습니다. 블록 디코더의 크기가 더 클수록 초기 위치의 손실을 낮추지만, 토큰 디코더의 크기가 더 클수록 후반 토큰의 예측 정확도가 향상됩니다.
3.4 블록 트랜스포머의 구성 요소에 대한 연구
임베더 전략 중에서 가장 효과적인 접근 방식은 룩업 테이블 전략이었습니다. 이 전략은 추가적인 계산 부담 없이 트랜스포머 구조를 간소화하는 데 도움을 줍니다.
3.5 글로벌-투-로컬 언어 모델링 분석
블록 길이를 조정함으로써 언어 모델링의 효율성을 최적화했습니다. 블록 길이가 증가함에 따라 훈련 손실은 로그-선형으로 변화하고 처리량은 기하급수적으로 증가합니다. 이는 글로벌-투-로컬 모델링이 효과적임을 명확하게 보여줍니다.
3.6 IsoFLOP 분석을 통한 인퍼런스 처리량 제약 하에서
이전 연구들은 훈련 FLOPs 예산 내에서 성능을 최대화하는 데 초점을 맞추었지만, 최근 추세는 인퍼런스 처리량 제약을 고려한 모델에도 주목하고 있습니다. 본 연구에서는 IsoFLOP 분석을 통해 모델이 훈련 효율성과 인퍼런스 처리량 사이의 균형을 효과적으로 이룰 수 있음을 보여줍니다.
3.7 바닐라 트랜스포머에서 블록 트랜스포머로의 업트레이닝
기존의 바닐라 트랜스포머에서 시작하여 사전 훈련된 초기화를 활용함으로써, 적은 양의 데이터로도 효과적인 학습이 가능하게 되었습니다. 이 접근 방식은 무작위 초기화 전략보다 우수한 성능 회복을 보여줍니다.
4 토론
4.1 관련 연구와의 비교
성능 비교: MEGABYTE 모델
MEGABYTE 모델 [74]은 글로벌-투-로컬 구조를 채택하지만, 인퍼런스보다는 효율적인 사전 학습에 중점을 둡니다. 따라서 훈련 FLOPs 예산 내에서, 그들은 최적으로 여겨지는 6:1 비율을 바탕으로 큰 블록 디코더의 필요성을 언급합니다. Figure 5b에서 보듯이, 토큰
레벨 MEGABYTE 모델을 재구현하여, 글로벌-투-로컬 모델링을 통해 기존 모델들에 비해 상당히 높은 처리량을 달성했습니다. 그러나 섹션 3.3에서의 통찰과 일치하게, 로컬 계산 능력이 강화된 본 논문의 모델들은 MEGABYTE 위에 추가로 $1.5 \times$ 이상의 처리량 증가를 보였습니다. 자세한 내용은 Appendix O에서 확인할 수 있습니다.
MEGABYTE 모델과의 성능 비교
MEGABYTE 모델은 글로벌-투-로컬 구조를 채택하지만 인퍼런스보다는 효율적인 사전 학습에 중점을 둡니다. 이에 따라, 훈련 FLOPs 예산 내에서, 그들은 최적으로 여겨지는 6:1 비율을 바탕으로 더 큰 블록 디코더의 필요성을 언급합니다. Figure 5b에서 보여진 바와 같이, 토큰 레벨 MEGABYTE 모델들을 재구현하였고, 이들 또한 글로벌-투-로컬 모델링을 통해 기존 모델들에 비해 상당한 처리량을 달성하였습니다. 그러나 섹션 3.3에서의 통찰에 따르면, 로컬 계산 능력이 강화된 본 논문의 모델들은 MEGABYTE 모델을 초과하여 $1.5 \times$ 이상의 처리량 증가를 보여주었습니다. 자세한 내용은 Appendix O를 참조하세요.
KV 캐시 압축과의 관계
글로벌-투-로컬 모델링은 KV 캐시 압축의 관점을 통해 볼 수 있습니다. 새로운 레이어에서는 과거 시퀀스가 완전히 제거됩니다.
KV 캐시 압축과의 관계
글로벌-투-로컬 모델링은 KV 캐시 압축의 관점을 통해 이해할 수 있습니다. 이 과정에서는 새로운 레이어에서 과거의 시퀀스가 완전히 제거됩니다.
연구들은 누적된 주의 점수에 의해 결정된 의미 있는 토큰만 보존하는 알고리즘을 도입하였습니다[67, 77]. 대부분의 주의는 첫 번째 토큰으로 집중되는 경향이 있다는 것이 관찰되었습니다[72, 28]. Figure 5c에서, 본 논문의 모델들은 비슷한 패턴을 보여줍니다. 이 관찰은 현재의 컨텍스트 임베딩뿐만 아니라 이전 창의 글로벌 임베딩 또는 컨텍스트 임베딩을 포함함으로써 성능을 향상시킬 수 있다는 것을 시사합니다. 자세한 내용은 Appendix P에서 확인할 수 있습니다.
4.2 컨텍스트 블록 임베딩에 포함된 컨텍스트 정보
입력 토큰과 컨텍스트 임베딩이 토큰 디코더에서 같은 잠재 공간을 공유하기 때문에, 이런 블록 임베딩에 가장 가까운 토큰들을 분석했습니다. 흥미롭게도, Appendix Q의 표 5에서 컨텍스트 임베딩은 다음 블록을 개요하는 것이 아니라 글로벌 컨텍스트를 압축하는 것으로 나타났습니다. 두 번째 접두사는 종종 현재 블록의 마지막 토큰에 대한 정보를 포함하여 다음 블록의 첫 번째 토큰을 예측하는 데 도움이 됩니다. 한편, 첫 번째 접두사는 일반적으로 비직관적이거나 EOS 토큰과 일치하는 경향이 있어, 보다 일반적인 정보를 담고 있다고 제안합니다. 이런 점을 감안할 때, 블록 디코더는 과거의 글로벌 컨텍스트를 효과적으로 압축하며, 토큰 디코더는 그것을 지역 언어 모델링에 활용합니다.
4.3 처리량 향상을 위한 기술
블록 자동 회귀 모델과 병렬 토큰 디코딩
블록 디코더를 사전 학습하여 다음 입력 블록 임베딩을 예측할 경우, 토큰 디코더는 블록 디코더의 예측이 정확하다면 모든 블록을 병렬로 디코드할 수 있습니다.
Mujika [44]는 임베딩 행렬을 직접 예측함으로써 사전 학습 효율을 향상시켰지만, 블록 디코더에서 MSE나 대조 손실[16]을 사용하면 성능이 저하된다는 것을 발견했습니다. 또한, 블록 임베딩으로는 이산화가 불가능하므로 블록 레벨에서의 오류 축적 문제를 해결해야 합니다.
그럼에도 불구하고, 사전 훈련된 텍스트 임베딩[68, 36]을 ground truth로 사용하는 것이 임베더를 공동으로 훈련하는 대신 유익할 수 있습니다.
한 번에 여러 블록 예측
모델이 동시에 두 개 또는 세 개의 블록을 예측하도록 훈련되면, 처리량은 비례적으로 증가합니다. 예를 들어, 입력 블록 길이가 네 개라면, 토큰 디코더는 두 개의 블록에 해당하는 여덟 개의 토큰을 예측하도록 사전 훈련될 수 있습니다. 효율
적인 훈련 방법 중 하나는 원래의 블록 트랜스포머 모델을 업트레이닝하는 것일 수 있습니다. 성능을 보장하기 위해, 후속 블록의 확신에 기초하여 예측 길이를 적응적으로 조정하거나 그러한 초안을 검증할 수 있습니다. 이는 추측적 디코딩[37, 15, 39]과 유사합니다.
Generating tokens with transformer-based autoregressive language models (LMs) is costly due to the self-attention mechanism that attends to all previous tokens [6, 66]. To alleviate the cost of the self-attention, it is common to cache the key-value (KV) states of all tokens across all layers during the autoregressive decoding. However, while each decoding step only computes the KV state of a single token, it still has to load the KV states of all previous tokens for computing self-attention scores. Subsequently, this KV cache IO mostly dominates the inference cost spent on serving LMs. While several techniques have been proposed for reducing the inference cost of the attention component [20, 35, 69], developing effective transformer-based LM architectures that inherently avoid the attention overhead is still an ongoing challenge.
Hierarchical global-to-local architectures [49, 31] have shown significant potential to effectively model large-scale data by addressing global dependencies in coarse detail and capturing fine details within local regions. Inspired by these frameworks, we identify a unique opportunity to mitigate key bottlenecks in autoregressive transformer inference: (1) coarse global modeling can reduce overall costs by its granularity; but more importantly, (2) localized self-attention can nearly eliminate the costs of attention as there is no need to compute, store, and retrieve KV-cache of past tokens beyond the small local context.
Figure 1: An overview of the Block Transformer architecture, demonstrated with a block length of four (each alphabet symbol represents one token from the vocabulary). The shaded parts indicate prompt tokens, which do not need to be prefilled for the token decoder during inference. The receptive field of the last token is illustrated with a green line, demonstrating how global-to-local language modeling efficiently covers the full context in the receptive field.
This paper presents the Block Transformer architecture which models global dependencies through self-attention between coarse blocks (each representing multiple tokens) at lower layers, and decodes fine-grained tokens within each local block at upper layers, as shown in Figure 1. Specifically, a lightweight module called (1) the embedder first embeds each block of LB input tokens into an input block embedding. These become the input units of (2) the block decoder, an autoregressive transformer that applies self-attention between blocks to decode a context block embedding which contains information for predicting the next block. Finally, (3) the token decoder autoregressively decodes the token contents of the next block, applying local self-attention between only the LB tokens within the block. While this leaves the token decoder to solely rely on the output block embedding for global context information, it drastically reduces self-attention costs to be linear to total context length, and eliminates the need to prefill prompt tokens during inference.
While analogous transformer architectures have been proposed to handle long sequences comprised of raw bytes [74], prior work consider the global module to be the primary model, benefiting from coarse processing, while the embedder and local module simply map between coarse and fine representations to reduce context length. Our approach to global-to-local modeling in LMs challenges these prior beliefs, and uncovers substantial inference-time benefits that have been overlooked in previous work. In detail, we propose that both the global block decoder and local token decoder can play vital roles in language modeling, hence the term global-to-local language modeling. Our ablations reveal that a more balanced parameter allocation across the global and local modules enhances performance, and also results in higher throughput due to significantly shortened context lengths in the local module.
Extensive experiments on models up to 1.4 billion parameters show that Block Transformers notably improve inference throughput for both prefill- and decode-intensive scenarios, achieving 10–20× gains in throughput compared to vanilla transformers with equivalent perplexity or zero-shot task performance. Despite the architectural restriction of global attention in Block Transformers, our models show similar ability to utilize global context compared to their vanilla transformer counterparts. In addition, we show that it is possible to uptrain pretrained vanilla model into Block Transformers, closely approaching the performance of those pretrained from scratch, using just 10% of the training budget for adaptation.
Our main contributions are summarized below:
• We are the first to recognize the central role and inference-time benefits of both global and local modeling in autoregressive transformers–particularly the significance of local modules.
• We leverage these insights to optimize inference throughput in our architecture to significantly extend the Pareto frontier of performance to throughput compared to vanilla transformers.
EmbedderBlockDecoderTokenDecoderIJKABCDEFGHMNOPIJKLQRSQRSTIJKLMNOMNOPEFGEFGHTable 1: Comparison of relative compute and memory costs for our block and token decoder compared to vanilla transformers, and overview of principal bottlenecks for each inference stage. The number of layers is represented by N , the dimension by D, the batch size by B, the context length by L, and the block length by LB. † Token decoder is not used during the prefill stage, so its complexity is zero. The details about inference efficiency are summarized in Appendix E.
The Block Transformer employs global and local attention mechanisms with hierarchical paradigm by separating the comprehension of the full context and detailed interactions into two distinct stages. Precisely, global context is captured at lower layers as coarse block-level granularity, where each block consists of a fixed number of tokens aggregated into a single embedding. The local dependencies are resolved at upper layers, where multiple subword tokens are decoded in an autoregressive manner by solely attending context block embedding from the block decoder. The Block Transformer consists of three components:
The main goal of our architecture design is to minimize the wall-clock bottlenecks during inference. In vanilla transformers, the global treatment of self-attention to all previous tokens significantly hinders batch decoding throughput, mainly due to memory overhead of retrieving previous KV cache \cite{20, 25}. This also necessitates all prompt tokens, which are typically quite lengthy, to be fully prefilled prior to decoding the first token, contributing to increased latency \cite{1, 25}.
A global-to-local approach can mitigate these costs by isolating the expensive bottlenecks of global modeling to the lower layers and perform local modeling within independent blocks at the upper layers. Coarse-grained global modeling (block-level decoding) alleviates KV cache bottlenecks by a factor of block length, while maintaining the ability to account for the full context. Local decoding comes free of the cost of prefill, and nearly removes KV cache overhead, thus benefits from significantly higher utilization of the compute units on inference hardware. This allows the token decoder to use more FLOPs for fine-grained language modeling with minimal impact on inference throughput. Table 1 outlines the principal wall-time bottlenecks at the prefill and decode stages, and summarizes the efficiency gains of our block and token decoders.
Although our models require more parameters than vanilla transformers to maintain comparable performance, the actual bottleneck in throughput is the KV cache overhead, allowing our model to still achieve higher speed improvements. Thereby, we focus on production systems like cloud platforms, which can accommodate the higher parameter demands. Edge devices are constrained by memory \cite{3} and typically use small batches \cite{61}. Since parameter IO is a critical bottleneck \cite{51}, we leave the optimization of the Block Transformer for on-device scenarios to future work.
Our embedder design prioritizes simplicity given the small block length (2–8) in our study. We primarily use a lookup table \(E_{\text{emb}} \in \mathbb{R}^{V \times D_{\text{emb}}}\) to retrieve and concatenate trainable token embeddings, where the token embedding dimension \(D_{\text{emb}}\) is set to \(D/L_B\), with \(D\) being the dimension of block representations used throughout the network. While we explored variants such as small encoder transformers (Appendix F), these did not yield performance improvements (Section 3.4).
The block decoder aims to contextualize block representations by attending to preceding blocks, utilizing the embedder’s output as input. This autoregressive transformer operates at the block level, producing output block embeddings (also called context embeddings) that enable the token decoder to autoregressively decode the subsequent block’s token contents. Given input block embeddings from the embedder, derived from input tokens \(x_{0:(i \times L_B -1)}\), the block decoder outputs a context embedding which contains the information to predict \(x_{(i \times L_B )}:(i+1 \times L_B -1)\). This approach mitigates the quadratic costs of self-attention by using coarse-grained block inputs instead of individual tokens, while preserving global modeling capabilities and ease of hardware acceleration of dense attention \cite{75}. This reduces the context length of a given sequence by \(L_B\) compared to a vanilla transformer. In terms of FLOPs (the main bottleneck during prefill), all positionwise computations are reduced by a factor of \(L_B\), and attention score computation is reduced by \(L^2_B\) \cite{74}. During decoding, KV cache usage and KV cache IO (the main bottleneck during batch decoding) are reduced by \(L_B\) and \(L^2_B\), respectively, allowing for larger batch sizes and higher compute utilization.
The token decoder locally decodes the individual tokens of the next block using the context block embedding as the sole source of global context information. The token decoder is also a standard autoregressive transformer, featuring its own embedding table \(E_{\text{tok}} \in \mathbb{R}^{V \times D_{\text{tok}}}\) and classifier. The key to designing the token decoder lies in how to incorporate the context embedding into the decoding process, in a way that effectively leverages the high compute density of the token decoder.
The token decoder eliminates prefill (necessary only in the block decoder), as context information is provided by the output block embedding–hence the term context embedding. Additionally, KV cache IO, a major bottleneck during batch decoding, is nearly removed. While vanilla attention’s KV cache IO is quadratic to the full context length (\(L^2\)), the token decoder’s local attention costs \(L^2_B\) per block over \(L/L_B\) blocks, resulting in a linear cost to the full context length and a reduction factor of \(L/L_B\) (e.g., 2048/4 = 256 in our main models). This allows for significantly higher compute unit utilization compared to vanilla transformers, which have \(\sim 1\%\) model FLOPs utilization (MFU) \cite{51}, making the inference wall-time cost of extra FLOPs relatively cheap.
To incorporate the context embedding and leverage this low-cost compute, we project the context block embedding into prefix tokens, enabling further refinement of the global context. Expanding the number of prefix tokens (prefix length) broadens the token decoder’s computation width and allows for finer attention to context information, similar to pause tokens \cite{29}. Owing to parallel processing and small local context, these extra prefix tokens do not incur significant wall-time overhead. While we also considered summation and cross-attention based variants (Appendix F), these proved less effective than our main method (Section 3.4).
We use the transformer architecture of Pythia [8], and train both vanilla and Block Transformer models on the Pile [26, 7] with a context length of 2048. The models are pretrained on 300B tokens, which corresponds to about 1.5 epochs. We employ the HuggingFace training framework [70]. Eight A100 GPUs with 40 GiB of VRAM are used for training, while an H100 GPU is used for inference wall-time measurements. Experimental details of each subsection are summarized in Appendix G.
Table 2: Performance comparison between vanilla and block transformer models. For a clear comparison, we highlight an example where the vanilla and our models achieve comparable levels of training loss. We measure the perplexity of LAMBADA [48] and WikiText [42], and the accuracy of HellaSwag [76], PIQA [9], and ARC-easy [18] benchmarks. Memory refers to the amount of memory allocated per sample, measured in megabytes, while throughput is measured in units of 1K tokens per second. * refers to variants trained with random-length padding2.
Figure 2: Pareto frontier of throughput to language modeling performance. Throughput denotes the number of generated tokens per second, and the numbers next to each point represent the number of non embedding parameters. (a) Pareto frontier in the prefill-heavy setting. (b) Pareto frontier in the decode-heavy setting. (c) Throughput in the prefill-heavy setting with varying prompt lengths. Each point corresponds to the same order of model sizes as in the left figures.
In Table 2, we measure the language modeling performance of the Block Transformer. Block models are scaled to have the same number of non-embedding parameters as the vanilla model variants. Our models, when having two or three times more parameters, achieve comparable perplexity and accuracy on five zero-shot evaluation tasks as the vanilla models. This is an expected result because two separate decoders spend fewer FLOPs per forward pass, reducing the attention complexity by a factor of \(1/L2\) B at the block-level and by roughly LB/L at the token-level.
The actual inference throughput and memory efficiency of the Block Transformer are significantly higher compared to vanilla models. We measure the maximum throughput [60], which use maximum batch sizes of each model variant allowed by memory. As shown in Figure 2a and Figure 2b, our models achieve Pareto-optimality, especially demonstrating up to 25 times increase, under two scenarios: prefill-heavy and decode-heavy, where the input and output sequence lengths are 2048, 128 and vice-versa. This efficiency improvement is due to effective reductions in KV cache memory, which allows batch sizes to be about six times larger, as summarized in memory per sample in Table 2. The Block Transformer further reduces latency in a prefill-heavy setting, as past KV states of prompts need to be cached only in the block decoder, without forwarding them to the token decoder.
During evaluation, we add left padding of length \(L_B - 1\) to the first block. To use internal padding in blocks during inference, we apply random-length padding when packing documents for pretraining (see Appendix H). Absence of this technique results in significant performance drop for certain tasks such as LAMBADA.
(a) Loss by allocation ratio
(b) Loss by block length
(c) Embedder ablations
(d) Position loss by ratio
(e) Position loss by length
(f) Token decoder ablations
Figure 3: (Left: (a), (d)) Average and position-wise loss by the ratio of parameter allocation between block and token decoders. The ratio is represented as block to token decoders. (Center: (b), (e)) Average and position-wise loss in relation to block length LB. (Right: (c), (f)) Training loss curve for variants of the embedder and token decoder. We consider four different lengths for the prefix-based token decoder. We use models with 302M non-embedding parameters and one-to-one ratio trained on 8 billion tokens.
The Pareto frontiers for variable fixed batch sizes, i.e., 1, 32, and 256, are illustrated in Appendix I. We discover that as both the model size and batch size increase, the throughput rate of the Block Transformer scales exponentially. Considering that the LLMs typically utilized in real-world applications have billions of parameters, and taking into account the strategy of aggregating multiple user requests to optimize batch inference [35, 50, 60], the results suggest that our proposed architecture will demonstrate even more benefits in practical multi-tenant deployment scenarios.
In Figure 2c, we observe that the throughput of the Block Transformer with an 8K prompt length surpasses that of the vanilla model with a 2K prompt length. This is reasonable because the context length of the block decoder is reduced by a factor of 4, and the token decoder is nearly free of KV-cache overheads. Given the rising interest in enabling longer context lengths, even over one million tokens [13, 57, 46], the Block Transformer has potential to enhance throughput even further.
Perplexity shows a U-shaped pattern across different allocation ratios. We explore the impact of different allocation ratios between the block and token decoders on language modeling performance, while keeping the total number of non-embedding parameters constant. Figure 3a illustrates the training loss across five distinct ratios for three model sizes. Interestingly, there is a clear U-shaped trade-off at all three model sizes. We find that a one-to-one ratio is optimal for models with \(L_B = 4\). If either side is too small, there is a noticeable decline in performance consistently across all model sizes. This demonstrates the synergistic effect and the equal importance of the block and token decoders in language modeling.
Larger block and token decoders reduce perplexity at initial and later positions respectively We measure average loss at each position within a block, depicted in Figure 3d. The position-wise loss typically exhibits a U-shaped pattern, aligning with findings from a previous multiscale language model [74] and blockwise parallel decoding methods [62, 14, 34]. This trend stems from the lack of global context in context embeddings, which escalates uncertainty at later positions. Moreover, perplexity at specific positions correlates with the parameter sizes of two decoders. A larger block decoder significantly lowers initial position loss due to predictions solely based on the context embedding. In contrast, a larger token decoder improves prediction accuracy for later tokens by better leveraging local context. These interdependent effects dictate the optimal parameter ratio, with similar patterns evident in models of various sizes, detailed in Appendix J.
Shorter block length favors larger block decoder whereas longer length prefers token decoder Figure 3b demonstrates that training loss still follows a U-shaped pattern across different allocation ratios, regardless of block length. Optimal ratios shift with block length: shorter blocks benefit from a larger block decoder, while longer blocks perform better with more parameters in the token decoder. This is due to the inverse relationship between block length and FLOPs of the block decoder, which influences model capacity [22, 23, 29]. As Figure 3e shows, first position loss significantly decreases with shorter blocks, reflecting increased capacity in the block decoder. While the token decoder shows minimal differences in FLOPs across block lengths, it has more chance to improve the likelihood of later tokens as block length increases, favoring a larger token decoder. These trends are consistent across different model sizes and allocation ratios, detailed in Appendix K.
Larger token decoder and longer block length are beneficial for achieving high-throughput We evaluate the allocation ratio and block length from a throughput perspective, summarizing the Pareto frontier in Appendix L. Models with larger token decoders reach Pareto-optimality by achieving higher throughput at a minor performance compromise. Since KV cache IO significantly influences inference time, allocating more parameters to the token decoder is advantageous because the local context length is bounded by the block length. Additionally, increasing the block length improves throughput as KV cache length in the block decoder reduces proportionally. Therefore, although our main configuration uses a one-to-one ratio and a block length of four, opting for a longer block length and a larger token decoder could result in a higher-throughput model.
Lookup strategy is the most effective approach for the embedder In Figure 3c, we experiment with three embedder strategies to bundle block tokens into a single embedding. Surprisingly, a complex transformer encoder like RoBERTa [40] does not outperform a simpler lookup table strategy. Moreover, the encoder-based embedder lowers generation throughput due to additional computational overhead. As a result, we opt for the lookup strategy to steamline the Block Transformer architecture. Although the CLS token approach allows flexibility in block length, we leave it for future work as it compromises language modeling performance.
Prefix token decoder with longer prefixes enhances performance with minimal overhead Figure 3f shows the training loss curve for three token decoder strategies. Using a cross-attention module with key and value sequences equal to the block length considerably diminishes performance. In contrast, forwarding context embeddings through self-attention operations enhances performance, with prefix decoding surpassing other methods. Furthermore, extending the prefix beyond four tokens markedly improves perplexity, effectively broadening the computation width of token decoder. Since longer prefixes add minimal inference overhead, we select a prefix length of two by balancing performance with FLOPs. This approach offers new insights into global-to-local modeling, diverging from previous studies [74] which overlook the potential of local computational capacity in the token decoder. Detailed results across various model sizes are summarized in Appendix M.
Global-to-local language modeling efficiently optimizes throughput relative to performance In Figure 4a, we transition from vanilla to Block Transformers by adjusting block lengths. As block length increases, training loss changes log-linearly and throughput increases exponentially, clearly demonstrating the efficiency of global-to-local modeling. Using a lookup embedder and token decoder with one prefix token, our model with LB = 1 differs from the vanilla model only by removing global attention in the upper layers. Notably, this model achieves loss equivalent to that of the vanilla model after training on 70% of the tokens, while doubling throughput. Despite pruning all past sequences, this robust performance shows that the context embedding can retain relevant information, enabling the effective of use local computations in global-to-local language modeling.
Figure 4: (a) Training loss curve with varying block lengths. The numbers in the brackets represent the maximum throughput, measured in 1K tokens per second, for prefill-heavy and decode-heavy settings, respectively. (b) The loss at different token positions within context length on the PG19 test set. We average over every 128 sequences for smoothing. (c) Training loss curves under the same budget for both training FLOPs and inference throughput.
Block transformer can effectively leverage full context Since the token decoder depends solely on the context embedding, there could be a concern about whether the Block Transformer fully utilize context information. To address this, we evaluate the loss of token positions within a 2K context window using the test set of PG19 dataset [52]. Figure 4b indicates that later tokens are consistently predicted with higher likelihood, suggesting that our architecture, which distinguishes between block-level and token-level decoders, effectively leverages at least 2K tokens of context.
Previous studies have focused on compute-optimal models to maximize performance within training FLOPs budgets [33, 32], while typically overlooking inference throughput. Recent trends, however, emphasize models that also consider inference throughput constraints, either by overtraining smaller models [65, 64] or by reducing FLOPs of the model itself [55]. In Figure 4c, an optimal Block Transformer model achieves superior perplexity and triples the throughput when using the training FLOPs and throughput of the vanilla model as budget constraints. This illustrates that our models can effectively balance training efficiency and inference throughput.
Unlike previous studies [74], our subword-level global-to-local architecture can leverage the initialization from a pretrained vanilla transformer. This enables efficient training, requiring only a small number of data. As shown in Figure 5a, this uptraining strategy can lead to near-full performance recovery with just 10% of the original training steps, outperforming random initialization strategy. Consistent with previous studies [2], investigating deliberate weight initialization techniques can further enhance the performance convergence. We summarize details in Appendix N.
Performance comparison to MEGABYTE The MEGABYTE model [74] adopts a global-to-local structure but focuses on efficient pretraining over inference. Thus, within the training FLOPs budget, they argue for a larger block decoder based on a 6:1 ratio deemed optimal. As shown in Figure 5b, we reimplement the token-level MEGABYTE models, and they also achieve significantly higher throughput compared to vanilla models through global-to-local modeling. Nevertheless, consistent with our insights in Section 3.3, our models with enhanced local computational capacity demonstrate a significant throughput increase of over 1.5 times on top of MEGABYTE. See Appendix O for more details.
\[\text{Performance comparison to MEGABYTE:}\]The MEGABYTE model [74] adopts a global-to-local structure but focuses on efficient pretraining over inference. Thus, within the training FLOPs budget, they argue for a larger block decoder based on a 6:1 ratio deemed optimal. As shown in Figure 5b, we reimplement the token-level MEGABYTE models, and they also achieve significantly higher throughput compared to vanilla models through global-to-local modeling. Nevertheless, consistent with our insights in Section 3.3, our models with enhanced local computational capacity demonstrate a significant throughput increase of over $1.5 \times$ on top of MEGABYTE. See Appendix O for more details.
Relation to KV cache compression Global-to-local modeling can be viewed through the lens of KV cache compression, where past sequences are entirely pruned in the new layers.
\[\text{Relation to KV cache compression:}\]Global-to-local modeling can be viewed through the lens of KV cache compression, where past sequences are entirely pruned in the new layers.
Figure 5: (a) Training loss curve with uptraining strategy. The red horizontal line refers to the training loss of a full pretrained model. (b) Throughput comparison to MEGABYTE. We compare to three sizes of MEGABYTE in the prefill-heavy setting. (c) Visualization of heatmap for attention scores in block decoder. We visualize only the first 64 sequences for clarity.
Studies have introduced algorithms that preserve only meaningful tokens, determined by accumulated attention scores [67, 77], with observing that most attention tends to sink into the first token [72, 28]. In Figure 5c, our models exhibit a similar pattern. This observation suggests that performance could be enhanced by leveraging not just the current context embedding but also by incorporating global embeddings or context embeddings from the previous window. See Appendix P for more details.
Since the input tokens and context embeddings share the same latent space in the token decoder, we analyze the nearest tokens to these block embeddings. Interestingly, Table 5 in Appendix Q reveals that context embeddings compress global context rather than outlining the next block. The second prefix often contains information about the last token of current block to aid predicting the first token of the next block. Meanwhile, the first prefix typically matches non-intuitive or the EOS token, suggesting that they carry more general information. In light of this, the block decoder effectively compresses past global contexts, which the token decoder leverages for its local language modeling.
Block autoregressive model with parallel token decoding When we pretrain the block decoder to predict next input block embeddings, the token decoder can decode all blocks in parallel if the predictions from block decoder are precise. While Mujika [44] enhance pretraining efficiency by directly predicting the embedding matrix, we find that MSE or contrastive losses [16] at the block decoder actually degrades performance. Moreover, error accumulation at the block level needs to be addressed, as discretization is not possible with block embeddings. Nevertheless, using pretrained text embeddings [68, 36] as ground truth, instead of jointly training embedder, could be beneficial.
Predicting multiple blocks at once with longer output length If the model is trained to predict two or three blocks simultaneously, throughput will increase proportionally. For example, if the input block length is four, the token decoder can be pretrained to predict eight tokens, equivalent to two blocks. One efficient training method could be uptraining the original Block Transformer models. To guarantee performance, we can adaptively adjust the prediction length based on the confidence of subsequent blocks or verify those drafts, similar to speculative decoding [37, 15, 39].
We introduced the Block Transformer architecture which highlights the inference-time advantages of global-to-local modeling in autoregressive transformers. Our empirical findings demonstrate that both global and local components play vital roles, and we recognize the inference benefits of token decoder, which was overlooked in previous work. By strategically designing our architecture, we significantly improve throughput compared to vanilla transformers of equal performance. Refer to Appendix A for limitation, Appendix B for future works, and Appendix C for broader impacts. |