00:00:00

Share Your Feedback 🏝️

Switch Transformers

Switch Transformers

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Survey Efficiency | Full Stack Optimization of Transformer Next: POST | Estimation FLOPs of LLaMA-2

Switch Transformers

  • Related Project: Private
  • Category: Paper Review
  • Date: 2023-12-12

Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

  • url: https://arxiv.org/abs/2101.03961
  • pdf: https://arxiv.org/pdf/2101.03961
  • abstract: In deep learning, models typically reuse the same parameters for all inputs. Mixture of Experts (MoE) defies this and instead selects different parameters for each incoming example. The result is a sparsely-activated model with outrageous numbers of parameters but a constant computational cost. However, despite several notable successes of MoE, widespread adoption has been hindered by complexity, communication costs and training instability we address these with the Switch Transformer. We simplify the MoE routing algorithm and design intuitive improved models with reduced communication and computational costs. Our proposed training techniques help wrangle the instabilities and we show large sparse models may be trained, for the first time, with lower precision (bfloat16) formats. We design models based off T5-Base and T5-Large to obtain up to 7x increases in pre-training speed with the same computational resources. These improvements extend into multilingual settings where we measure gains over the mT5-Base version across all 101 languages. Finally, we advance the current scale of language models by pre-training up to trillion parameter models on the “Colossal Clean Crawled Corpus” and achieve a 4x speedup over the T5-XXL model.

Contents

TL;DR


이 논문은 스위치 트랜스포머의 구조적 간소화와 효율성 증대에 초점을 맞추며, 크고 복잡한 데이터셋를 처리하는 데 필요한 계산 비용을 줄이는 방법을 제시합니다. 이는 효율적인 모델 스케일링과 벤치마킹에 대한 깊은 이해를 바탕으로 수행됩니다.


1. 서론

대규모 훈련이 유연하고 강력한 신경 언어 모델로 이어지며, 복잡한 알고리즘보다 간단한 구조가 더 효과적임을 지지하는 증거가 제시됩니다. 이에 힘입어 스파스한 활성화 전문가 모델인 스위치 트랜스포머를 제안하며, 이는 대규모 훈련 및 효율적인 계산을 목표로 합니다.


2. 스위치 트랜스포머

스위치 트랜스포머는 트랜스포머 모델의 파라미터 수를 최대화하면서 단순하고 계산적으로 효율적인 방법을 제공하며, 본 연구는 모델 크기, 데이터셋 크기 및 계산 예산이 모델 성능에 미치는 영향을 체계적으로 조사합니다.

  • 스위치 트랜스포머는 트랜스포머 모델의 파라미터 수를 효율적으로 증가시키기 위해 설계되었습니다.
  • 스파스 라우팅 기술을 사용하여 연산 효율을 극대화하며, 모델 크기에 따른 성능 개선을 실증적으로 보입니다.
  • 실험을 통해 스위치 트랜스포머가 덴스 모델과 MoE 모델을 초과하는 성능을 보임을 확인합니다.

2.1 스파스 라우팅의 단순화

스위치 트랜스포머에서의 핵심적인 특징은 스파스 라우팅(sparse routing)으로, 이는 각 토큰을 최적의 전문가(expert)에게만 할당하는 방식으로, 이런 접근 방식의 수학적 모델은 다음과 같습니다.

입력 토큰 $x$에 대해, 라우터 $W_r$은 $x$와의 내적을 통해 로짓 벡터 $h(x)$를 생성

\[h(x) = W_r \cdot x\]

위 수식에서 $h(x)$는 softmax 함수를 통해 각 전문가에 대한 확률 분포로 정규화되고,

\[p_i(x) = \frac{\exp(h_i(x))}{\sum_{j=1}^N \exp(h_j(x))}\]

이 확률값들 중 최대값을 갖는 전문가에게만 토큰이 할당되고, 그 결과는 다음과 같이 계산됩니다.

\[y = \sum_{i \in T} p_i(x)E_i(x)\]

위 수식에서 $T$는 선택된 전문가의 집합이며, 일반적으로 최고의 확률값을 가진 하나의 전문가만 포함합니다.

배경 지식 및 연구

이런 스파스 라우팅 방식은 연산량을 크게 줄이면서도 모델의 성능을 유지할 수 있도록 하는데, 이는 모델의 파라미터 수를 늘리면서도 각 예시에 대한 연산량(FLOPs)은 일정하게 유지하는 전략과도 일치합니다. 이런 접근은 하드웨어 자원의 효율적 사용과 직결되며, 특히 TPU와 같은 밀집 행렬 연산에 최적화된 하드웨어에서 높은 효율성을 발휘할 수 있습니다.

2.2 효율적인 스파스 라우팅

Mesh-TensorFlow를 사용하여 효율적인 분산 데이터 및 모델 병렬 구조를 구현합니다. 이 라이브러리는 물리적 코어 집합을 논리적 메시로 추상화하여, 텐서와 연산을 명명된 차원별로 쉽게 분할할 수 있게해 모델의 크기 확장성과 유연성을 크게 향상시킬 수 있습니다.

실험 및 적용

실제 실험에서는 이런 분산 구조를 활용하여 모델의 전문가 용량을 동적으로 조절한 뒤 전문가의 수에 비례하여 토큰이 할당되며, 전문가의 용량은 불균형을 최소화하면서도 연산 효율을 극대화하는 방식으로 설정했다고 합니다.

2.3 통합 및 최적화

스위치 트랜스포머의 성능을 극대화하기 위해, C4 데이터셋을 사용하여 pre-training을 수행하고, 이후 다양한 downstream 작업에 대한 파인튜닝을 진행합니다.

이 과정에서 고려되는 주요 수학적 요소는 손실(loss)의 최적화이며, 손실은 각 전문가에게 할당된 토큰의 균등 분포를 목표로 최소화됩니다.

\[\mathcal{L}_\text{aux} = N \cdot \alpha \cdot \vec{f} \cdot \vec{P}^\top\]

위 수식에서 $\vec{f}$는 각 전문가에게 할당된 토큰의 비율을 나타내며, $\vec{P}$는 예상되는 이상적 분포를 나타내는데, 이 손실 함수는 모델 학습 중 전문가 간의 부하를 균등하게 분배하는 데 중요한 역할을 하게 됩니다.

적용 및 결과

스위치 트랜스포머는 모델의 파라미터 수를 크게 늘리면서도 연산 효율을 유지하여, 기존의 밀집 모델 및 다른 MoE 모델들보다 우수한 성능을 달성하였는데, 특히 대규모 모델에서의 메모리 제약 조건 하에서 더욱 두드러지게 나타났습니다.

2.4 훈련 및 파인튜닝 기술 개선

스파스 전문가 모델의 훈련 안정성을 개선하기 위해 선택적 Precision와 파라미터 초기화 기법을 도입합니다. 또한, 더 큰 드롭아웃 비율을 적용하여 파인튜닝 시 오버피팅을 방지합니다.

  • 훈련의 어려움

    스파스 전문가 모델은 기본 트랜스포머 대비 훈련 시 안정성 문제를 일으킬 수 있습니다. 각 레이어에서의 하드 스위칭(라우팅) 결정은 모델 불안정의 주요 원인입니다. 또한, 낮은 Precision 형식인 bfloat16은 라우터의 소프트맥스 계산 시 문제를 악화시킬 수 있습니다.

  • 선택적 Precision

    이런 문제를 해결하기 위해, 모델의 특정 부분만 float32 Precision로 캐스팅하여 안정성을 확보하면서도, float32 텐서의 비용이 많이 드는 통신 비용을 발생시키지 않도록 설계했습니다. 이는 현대의 혼합 Precision 훈련 전략과 일치합니다. 결과적으로, 접근 방식은 bfloat16 훈련과 거의 동등한 속도를 유지하면서 훈련의 안정성을 높입니다.

  • 파라미터 초기화의 중요성

    스위치 트랜스포머에 적합한 초기화는 깊은 학습에서 성공적인 훈련을 위해 중요합니다. 가중치 행렬을 초기화할 때 평균이 \(\mu = 0\)이고 표준 편차가 \(\sigma = \sqrt{s/n}\)인 절단 정규 분포에서 요소를 추출합니다. 위 수식에서 \(s\)는 스케일 하이퍼파라미터이고, \(n\)은 가중치 텐서의 입력 유닛 수입니다. 불안정성을 추가로 줄이기 위해 기본 트랜스포머 초기화 스케일 \(s = 1.0\)을 10분의 1로 줄이는 것을 권장합니다.

  • 스파스 모델의 규제

    대규모 코퍼스에서 사전 훈련 후, 요약이나 질문 응답과 같은 더 작은 downstream 작업에 파인튜닝을 적용하는 일반적인 NLP 접근 방식을 고려합니다. 표준 트랜스포머와 비교할 때 스위치 트랜스포머는 훨씬 많은 파라미터를 가지고 있으므로 이런 작은 downstream 작업에서 더 심각한 오버피팅 문제가 발생할 수 있습니다. 이 문제를 해결하기 위해 전문가 레이어에서 드롭아웃율을 크게 늘리는 것을 제안합니다. 이런 전문가 드롭아웃은 훈련 동안 성능을 향상시킵니다.


3. 스위치 트랜스포머의 규모 확장 속성 분석

  • 스위치 트랜스포머는 전문가의 수를 증가시킴으로써 효율적인 확장성을 달성합니다.
  • 고정된 연산 예산 하에서 모델 성능이 일관되게 향상된다는 것을 실험적으로 보입니다.
  • 시간 대비 성능 개선을 통해 더 빠른 학습 속도와 높은 샘플 효율성을 제공합니다.

3.1 스텝 기반 규모 확장 결과

스위치 트랜스포머의 핵심 전략 중 하나는 토큰당 하나의 전문가만을 선택하여 연산 비용을 고정시키는 것으로, 이 접근 방식은 다음 수식으로 표현될 수 있습니다.

\[\text{Cost} = O(d_{\text{model}} \times \text{num experts})\]

상기 수식에서 \(d_{\text{model}}\)은 토큰 간 전달되는 임베딩 차원을 나타내며, 이는 전문가의 수에 비례하여 증가하게 됩니다.

이 모델은 큰 데이터셋인 C4 코퍼스를 사용하여 훈련되었으며, 토큰의 수는 180B에 달합니다.

수학적 배경과 논리

이 모델은 다음과 같은 수학적 원리에 기초하여 설계되었습니다.

  • 전문가의 수를 증가시키면 각 스텝에서의 처리 비용은 유지되면서도, 모델이 처리할 수 있는 정보의 양이 증가합니다.
  • 모델의 효율성은 전문가의 수가 증가함에 따라 선형적으로 향상된다고 가정합니다.
  • 실험 결과, 전문가 수를 두 배로 늘릴 때마다 성능이 향상되는 것을 확인할 수 있다(e.g., 2, 4, 8 전문가).

이런 결과는 트랜스포머의 샘플 효율성이 증가함을 보여줍니다.

이는 토큰당 고정된 FLOPS를 유지하면서 모델의 전문가 수를 증가시킬 때, 학습 속도가 7.5배 가량 향상됨을 의미합니다.

3.2 시간 기반 규모 확장 결과

이 섹션에서는 시간 대비 성능을 측정하여, 동일한 연산 예산을 가진 밀집 모델과 비교합니다. 스위치 트랜스포머는 추가적인 라우팅 메커니즘과 장치 간의 통신 비용에도 불구하고, 단계별 성능 향상을 보여줍니다.

\[\text{Training time efficiency} = \frac{\text{Performance}}{\text{Time}}\]

스위치 트랜스포머는 동일한 훈련 시간과 계산 예산을 사용하여 밀집 트랜스포머 Baseline Model보다 7배 빠른 시간 안에 동일한 성능을 달성합니다.

3.3 더 큰 밀집 모델과의 비교

스위치 트랜스포머와 비교하여 더 큰 밀집 모델을 훈련시키는 시나리오를 고려해 본다. T5-Large 모델은 토큰당 3.5배 더 많은 FLOPS를 사용하며, 스위치-베이스는 여전히 더 높은 샘플 효율성을 보여주며 2.5배 빠른 속도를 제공합니다.


4 Downstream task 결과

  • T5 및 스위치 변형 모델의 파인튜닝 결과 분석
  • 다양한 언어 및 태스크에서 모델 성능 및 압축률 측정
  • 크고 희소한 모델에서 작고 조밀한 모델로의 지식 증류 효과 연구

4.1 파인튜닝

기본선 및 스위치 모델의 파인튜닝

T5-Base (223M 파라미터)와 T5-Large (739M 파라미터)는 기존 연구(Raffel et al., 2019)의 향상된 C4 코퍼스에서 pre-trained 후 다양한 NLP 태스크에 파인튜닝되었습니다. 이들 기본선과 비교하여, FLOP 수치가 일치하는 스위치 트랜스포머 모델을 설계하여 더 많은 파라미터를 사용하였습니다. 스위치 레이어를 제외한 모든 레이어에서는 0.1의 드롭아웃 비율을 사용하고, 스위치 레이어에서는 0.4의 드롭아웃 비율을 적용하였습니다. 파인튜닝은 각 태스크별로 1M의 배치 크기로 16k 스텝 동안 수행되며, 200 스텝마다 모델의 품질을 평가하여 검증 세트에서의 최고 성능을 보고합니다.

\[\text{Dropout rate:}\quad p_{\text{other}} = 0.1, \quad p_{\text{switch}} = 0.4\]

파인튜닝 태스크 및 데이터셋

언어 능력을 평가하기 위한 태스크로는 질문 응답, 요약, World Knowledge 등이 포함되며, 이런 태스크는 GLUE 및 SuperGLUE 벤치마크를 통해 평가되었습니다. 이 벤치마크는 감정 분석, 단어 의미 구별, 문장 유사성, 자연어 인퍼런스 등 다양한 하위 태스크를 포함합니다. 또한, 문서 요약 능력은 CNNDM 및 BBC XSum 데이터셋를 통해, 질문 응답 능력은 SQuAD 및 ARC Reasoning Challenge를 통해 평가되었습니다.

파인튜닝 결과

GLUE와 SuperGLUE에서의 개선이 두드러졌으며, 스위치 변형 모델이 T5-Base 및 T5-Large 대비 각각 4.4 및 2 퍼센트 포인트 향상되었습니다. 또한, Winogrande, Trivia QA, XSum에서 큰 향상을 보였습니다. 이런 결과는 스위치 아키텍처가 단순히 사전 훈련에서 우수한 성능을 나타내는 것뿐만 아니라, 파인튜닝을 통해 Downstream task 태스크에서도 품질 향상을 실현할 수 있음을 입증합니다.

4.2 증류(Distillation) 이 섹션에서 소개하는 증류는 대규모 스파스 모델(sparse models)을 조밀한 모델로 증류하는 것을 의미하며, 실험 결과 스파스 모델의 비전문가 레이어에서 초기화된 조밀 모델은 약간의 성능 향상을 보였으며, 이는 모든 모델이 FLOP 수치가 일치하기 때문에 가능한 방법이였다고 합니다. teacher 확률 0.25와 실제 레이블 0.75의 혼합을 사용한 증류는 큰 스파스 모델의 약 30% 품질 향상을 조밀 모델로 이전할 수 있었습니다.

\[\text{Teacher probability:}\quad \alpha = 0.25, \quad \text{Actual label:}\quad \beta = 0.75\]

가능한 압축률

스위치 베이스 모델을 다양한 크기의 조밀 모델로 증류할 때, 1.1B 파라미터 모델의 품질 향상 37%를 유지하며 82% 압축하는 결과를 얻었습니다. 더 극단적인 경우, 모델을 99% 압축하더라도 teacher 모델의 28% 성능 향상을 유지할 수 있었습니다.

4.3 다국어 학습

모델 품질 및 속도 절충 분석

101가지 다양한 언어에 대한 사전 훈련을 통해 모델 품질과 속도 절충을 측정하였습니다. mT5의 최근 연구를 기반으로 다국어 T5 확장 버전인 mSwitch-Base를 개발하였습니다. 1M 스텝 동안의 사전 훈련 후, 모든 101개 언어에서 기본 모델 대비 최종 음의 로그 우도에서의 품질 개선을 확인하였습니다. 이는 스위치 트랜스포머가 다양한 태스크 및 다국어 학습에서 효과적임을 입증합니다.


5. 데이터, 모델 및 전문가 병렬 처리를 활용한 모델 설계

  • 데이터, 모델, 전문가 병렬 처리의 결합
  • 수학적 모델 및 분할 전략 분석
  • 향상된 성능과 효율성 입증

5.1 데이터 병렬성

데이터 병렬 모델은 분산 훈련의 표준 방법으로, 모든 코어가 데이터 병렬 차원에 할당됩니다(\(n = N\), \(m = 1\)). 이 구성은 모든 전진 및 후진 패스가 완료될 때까지 통신이 필요 없다는 장점이 있으며, gradient는 전체 패스가 끝난 후에만 집계됩니다.

5.2 모델 병렬성

모델 병렬 차원에 모든 코어를 할당할 때(\(n = 1\), \(m = N\)), 각 코어는 전체 \(\mathcal{B}\) 토큰을 유지하며 각 코어는 가중치의 고유한 슬라이스를 포함합니다. 이 구조는 모든 전진 및 후진 패스에 통신 비용을 발생시키며, 분할된 차원이 합산되어야 할 때마다 전방 및 후방 패스에 all-reduce 연산을 추가합니다.

5.3 모델 및 데이터 병렬성

큰 규모의 모델에는 일반적으로 모델 및 데이터 병렬성이 혼합됩니다. 총 코어 수 \(N = n \times m\)가 되며, 이제 각 코어는 \(\mathcal{B}/n\) 토큰과 \(d_\text{ff}/m\)의 가중치 및 중간 활성화를 담당합니다. 각 패스에서 각 코어는 크기가 \([\mathcal{B}/n, d_\text{model}]\)인 텐서를 통신하며 all-reduce 연산을 수행합니다.

5.4 전문가 및 데이터 병렬성

전문가 및 데이터 병렬 차원 \(n\)에 모든 코어를 할당하며, 이는 모델의 전문가 수와도 일치합니다. 각 코어는 로컬로 전문가에 대한 할당을 계산하고, 결과는 크기가 \([n, \mathcal{B}/n, E, C]\)인 이진 행렬로 나타나며, 이는 첫 번째 차원을 따라 분할됩니다. 이 이진 행레을 통해 입력 텐서와 행렬 곱셈을 수행하여 최종 텐서 \([n, E, C, d_\text{model}]\)을 형성합니다.

5.5 전문가, 모델 및 데이터 병렬성

모델의 최적 설계에서는 토큰당 FLOPs와 파라미터 수를 균형있게 조정합니다. 전문가 수 \(E\)를 증가시키면 파라미터 수는 증가하지만 토큰당 FLOPs는 변경되지 않습니다. FLOPs를 증가시키기 위해서는 \(d_\text{ff}\) 차원을 증가시켜야 합니다. 이는 파라미터 수를 증가시키지만, 더 느린 비율로 증가합니다. \(d_\text{ff}\)를 증가시키면 코어 당 메모리가 부족해지므로 \(m\)을 증가시켜야 합니다. 그러나 코어 수 \(N\)이 고정되어 있으므로, \(n\)을 감소시켜야 하며 이는 코어 당 더 작은 배치 크기를 강제합니다.

수학적 인퍼런스 및 논증

\[\begin{align*} \text{Let } \mathcal{B} &\text{ be the batch size, } d_\text{model} \text{ and } d_\text{ff} \text{ be model dimensions, and } N \text{ be total cores.} \\ \text{Then, for data parallelism: } &n = N \text{ and } m = 1, \text{ no inter-core communication is required until gradient aggregation.} \\ \text{For model parallelism: } &n = 1 \text{ and } m = N, \text{ inter-core communication is required for each pass.} \\ \text{For expert and data parallelism: } &n = \text{number of experts, } E = \text{expert capacity. Each core calculates its own expert assignments locally.} \\ \text{In combining model and expert parallelism, } &\text{communication costs arise from token routing and internal model communications.} \end{align*}\]

이런 모든 전략은 계산 및 메모리 효율성을 최대화하면서 복잡한 모델을 효과적으로 학습하는 데 중요하며, 각 차원과 할당 전략의 수학적 이해가 필수적입니다.


6. 관련 연구

이 섹션에서는 신경망의 규모 확장과 관련된 다양한 기존 연구와 그 방법을 소개하고, 조건부 계산을 사용하여 이를 효과적으로 구현하는 현대적 접근 방식을 분석했습니다.

  • 신경망의 규모 확장 중요성 인식
  • 다양한 모델 병렬 처리 방법 제안
  • 조건부 계산을 통한 효율적 학습 연구

배경 및 선행연구

신경망에서 규모의 중요성은 널리 인식되어 왔으며, 파라미터가 수십억에 이르는 모델을 확장하는 여러 접근 방식이 제안되었습니다.

특히 모델 병렬 처리는 여러 코어에 가중치와 텐서를 분할하여 계산 부담을 분산시키는 방법으로 사용됩니다(Shazeer et al., 2018; Rajbhandari et al., 2019; Raffel et al., 2019; Brown et al., 2020; Shoeybi et al., 2019).

또한, Harlap et al. (2018)과 Huang et al. (2019)은 파이프라인 기반의 모델 병렬 처리를 제안하여 다른 레이어를 다른 디바이스에 분할하고 마이크로 배치를 파이프라인으로 처리하는 방식을 소개했습니다.

문제 정의 및 접근 방법

본 연구는 입력에 기반하여 동적으로 계산 결정을 내리는 조건부 계산 방식을 사용하는 특정 모델 클래스를 연구합니다.

이전 연구들에서 Cho와 Bengio (2014)는 모델의 hidden state에서 발생하는 특정 비트 패턴에 기반하여 가중치를 선택적으로 사용하는 방법을 제안했습니다.

또한, Eigen et al. (2013)은 밀집 행렬 곱셈과 ReLU 활성화를 통해 전문가 층을 쌓아 MNIST와 단조로운 음성 데이터에 대해 유망한 결과를 보여주었습니다.

컴퓨터 비전 분야에서는 Puigcerver et al. (2020)이 upstream 사전 훈련 동안 의미적 클래스에 기반하여 토큰을 수동으로 라우팅하고, downstream task 작업에 따라 관련 전문가를 선택적으로 사용하는 방법을 연구했습니다.

방법 및 수학적 논증

전문가 혼합(Mixture of Experts, MoE) 방식은 현대 심층 학습 구조에서 효과적임이 입증되었습니다(Shazeer et al., 2017). 이 연구에서는 LSTM 층 사이에 MoE 층을 추가하여 토큰을 전문가 조합으로 별도 라우팅하였으며, 이는 언어 모델링 및 기계 번역 벤치마크에서 최고의 결과를 달성했습니다. 이 MoE 층은 나중에 Transformer 구조로 재도입되었으나 NLP 결과는 동반되지 않았습니다(Shazeer et al., 2018). 최근의 머신러닝 인프라 발전을 통해 GShard(Lepikhin et al., 2020)는 XLA 컴파일러를 확장하여 100개 언어에 걸친 기계 번역을 향상시켰습니다. Fan et al. (2021)은 모델 파라미터를 언어 그룹별로 겹치지 않게 분할하는 다른 결정적 MoE 전략을 선택했습니다.

연구 결과 및 시사점

Transformer의 시퀀스 길이 차원에서의 희소성은 성공적인 기술로, Attention의 복잡성을 \(O(L^2)\)에서 감소시켜 이전에 가능했던 것보다 더 긴 시퀀스의 학습을 가능하게 했습니다(Child et al., 2019; Correia et al., 2019; Sukhbaatar et al., 2019; Kitaev et al., 2020; Zaheer et al., 2020; Beltagy et al., 2020). 이번 버전의 Switch Transformer는 Attention의 희소성을 사용하지 않지만, 이 기술은 보완적이며, 장기적으로 이를 결합하여 long context을 요구하는 작업에서 학습을 개선할 가능성이 있습니다.

스위치 트랜스포머는 규모 확장성이 뛰어나며, 추가적인 통신 비용과 라우팅 계산에도 불구하고 다양한 언어와 태스크에서 더 빠른 학습 속도와 높은 성능을 보였습니다.


1 Introduction

Large scale training has been an effective path towards flexible and powerful neural language models (Radford et al., 2018; Kaplan et al., 2020; Brown et al., 2020). Simple architectures— backed by a generous computational budget, data set size and parameter count—surpass more complicated algorithms (Sutton, 2019). An approach followed in Radford et al. (2018); Raffel et al. (2019); Brown et al. (2020) expands the model size of a densely-activated Transformer (Vaswani et al., 2017). While effective, it is also extremely computationally intensive (Strubell et al., 2019). Inspired by the success of model scale, but seeking greater computational efficiency, we instead propose a sparsely-activated expert model: the Switch Transformer. In our case the sparsity comes from activating a subset of the neural network weights for each incoming example.

Figure 1: Scaling and sample efficiency of Switch Transformers. Left Plot: Scaling properties for increasingly sparse (more experts) Switch Transformers. Right Plot: Negative log perplexity comparing Switch Transformers to T5 (Raffel et al., 2019) models using the same compute budget.

Sparse training is an active area of research and engineering (Gray et al., 2017; Gale et al., 2020), but as of today, machine learning libraries and hardware accelerators still cater to dense matrix multiplications. To have an efficient sparse algorithm, we start with the Mixture-of-Expert (MoE) paradigm (Jacobs et al., 1991; Jordan and Jacobs, 1994; Shazeer et al., 2017), and simplify it to yield training stability and computational benefits. MoE models have had notable successes in machine translation (Shazeer et al., 2017, 2018; Lepikhin et al., 2020), however, widespread adoption is hindered by complexity, communication costs, and training instabilities.

We address these issues, and then go beyond translation, to find that these class of algorithms are broadly valuable in natural language. We measure superior scaling on a diverse set of natural language tasks and across three regimes in NLP: pre-training, finetuning and multi-task training. While this work focuses on scale, we also show that the Switch Transformer architecture not only excels in the domain of supercomputers, but is beneficial even with only a few computational cores. Further, our large sparse models can be distilled (Hinton et al., 2015) into small dense versions while preserving 30% of the sparse model quality gain. Our contributions are the following:

  • The Switch Transformer architecture, which simplifies and improves over Mixture of Experts.
  • Scaling properties and a benchmark against the strongly tuned T5 model (Raffel et al., 2019) where we measure 7x+ pre-training speedups while still using the same FLOPS per token. We further show the improvements hold even with limited computational resources, using as few as two experts.
  • Successful distillation of sparse pre-trained and specialized fine-tuned models into small dense models. We reduce the model size by up to 99% while preserving 30% of the quality gains of the large sparse teacher.
  • Improved pre-training and fine-tuning techniques: (1) selective precision training that enables training with lower bfloat16 precision (2) an initialization scheme that allows for scaling to a larger number of experts and (3) increased expert regularization that improves sparse model fine-tuning and multi-task training.
  • A measurement of the pre-training benefits on multilingual data where we find a universal improvement across all 101 languages and with 91% of languages benefiting from 4x+ speedups over the mT5 baseline (Xue et al., 2020).
  • An increase in the scale of neural language models achieved by efficiently combining data, model, and expert-parallelism to create models with up to a trillion parameters. These models improve the pre-training speed of a strongly tuned T5-XXL baseline by 4x.

2. Switch Transformer

The guiding design principle for Switch Transformers is to maximize the parameter count of a Transformer model (Vaswani et al., 2017) in a simple and computationally efficient way. The benefit of scale was exhaustively studied in Kaplan et al. (2020) which uncovered powerlaw scaling with model size, data set size and computational budget. Importantly, this work advocates training large models on relatively small amounts of data as the computationally optimal approach.

Heeding these results, we investigate a fourth axis: increase the parameter count while keeping the floating point operations (FLOPs) per example constant. Our hypothesis is that the parameter count, independent of total computation performed, is a separately important axis on which to scale. We achieve this by designing a sparsely activated model that efficiently uses hardware designed for dense matrix multiplications such as GPUs and TPUs. Our work here focuses on TPU architectures, but these class of models may be similarly trained on GPU clusters. In our distributed training setup, our sparsely activated layers split unique weights on different devices. Therefore, the weights of the model increase with the number of devices, all while maintaining a manageable memory and computational footprint on each device.

Figure 2: Illustration of a Switch Transformer encoder block. We replace the dense feed forward network (FFN) layer present in the Transformer with a sparse Switch FFN layer (light blue). The layer operates independently on the tokens in the sequence. We diagram two tokens (x1 = “More” and x2 = “Parameters” below) being routed (solid lines) across four FFN experts, where the router independently routes each token. The switch FFN layer returns the output of the selected FFN multiplied by the router gate value (dotted-line).

2.1 Simplifying Sparse Routing

Mixture of Expert Routing. Shazeer et al. (2017) proposed a natural language Mixture-of-Experts (MoE) layer which takes as an input a token representation $x$ and then routes this to the best determined top-$k$ experts, selected from a set \(\{E_i(x)\}_{i=1}^N\) of \(N\) experts. The router variable \(W_r\) produces logits \(h(x) = W_r \cdot x\) which are normalized via a softmax distribution over the available \(N\) experts at that layer. The gate-value for expert \(i\) is given by,

\[p_i(x) = \frac{\exp(h_i(x))}{\sum_{j=1}^N \exp(h_j(x))}\]

The top-\(k\) gate values are selected for routing the token \(x\). If \(T\) is the set of selected top-\(k\) indices then the output computation of the layer is the linearly weighted combination of each expert’s computation on the token by the gate value,

\[y = \sum_{i \in T} p_i(x)E_i(x)\]

Switch Routing: Rethinking Mixture-of-Experts. Shazeer et al. (2017) conjectured that routing to \(k > 1\) experts was necessary in order to have non-trivial gradients to the routing functions. The authors intuited that learning to route would not work without the ability to compare at least two experts. Ramachandran and Le (2018) went further to study the top-\(k\) decision and found that higher \(k\)-values in lower layers in the model were important for models with many routing layers. Contrary to these ideas, we instead use a simplified strategy where we route to only a single expert. We show this simplification preserves model quality, reduces routing computation and performs better. This \(k = 1\) routing strategy is later referred to as a Switch layer. Note that for both MoE and Switch Routing, the gate value \(p_i(x)\) in Equation 2 permits differentiability of the router.

The benefits for the Switch layer are three-fold: (1) The router computation is reduced as we are only routing a token to a single expert. (2) The batch size (expert capacity) of each expert can be at least halved since each token is only being routed to a single expert.\(^3\) (3) The routing implementation is simplified and communication costs are reduced. Figure 3 shows an example of routing with different expert capacity factors.

Figure 3: Illustration of token routing dynamics. Each expert processes a fixed batch-size of tokens modulated by the capacity factor. Each token is routed to the expert with the highest router probability, but each expert has a fixed batch size of (total tokens / num experts) × capacity factor. If the tokens are unevenly dispatched then certain experts will overflow (denoted by dotted red lines), resulting in these tokens not being processed by this layer. A larger capacity factor alleviates this overflow issue, but also increases computation and communication costs (depicted by padded white/empty slots).

2.2 Efficient Sparse Routing

We use Mesh-Tensorflow (MTF) (Shazeer et al., 2018) which is a library, with similar semantics and API to Tensorflow (Abadi et al., 2016) that facilitates efficient distributed data and model parallel architectures. It does so by abstracting the physical set of cores to a logical mesh of processors. Tensors and computations may then be sharded per named dimensions, facilitating easy partitioning of models across dimensions. We design our model with TPUs in mind, which require statically declared sizes. Below we describe our distributed Switch Transformer implementation.

Distributed Switch Implementation. All of our tensor shapes are statically determined at compilation time, but our computation is dynamic due to the routing decisions at training and inference. Because of this, one important technical consideration is how to set the expert capacity. The expert capacity—the number of tokens each expert computes—is set by evenly dividing the number of tokens in the batch across the number of experts, and then further expanding by a capacity factor, A capacity factor greater than 1.0 creates additional buffer to accommodate for when tokens are not perfectly balanced across experts. If too many tokens are routed to an expert (referred to later as dropped tokens), computation is skipped and the token representation is passed directly to the next layer through the residual connection. Increasing the expert capacity is not without drawbacks, however, since high values will result in wasted computation and memory. This trade-off is explained in Figure 3. Empirically we find ensuring lower rates of dropped tokens are important for the scaling of sparse expert-models. Throughout our experiments we didn’t notice any dependency on the number of experts for the number of tokens dropped (typically < 1%). Using the auxiliary load balancing loss (next section) with a high enough coefficient ensured good load balancing. We study the impact that these design decisions have on model quality and speed in Table 1.

A Differentiable Load Balancing Loss.

To encourage a balanced load across experts we add an auxiliary loss (Shazeer et al., 2017, 2018; Lepikhin et al., 2020). As in Shazeer et al. (2018); Lepikhin et al. (2020), Switch Transformers simplifies the original design in Shazeer et al. (2017) which had separate load-balancing and importance-weighting losses. For each Switch layer, this auxiliary loss is added to the total model loss during training. Given \(N\) experts indexed by \(i = 1\) to \(N\) and a batch \(B\) with \(T\) tokens, the auxiliary loss is computed as the scaled dot-product between vectors \(\vec{f}\) and \(\vec{P}\),

\[\mathcal{L}_\text{aux} = N \cdot \alpha \cdot \vec{f} \cdot \vec{P}^\top\]

where \(f_i\) is the fraction of tokens dispatched to expert \(i\),

\[f_i = \frac{1}{T} \sum_{t=1}^T \mathbb{1}[\text{expert}(t) = i]\]

Since we seek uniform routing of the batch of tokens across the \(N\) experts, we desire both vectors to have values of \(1/N\). The auxiliary loss of Equation 4 encourages uniform routing since it is minimized under a uniform distribution. The objective can also be differentiated as

\[\nabla_{\vec{P}} \mathcal{L}_\text{aux} = N \cdot \alpha \cdot \vec{f}\]

the \(\vec{P}\)-vector is differentiable, but the \(\vec{f}\)-vector is not. The final loss is multiplied by expert count \(N\) to keep the loss constant as the number of experts varies since under uniform routing \(\langle\vec{f}, \vec{1}\rangle = N/N\). Finally, a hyper-parameter \(\alpha\) is a multiplicative coefficient for these auxiliary losses; throughout this work we use an \(\alpha = 10^{-2}\) which was sufficiently large to ensure load balancing while small enough to not to overwhelm the primary cross-entropy objective. We swept hyper-parameter ranges of \(\alpha\) from \(10^{-1}\) to \(10^{-5}\) in powers of 10 and found \(10^{-2}\) balanced load quickly without interfering with training loss.

3 See Section 2.2 for a technical description.

2.3 Putting It All Together: The Switch Transformer

Our first test of the Switch Transformer starts with pre-training on the “Colossal Clean Crawled Corpus” (C4), introduced in (Raffel et al., 2019). For our pre-training objective, we use a masked language modeling task (Taylor, 1953; Fedus et al., 2018; Devlin et al., 2018) where the model is trained to predict missing tokens. In our pre-훈련 설정, as determined in Raffel et al. (2019) to be optimal, we drop out 15% of tokens and then replace the masked sequence with a single sentinel token. To compare our models, we record the negative log perplexity.4 Throughout all tables in the paper, ↑ indicates that a higher value for that metric is better and vice-versa for ↓. A comparison of all the models studied in this work are in Table 9.

A head-to-head comparison of the Switch Transformer and the MoE Transformer is presented in Table 1. Our Switch Transformer model is FLOP-matched to ‘T5-Base’ (Raffel et al., 2019) (same amount of computation per token is applied). The MoE Transformer, using top-2 routing, has two experts which each apply a separate FFN to each token and thus its FLOPS are larger. All models were trained for the same number of steps on identical hardware. Note that the MoE model going from capacity factor 2.0 to 1.25 actually slows down (840 to 790) in the above experiment setup, which is unexpected.5

We highlight three key findings from Table 1: (1) Switch Transformers outperform both carefully tuned dense models and MoE Transformers on a speed-quality basis. For a fixed amount of computation and wall-clock time, Switch Transformers achieve the best result. (2) The Switch Transformer has a smaller computational footprint than the MoE counterpart. If we increase its size to match the training speed of the MoE Transformer, we find this outperforms all MoE and Dense models on a per step basis as well. (3) Switch Transformers perform better at lower capacity factors (1.0, 1.25). Smaller expert capacities are indicative of the scenario in the large model regime where model memory is very scarce and the capacity factor will want to be made as small as possible.

2.4 Improved Training and Fine-Tuning Techniques

Sparse expert models may introduce training difficulties over a vanilla Transformer. Instability can result because of the hard-switching (routing) decisions at each of these layers. Further, low precision formats like bfloat16 (Wang and Kanwar, 2019) can exacerbate issues in the softmax computation for our router. We describe training difficulties here and the methods we use to overcome them to achieve stable and scalable training.

4 We use log base-e for this metric so the units are nats. 5. Note that speed measurements are both a function of the algorithm and the implementation details. Switch Transformer reduces the necessary computation relative to MoE (algorithm), but the final speed differences are impacted by low-level optimizations (implementation).

Table 1: Benchmarking Switch versus MoE. Head-to-head comparison measuring per step and per time benefits of the Switch Transformer over the MoE Transformer and T5 dense baselines. We measure quality by the negative log perplexity and the time to reach an arbitrary chosen quality threshold of Neg. Log Perp.=-1.50. All MoE and Switch Transformer models use 128 experts, with experts at every other feed-forward layer. For Switch-Base+, we increase the model size until it matches the speed of the MoE model by increasing the model hidden-size from 768 to 896 and the number of heads from 14 to 16. All models are trained with the same amount of computation (32 cores) and on the same hardware (TPUv3). Further note that all our models required pre-training beyond 100k steps to achieve our level threshold of -1.50. † T5-Base did not achieve this negative log perplexity in the 100k steps the models were trained.

Selective precision with large sparse models. Model instability hinders the ability to train using efficient bfloat16 precision, and as a result, Lepikhin et al. (2020) trains with float32 precision throughout their MoE Transformer. However, we show that by instead selectively casting to float32 precision within a localized part of the model, stability may be achieved, without incurring expensive communication cost of float32 tensors. This technique is inline with modern mixed precision training strategies where certain parts of the model and gradient updates are done in higher precision Micikevicius et al. (2017). Table 2 shows that our approach permits nearly equal speed to bfloat16 training while conferring the training stability of float32.

To achieve this, we cast the router input to float32 precision. The router function takes the tokens as input and produces the dispatch and combine tensors used for the selection and recombination of expert computation (refer to Code Block 15 in the Appendix for details). Importantly, the float32 precision is only used within the body of the router function—on computations local to that device. Because the resulting dispatch and combine tensors are recast to bfloat16 precision at the end of the function, no expensive float32 tensors are broadcast through all-to-all communication operations, but we still benefit from the increased stability of float32.

Table 2: Selective precision. We cast the local routing operations to float32 while preserving bfloat16 precision elsewhere to stabilize our model while achieving nearly equal speed to (unstable) bfloat16-precision training. We measure the quality of a 32 expert model after a fixed step count early in training its speed performance. For both Switch-Base in float32 and with Selective prevision we notice similar learning dynamics.

Smaller parameter initialization for stability. Appropriate initialization is critical to successful training in deep learning and we especially observe this to be true for Switch Transformer. We initialize our weight matrices by drawing elements from a truncated normal distribution with mean \(\mu = 0\) and standard deviation \(\sigma = \sqrt{s/n}\) where \(s\) is a scale hyper-parameter and \(n\) is the number of input units in the weight tensor (e.g. fan-in).\(^6\)

As an additional remedy to the instability, we recommend reducing the default Transformer initialization scale \(s = 1.0\) by a factor of 10. This both improves quality and reduces the likelihood of destabilized training in our experiments. Table 3 measures the improvement of the model quality and reduction of the variance early in training. We find that the average model quality, as measured by the Neg. Log Perp., is dramatically improved and there is a far reduced variance across runs. Further, this same initialization scheme is broadly effective for models spanning several orders of magnitude. We use the same approach to stably train models as small as our 223M parameter baseline to enormous models in excess of one trillion parameters.

Table 3: Reduced initialization scale improves stability. Reducing the initialization scale results in better model quality and more stable training of Switch Transformer. Here we record the average and standard deviation of model quality, measured by the negative log perplexity, of a 32 expert model after 3.5k steps (3 random seeds each).

6 Values greater than two standard deviations from the mean are resampled.

Regularizing large sparse models. Our paper considers the common NLP approach of pre-training on a large corpus followed by fine-tuning on smaller downstream tasks such as summarization or question answering. One issue that naturally arises is overfitting since many fine-tuning tasks have very few examples. During fine-tuning of standard Transformers, Raffel et al. (2019) use dropout (Srivastava et al., 2014) at each layer to prevent overfitting. Our Switch Transformers have significantly more parameters than the FLOP matched dense baseline, which can lead to more severe overfitting on these smaller downstream tasks.

Table 4: Fine-tuning regularization results. A sweep of dropout rates while fine-tuning Switch Transformer models pre-trained on 34B tokens of the C4 data set (higher numbers are better). We observe that using a lower standard dropout rate at all non-expert layer, with a much larger dropout rate on the expert feed-forward layers, to perform the best.

We thus propose a simple way to alleviate this issue during fine-tuning: increase the dropout inside the experts, which we name as expert dropout. During fine-tuning we simply increase the dropout rate by a significant amount only at the interim feed-forward computation at each expert layer. Table 4 has the results for our expert dropout protocol. We observe that simply increasing the dropout across all layers leads to worse performance. However, setting a smaller dropout rate (0.1) at non-expert layers and a much larger dropout rate (0.4) at expert layers leads to performance improvements on four smaller downstream tasks.

3. Scaling Properties

We present a study of the scaling properties of the Switch Transformer architecture during pre-training. Per Kaplan et al. (2020), we consider a regime where the model is not bottlenecked by either the computational budget or amount of data. To avoid the data bottleneck, we use the large C4 corpus with over 180B target tokens (Raffel et al., 2019) and we train until diminishing returns are observed.

The number of experts is the most efficient dimension for scaling our model. Increasing the experts keeps the computational cost approximately fixed since the model only selects one expert per token, regardless of the number of experts to choose from. The router must compute a probability distribution over more experts, however, this is a lightweight computation of cost O(dmodel × num experts) where dmodel is the embedding dimension of tokens passed between the layers. In this section, we consider the scaling properties on a step-basis and a time-basis with a fixed computational budget.

3.1 Scaling Results on a Step-Basis

Figure 4 demonstrates consistent scaling benefits with the number of experts when training all models for a fixed number of steps. We observe a clear trend: when keeping the FLOPS per token fixed, having more parameters (experts) speeds up training. The left Figure demonstrates consistent scaling properties (with fixed FLOPS per token) between sparse model parameters and test loss. This reveals the advantage of scaling along this additional axis of sparse model parameters. Our right Figure measures sample efficiency of a dense model variant and four FLOP-matched sparse variants. We find that increasing the number of experts leads to more sample efficient models. Our Switch-Base 64 expert model achieves the same performance of the T5-Base model at step 60k at step 450k, which is a 7.5x speedup in terms of step time. In addition, consistent with the findings of Kaplan et al. (2020), we find that larger models are also more sample efficient—learning more quickly for a fixed number of observed tokens.

Figure 4: Scaling properties of the Switch Transformer. Left Plot: We measure the quality improvement, as measured by perplexity, as the parameters increase by scaling the number of experts. The top-left point corresponds to the T5-Base model with 223M parameters. Moving from top-left to bottom-right, we double the number of experts from 2, 4, 8 and so on until the bottom-right point of a 256 expert model with 14.7B parameters. Despite all models using an equal computational budget, we observe consistent improvements scaling the number of experts. Right Plot: Negative log perplexity per step sweeping over the number of experts. The dense baseline is shown with the purple line and we note improved sample efficiency of our Switch-Base models.

3.2 Scaling Results on a Time-Basis

Figure 4 demonstrates that on a step basis, as we increase the number of experts, the performance consistently improves. While our models have roughly the same amount of FLOPS per token as the baseline, our Switch Transformers incurs additional communication costs across devices as well as the extra computation of the routing mechanism. Therefore, the increased sample efficiency observed on a step-basis doesn’t necessarily translate to a better model quality as measured by wall-clock.

This raises the question: For a fixed training duration and computational budget, should one train a dense or a sparse model?

Figure 5: Speed advantage of Switch Transformer. All models trained on 32 TPUv3 cores with equal FLOPs per example. For a fixed amount of computation and training time, Switch Transformers significantly outperform the dense Transformer baseline. Our 64 expert Switch-Base model achieves the same quality in one-seventh the time of the T5-Base and continues to improve.

Figures 5 and 6 address this question. Figure 5 measures the pre-training model quality as a function of time. For a fixed training duration and computational budget, Switch Transformers yield a substantial speed-up. In this setting, our Switch-Base 64 expert model trains in one-seventh the time that it would take the T5-Base to get similar perplexity.

3.3 Scaling Versus a Larger Dense Model

The above analysis shows that a computationally-matched dense model is outpaced by its Switch counterpart. Figure 6 considers a different scenario: what if we instead had allocated our resources to a larger dense model? We do so now, measuring Switch-Base against the next strong baseline, T5-Large. But despite T5-Large applying 3.5x more FLOPs per token,

Switch-Base is still more sample efficient and yields a 2.5x speedup. Furthermore, more gains can be had simply by designing a new, larger sparse version, Switch-Large, which is FLOP-matched to T5-Large. We do this and demonstrate superior scaling and fine-tuning in the following section.

Figure 6: Scaling Transformer models with Switch layers or with standard dense model scaling. Left Plot: Switch-Base is more sample efficient than both the T5-Base, and T5-Large variant, which applies 3.5x more FLOPS per token. Right Plot: As before, on a wall-clock basis, we find that Switch-Base is still faster, and yields a 2.5x speedup over T5-Large.

Section 3 demonstrated the superior scaling properties while pre-training, but we now validate that these gains translate to improved language learning abilities on downstream tasks. We begin by fine-tuning on a diverse set of NLP tasks. Next we study reducing the memory footprint of our sparse models by over 90% by distilling into small—and easily deployed—dense baselines. Finally, we conclude this section measuring the improvements in a multi-task, multilingual setting, where we show that Switch Transformers are strong multi-task learners, improving over the multilingual T5-base model across all 101 languages.

4 Downstream Results

4.1 Fine-Tuning

Baseline and Switch models used for fine-tuning. Our baselines are the highly-tuned 223M parameter T5-Base model and the 739M parameter T5-Large model (Raffel et al., 2019). For both versions, we design a FLOP-matched Switch Transformer, with many more parameters, which is summarized in Table 9.7 Our baselines differ slightly from those in Raffel et al. (2019) because we pre-train on an improved C4 corpus which removes intraexample text duplication and thus increases the efficacy as a pre-training task Lee et al. (2021). In our protocol we pre-train with 220 (1,048,576) tokens per batch for 550k steps amounting to 576B total tokens. We then fine-tune across a diverse set of tasks using a dropout rate of 0.1 for all layers except the Switch layers, which use a dropout rate of 0.4 (see Table 4). We fine-tune using a batch-size of 1M for 16k steps and for each task, we evaluate model quality every 200-steps and report the peak performance as computed on the validation set.

7 FLOPS are calculated for the forward pass as done in Kaplan et al. (2020).

Fine-tuning tasks and data sets. We select tasks probing language capabilities including question answering, summarization and knowledge about the world. The language benchmarks GLUE (Wang et al., 2018) and SuperGLUE (Wang et al., 2019) are handled as composite mixtures with all the tasks blended in proportion to the amount of tokens present in each. These benchmarks consist of tasks requiring sentiment analysis (SST2), word sense disambiguation (WIC), sentence similarty (MRPC, STS-B, QQP), natural language inference (MNLI, QNLI, RTE, CB), question answering (MultiRC, RECORD, BoolQ), coreference resolution (WNLI, WSC) and sentence completion (COPA) and sentence acceptability (CoLA). The CNNDM (Hermann et al., 2015) and BBC XSum (Narayan et al., 2018) data sets are used to measure the ability to summarize articles. Question answering is probed with the SQuAD data set (Rajpurkar et al., 2016) and the ARC Reasoning Challenge (Clark et al., 2018). And as in Roberts et al. (2020), we evaluate the knowledge of our models by fine-tuning on three closed-book question answering data sets: Natural Questions (Kwiatkowski et al., 2019), Web Questions (Berant et al., 2013) and Trivia QA (Joshi et al., 2017). Closed-book refers to questions posed with no supplemental reference or context material. To gauge the model’s common sense reasoning we evaluate it on the Winogrande Schema Challenge (Sakaguchi et al., 2020). And finally, we test our model’s natural language inference capabilities on the Adversarial NLI Benchmark (Nie et al., 2019).

Fine-tuning metrics. The following evaluation metrics are used throughout the paper: We report the average scores across all subtasks for GLUE and SuperGLUE. The Rouge-2 metric is used both the CNNDM and XSum. In SQuAD and the closed book tasks (Web, Natural, and Trivia Questions) we report the percentage of answers exactly matching the target (refer to Roberts et al. (2020) for further details and deficiency of this measure). Finally, in ARC Easy, ARC Challenge, ANLI, and Winogrande we report the accuracy of the generated responses.

Fine-tuning results. We observe significant downstream improvements across many natural language tasks. Notable improvements come from SuperGLUE, where we find FLOP-matched Switch variants improve by 4.4 and 2 percentage points over the T5-Base and T5-Large baselines, respectively as well as large improvements in Winogrande, closed book Trivia QA, and XSum.8 In our fine-tuning study, the only tasks where we do not observe gains are on the AI2 Reasoning Challenge (ARC) data sets where the T5-Base outperforms Switch-Base on the challenge data set and T5-Large outperforms Switch-Large on the easy data set. Taken as a whole, we observe significant improvements spanning both reasoning and knowledge-heavy tasks. This validates our architecture, not just as one that pre-trains well, but can translate quality improvements to downstream tasks via fine-tuning.

8 Our T5 and Switch models were pre-trained with 220 tokens per batch for 550k steps on a revised C4 data set for fair comparisons.

Table 5: Fine-tuning results. Fine-tuning results of T5 baselines and Switch models across a diverse set of natural language tests (validation sets; higher numbers are better). We compare FLOP-matched Switch models to the T5-Base and T5-Large baselines. For most tasks considered, we find significant improvements of the Switchvariants. We observe gains across both model sizes and across both reasoning and knowledge-heavy language tasks.

4.2 Distillation

Deploying massive neural networks with billions, or trillions, of parameters is inconvenient. To alleviate this, we study distilling (Hinton et al., 2015) large sparse models into small dense models. Future work could additionally study distilling large models into smaller sparse models.

Distillation techniques.

In Table 6 we study a variety of distillation techniques. These techniques are built off of Sanh et al. (2019), who study distillation methods for BERT models. We find that initializing the dense model with the non-expert weights yields a modest improvement. This is possible since all models are FLOP matched, so non-expert layers will have the same dimensions. Since expert layers are usually only added at every or every other FFN layer in a Transformer, this allows for many of the weights to be initialized with trained parameters. Furthermore, we observe a distillation improvement using a mixture of 0.25 for the teacher probabilities and 0.75 for the ground truth label. By combining both techniques we preserve ≈ 30% of the quality gains from the larger sparse models with only ≈ 1/20th of the parameters. The quality gain refers to the percent of the quality difference between Switch-Base (Teacher) and T5-Base (Student). Therefore, a quality gain of 100% implies the Student equals the performance of the Teacher.

Table 6: Distilling Switch Transformers for Language Modeling. Initializing T5-Base with the non-expert weights from Switch-Base and using a loss from a mixture of teacher and ground-truth labels obtains the best performance. We can distill 30% of the performance improvement of a large sparse model with 100x more parameters back into a small dense model. For a final baseline, we find no improvement of T5-Base initialized with the expert weights, but trained normally without distillation.

Achievable compression rates. Using our best distillation technique described in Table 6, we distill a wide variety of sparse models into dense models. We distill SwitchBase versions, sweeping over an increasing number of experts, which corresponds to varying between 1.1B to 14.7B parameters. Through distillation, we can preserve 37% of the quality gain of the 1.1B parameter model while compressing 82%. At the extreme, where we compress the model 99%, we are still able to maintain 28% of the teacher’s model quality improvement.

Distilling a fine-tuned model. We conclude this with a study of distilling a finetuned sparse model into a dense model. Table 8 shows results of distilling a 7.4B parameter Switch-Base model, fine-tuned on the SuperGLUE task, into the 223M T5-Base. Similar to our pre-training results, we find we are able to preserve 30% of the gains of the sparse model when distilling into a FLOP matched dense variant. One potential future avenue, not considered here, may examine the specific experts being used for fine-tuning tasks and extracting them to achieve better model compression.

4.3 Multilingual Learning

In our final set of downstream experiments, we measure the model quality and speed tradeoffs while pre-training on a mixture of 101 different languages. We build and benchmark off the recent work of mT5 (Xue et al., 2020), a multilingual extension to T5. We pre-train on the multilingual variant of the Common Crawl data set (mC4) spanning 101 languages introduced in mT5, but due to script variants within certain languages, the mixture contains 107 tasks.

In Figure 7 we plot the quality improvement in negative log perplexity for all languages of a FLOP-matched Switch model, mSwitch-Base to the T5 base variant, mT5-Base. After pre-training both versions for 1M steps, we find that on all 101 languages considered, Switch Transformer increases the final negative log perplexity over the baseline. In Figure 8, we present a different view and now histogram the per step speed-up of using Switch Transformer over the mT5-Base.9 We find a mean speed-up over mT5-Base of 5x and that 91% of languages achieve at least a 4x speedup. This presents evidence that Switch Transformers are effective multi-task and multi-lingual learners.

Table 7: Distillation compression rates. We measure the quality when distilling large sparse models into a dense baseline. Our baseline, T5-Base, has a -1.636 Neg. Log Perp. quality. In the right columns, we then distill increasingly large sparse models into this same architecture. Through a combination of weight-initialization and a mixture of hard and soft losses, we can shrink our sparse teachers by 95%+ while preserving 30% of the quality gain. However, for significantly better and larger pre-trained teachers, we expect larger student models would be necessary to achieve these compression rates.

Table 8: Distilling a fine-tuned SuperGLUE model. We distill a Switch-Base model finetuned on the SuperGLUE tasks into a T5-Base model. We observe that on smaller data sets our large sparse model can be an effective teacher for distillation. We find that we again achieve 30% of the teacher’s performance on a 97% compressed model.

5. Designing Models with Data, Model, and Expert-Parallelism

Arbitrarily increasing the number of experts is subject to diminishing returns (Figure 4). Here we describe complementary scaling strategies. The common way to scale a Transformer is to increase dimensions in tandem, like dmodel or df f . This increases both the parameters and computation performed and is ultimately limited by the memory per accelerator. Once it exceeds the size of the accelerator’s memory, single program multiple data (SPMD) modelparallelism can be employed. This section studies the trade-offs of combining data, model, and expert-parallelism.

9 The speedup on a step basis is computed as the ratio of the number of steps for the baseline divided by the number of steps required by our model to reach that same quality.

Figure 7: Multilingual pre-training on 101 languages.

Improvements of Switch T5 Base model over dense baseline when multi-task training on 101 languages. We observe Switch Transformers to do quite well in the multi-task training setup and yield improvements on all 101 languages.

Figure 8: Multilingual pre-training on 101 languages. We histogram for each language, the step speedup of Switch Transformers over the FLOP matched T5 dense baseline to reach the same quality. Over all 101 languages, we achieve a mean step speedup over mT5-Base of 5x and, for 91% of languages, we record a 4x, or greater, speedup to reach the final perplexity of mT5-Base.

Reviewing the Feed-Forward Network (FFN) Layer. We use the FFN layer as an example of how data, model and expert-parallelism works in Mesh TensorFlow (Shazeer et al., 2018) and review it briefly here. We assume B tokens in the batch, each of dimension

Model: Both the input (\(x\)) and output (\(y\)) of the FFN are of size \([\mathcal{B}, d_\text{model}]\) and the intermediate (\(h\)) is of size \([\mathcal{B}, d_\text{ff}]\) where \(d_\text{ff}\) is typically several times larger than \(d_\text{model}\). In the FFN, the intermediate is \(h = xW_\text{in}\) and then the output of the layer is \(y = \text{ReLU}(h)W_\text{out}\). Thus \(W_\text{in}\) and \(W_\text{out}\) are applied independently to each token and have sizes \([d_\text{model}, d_\text{ff}]\) and \([d_\text{ff}, d_\text{model}]\).

We describe two aspects of partitioning: how the weights and batches of data divide over cores, depicted in Figure 9. We denote all cores available as \(N\) which Mesh Tensorflow may then remap into a logical multidimensional mesh of processors. Here we create a two-dimensional logical mesh, with one dimension representing the number of ways for data-parallel sharding (\(n\)) and the other, the model-parallel sharding (\(m\)). The total cores must equal the ways to shard across both data and model-parallelism, e.g. \(N = n \times m\). To shard the layer across cores, the tensors containing that batch of \(\mathcal{B}\) tokens are sharded across \(n\) data-parallel cores, so each core contains \(\mathcal{B}/n\) tokens. Tensors and variables with \(d_\text{ff}\) are then sharded across \(m\) model-parallel cores. For the variants with experts-layers, we consider \(E\) experts, each of which can process up to \(C\) tokens.

Where:

  • \(\mathcal{B}\) is the number of tokens in the batch.
  • \(N\) is the number of total cores.
  • \(n\) is the number of ways for data-parallelism sharding.
  • \(m\) is the number of ways for model-parallelism sharding.
  • \(E\) is the number of experts in Switch layers.
  • \(C\) is the expert capacity, the batch size of each expert.

5.1 Data Parallelism

When training data parallel models, which is the standard for distributed training, then all cores are allocated to the data-parallel dimension or n = N, m = 1. This has the advantage that no communication is needed until the entire forward and backward pass is finished and the gradients need to be then aggregated across all cores. This corresponds to the left-most column of Figure 9.

5.2 Model Parallelism

We now consider a scenario where all cores are allocated exclusively to the model-parallel dimension and so \(n = 1, m = N\). Now all cores must keep the full \(\mathcal{B}\) tokens and each core will contain a unique slice of the weights. For each forward and backward pass, a communication cost is now incurred. Each core sends a tensor of \([\mathcal{B}, d_\text{model}]\) to compute the second matrix multiplication \(\text{ReLU}(h)W_\text{out}\) because the \(d_\text{ff}\) dimension is partitioned and must be summed over. As a general rule, whenever a dimension that is partitioned across cores must be summed, then an all-reduce operation is added for both the forward and backward pass. This contrasts with pure data parallelism where an all-reduce only occurs at the end of the entire forward and backward pass.

Figure 9: Data and weight partitioning strategies. Each 4×4 dotted-line grid represents 16 cores and the shaded squares are the data contained on that core (either model weights or batch of tokens). We illustrate both how the model weights and the data tensors are split for each strategy. First Row: illustration of how model weights are split across the cores. Shapes of different sizes in this row represent larger weight matrices in the Feed Forward Network (FFN) layers (e.g larger df f sizes). Each color of the shaded squares identifies a unique weight matrix. The number of parameters per core is fixed, but larger weight matrices will apply more computation to each token. Second Row: illustration of how the data batch is split across cores. Each core holds the same number of tokens which maintains a fixed memory usage across all strategies. The partitioning strategies have different properties of allowing each core to either have the same tokens or different tokens across cores, which is what the different colors symbolize.

5.3 Model and Data Parallelism

It is common to mix both model and data parallelism for large scale models, which was done in the largest T5 models (Raffel et al., 2019; Xue et al., 2020) and in GPT-3 (Brown et al., 2020). With a total of N = n × m cores, now each core will be responsible for B/n tokens and df f /m of both the weights and intermediate activation. In the forward and backward pass each core communicates a tensor of size [B/n, dmodel] in an all-reduce operation.

5.4 Expert and Data Parallelism

Next we describe the partitioning strategy for expert and data parallelism. Switch Transformers will allocate all of their cores to the data partitioning dimension \(n\), which will also correspond to the number of experts in the model. For each token per core a router locally computes assignments to the experts. The output is a binary matrix of size \([n, \mathcal{B}/n, E, C]\) which is partitioned across the first dimension and determines expert assignment. This binary matrix is then used to do a gather via matrix multiplication with the input tensor of \([n, \mathcal{B}/n, d_\text{model}]\) resulting in the final tensor of shape \([n, E, C, d_\text{model}]\), which is sharded across the first dimension. Because each core has its own expert, we do an all-to-all communication of size \([E, C, d_\text{model}]\) to now shard the \(E\) dimension instead of the \(n\)-dimension. There are additional communication costs of \(\mathbf{bfloat16}\) tensors of size \(E \times C \times d_\text{model}\) in the forward pass to analogously receive the tokens from each expert located on different cores. See Appendix F for a detailed analysis of the expert partitioning code.

Where:

  • \(n\) is the number of ways for data-parallelism sharding.
  • \(\mathcal{B}\) is the number of tokens in the batch.
  • \(E\) is the number of experts.
  • \(C\) is the expert capacity, the batch size of each expert.
  • \(d_\text{model}\) is the model dimension size.

5.5 Expert, Model and Data Parallelism

In the design of our best model, we seek to balance the FLOPs per token and the parameter count. When we scale the number of experts $E$, we increase the number of parameters, but do not change the FLOPs per token. In order to increase FLOPs, we must also increase the \(d_\text{ff}\) dimension (which also increases parameters, but at a slower rate). This presents a trade-off: as we increase \(d_\text{ff}\) we will run out of memory per core, which then necessitates increasing \(m\). But since we have a fixed number of cores \(N\), and \(N = n \times m\), we must decrease \(n\), which forces use of a smaller batch-size (in order to hold tokens per core constant).

Where:

  • \(E\) is the number of experts
  • \(d_\text{ff}\) is the feed-forward dimension
  • \(N\) is the total number of cores
  • \(n\) is the number of ways for data-parallelism sharding
  • \(m\) is the number of ways for model-parallelism sharding

The key trade-off is that increasing \(d_\text{ff}\) to gain more FLOPs requires increasing \(m\), which in turn forces a decrease in \(n\) (since \(N\) is fixed) and therefore a smaller batch size per core.

When combining both model and expert-parallelism, we will have all-to-all communication costs from routing the tokens to the correct experts along with the internal all-reduce communications from the model parallelism. Balancing the FLOPS, communication costs and memory per core becomes quite complex when combining all three methods where the best mapping is empirically determined. See our further analysis in section 5.6 for how the number of experts effects the downstream performance as well.

5.6 Towards Trillion Parameter Models

Combining expert, model and data parallelism, we design two large Switch Transformer models, one with \(395\) billion and \(1.6\) trillion parameters, respectively. We study how these models perform on both up-stream pre-training as language models and their downstream fine-tuning performance. The parameters, FLOPs per sequence and hyper-parameters of the two different models are listed below in Table 9. Standard hyper-parameters of the Transformer, including \(d_\text{model}\), \(d_\text{ff}\), \(d_\text{kv}\), number of heads and number of layers are described, as well as a less common feature, \(\text{FFN}_\text{GEGLU}\), which refers to a variation of the FFN layer where the expansion matrix is substituted with two sets of weights which are non-linearly combined (Shazeer, 2020).

The Switch-C model is designed using only expert-parallelism, and no model-parallelism, as described earlier in Section 5.4. As a result, the hyper-parameters controlling the width, depth, number of heads, and so on, are all much smaller than the T5-XXL model. In contrast, the Switch-XXL is FLOP-matched to the T5-XXL model, which allows for larger dimensions of the hyper-parameters, but at the expense of additional communication costs induced by model-parallelism (see Section 5.5 for more details).

Table 9: Switch model design and pre-training performance. We compare the hyperparameters and pre-training performance of the T5 models to our Switch Transformer variants. The last two columns record the pre-training model quality on the C4 data set after 250k and 500k steps, respectively. We observe that the SwitchC Transformer variant is 4x faster to a fixed perplexity (with the same compute budget) than the T5-XXL model, with the gap increasing as training progresses.

Sample efficiency versus T5-XXL. In the final two columns of Table 9 we record the negative log perplexity on the C4 corpus after 250k and 500k steps, respectively. After 250k steps, we find both Switch Transformer variants to improve over the T5-XXL version’s negative log perplexity by over 0.061.10 To contextualize the significance of a gap of 0.061, we note that the T5-XXL model had to train for an additional 250k steps to increase 0.052. The gap continues to increase with additional training, with the Switch-XXL model out-performing the T5-XXL by 0.087 by 500k steps.

Training instability. However, as described in the introduction, large sparse models can be unstable, and as we increase the scale, we encounter some sporadic issues. We find that the larger Switch-C model, with 1.6T parameters and 2048 experts, exhibits no training instability at all. Instead, the Switch XXL version, with nearly 10x larger FLOPs per sequence, is sometimes unstable. As a result, though this is our better model on a step-basis, we do not pre-train for a full 1M steps, in-line with the final reported results of T5 (Raffel et al., 2019).

10 This reported quality difference is a lower bound, and may actually be larger. The T5-XXL was pretrained on an easier C4 data set which included duplicated, and thus easily copied, snippets within examples.

Reasoning fine-tuning performance. As a preliminary assessment of the model quality, we use a Switch-XXL model partially pre-trained on 503B tokens, or approximately half the text used by the T5-XXL model. Using this checkpoint, we conduct multi-task training for efficiency, where all tasks are learned jointly, rather than individually fine-tuned. We find that SQuAD accuracy on the validation set increases to 89.7 versus state-of-the-art of 91.3. Next, the average SuperGLUE test score is recorded at 87.5 versus the T5 version obtaining a score of 89.3 compared to the state-of-the-art of 90.0 (Wang et al., 2019). On ANLI (Nie et al., 2019), Switch XXL improves over the prior state-of-the-art to get a 65.7 accuracy versus the prior best of 49.4 (Yang et al., 2020). We note that while the SwitchXXL has state-of-the-art Neg. Log Perp. on the upstream pre-training task, its gains have not yet fully translated to SOTA downstream performance. We study this issue more in Appendix E.

Knowledge-based fine-tuning performance. Finally, we also conduct an early examination of the model’s knowledge with three closed-book knowledge-based tasks: Natural Questions, WebQuestions and TriviaQA, without additional pre-training using Salient Span Masking (Guu et al., 2020). In all three cases, we observe improvements over the prior stateof-the-art T5-XXL model (without SSM). Natural Questions exact match increases to 34.4 versus the prior best of 32.8, Web Questions increases to 41.0 over 37.2, and TriviaQA increases to 47.5 versus 42.9.

Summing up, despite training on less than half the data of other models, we already find comparable, and sometimes state-of-the-art, model quality. Currently, the Switch Transformer translates substantial upstream gains better to knowledge-based tasks, than reasoning-tasks (see Appendix E). Extracting stronger fine-tuning performance from large expert models is an active research question, and the pre-training perplexity indicates future improvements should be possible.

The importance of scale in neural networks is widely recognized and several approaches have been proposed. Recent works have scaled models to billions of parameters through using model parallelism (e.g. splitting weights and tensors across multiple cores) (Shazeer et al., 2018; Rajbhandari et al., 2019; Raffel et al., 2019; Brown et al., 2020; Shoeybi et al., 2019). Alternatively, Harlap et al. (2018); Huang et al. (2019) propose using pipeline based model parallelism, where different layers are split across devices and micro-batches are pipelined to the different layers. Finally, Product Key networks (Lample et al., 2019) were proposed to scale up the capacity of neural networks by doing a lookup for learnable embeddings based on the incoming token representations to a given layer.

Our work studies a specific model in a class of methods that do conditional computation, where computation decisions are made dynamically based on the input. Cho and Bengio (2014) proposed adaptively selecting weights based on certain bit patterns occuring in the model hidden-states. Eigen et al. (2013) built stacked expert layers with dense matrix multiplications and ReLU activations and showed promising results on jittered MNIST and monotone speech. In computer vision Puigcerver et al. (2020) manually route tokens based on semantic classes during upstream pre-training and then select the relevant experts to be used according to the downstream task.

Mixture of Experts (MoE), in the context of modern deep learning architectures, was proven effective in Shazeer et al. (2017). That work added an MoE layer which was stacked between LSTM (Hochreiter and Schmidhuber, 1997) layers, and tokens were separately routed to combinations of experts. This resulted in state-of-the-art results in language modeling and machine translation benchmarks. The MoE layer was reintroduced into the Transformer architecture by the Mesh Tensorflow library (Shazeer et al., 2018) where MoE layers were introduced as a substitute of the FFN layers, however, there were no accompanying NLP results. More recently, through advances in machine learning infrastructure, GShard (Lepikhin et al., 2020), which extended the XLA compiler, used the MoE Transformer to dramatically improve machine translation across 100 languages. Finally Fan et al. (2021) chooses a different deterministic MoE strategy to split the model parameters into non-overlapping groups of languages.

Sparsity along the sequence length dimension (L) in the Transformer attention patterns has been a successful technique to reduce the attention complexity from O(L2) (Child et al., 2019; Correia et al., 2019; Sukhbaatar et al., 2019; Kitaev et al., 2020; Zaheer et al., 2020; Beltagy et al., 2020). This has enabled learning longer sequences than previously possible. This version of the Switch Transformer does not employ attention sparsity, but these techniques are complimentary, and, as future work, these could be combined to potentially improve learning on tasks requiring long contexts.

7. Discussion

We pose and discuss questions about the Switch Transformer, and sparse expert models generally, where sparsity refers to weights, not on attention patterns.

Isn’t Switch Transformer better due to sheer parameter count? Yes, and by design! Parameters, independent of the total FLOPs used, are a useful axis to scale neural language models. Large models have been exhaustively shown to perform better (Kaplan et al., 2020). But in this case, our model is more sample efficient and faster while using the same computational resources.

I don’t have access to a supercomputer—is this still useful for me? Though this work has focused on extremely large models, we also find that models with as few as two experts improves performance while easily fitting within memory constraints of commonly available GPUs or TPUs (details in Appendix D). We therefore believe our techniques are useful in small-scale settings.

Do sparse models outperform dense models on the speed-accuracy Pareto curve? Yes. Across a wide variety of different models sizes, sparse models outperform dense models per step and on wall clock time. Our controlled experiments show for a fixed amount of computation and time, sparse models outperform dense models.

I can’t deploy a trillion parameter model—can we shrink these models? We cannot fully preserve the model quality, but compression rates of 10 to 100x are achievable by distilling our sparse models into dense models while achieving ≈30% of the quality gain of the expert model.

Why use Switch Transformer instead of a model-parallel dense model? On a time basis, Switch Transformers can be far more efficient than dense-models with sharded parameters (Figure 6). Also, we point out that this decision is not mutually exclusive. we can, and do, use model-parallelism in Switch Transformers, increasing the FLOPs per token, but incurring the slowdown of conventional model-parallelism.

Why aren’t sparse models widely used already? The motivation to try sparse models has been stymied by the massive success of scaling dense models (the success of which is partially driven by co-adaptation with deep learning hardware as argued in Hooker (2020)). Further, sparse models have been subject to multiple issues including (1) model complexity, (2) training difficulties, and (3) communication costs. Switch Transformer makes strides to alleviate these issues.

8. Future Work

This paper lays out a simplified architecture, improved training procedures, and a study of how sparse models scale. However, there remain many open future directions which we briefly describe here:

  1. A significant challenge is further improving training stability for the largest models. While our stability techniques were effective for our Switch-Base, Switch-Large and Switch-C models (no observed instability), they were not sufficient for Switch-XXL. We have taken early steps towards stabilizing these models, which we think may be generally useful for large models, including using regularizers for improving stability and adapted forms of gradient clipping, but this remains unsolved.
  2. Generally we find that improved pre-training quality leads to better downstream results (Appendix E), though we sometimes encounter striking anomalies. For instance, despite similar perplexities modeling the C4 data set, the 1.6T parameter Switch-C achieves only an 87.7 exact match score in SQuAD, which compares unfavorably to 89.6 for the smaller Switch-XXL model. One notable difference is that the SwitchXXL model applies ≈10x the FLOPS per token than the Switch-C model, even though it has ≈4x less unique parameters (395B vs 1.6T). This suggests a poorly understood dependence between fine-tuning quality, FLOPS per token and number of parameters.
  3. Perform a comprehensive study of scaling relationships to guide the design of architectures blending data, model and expert-parallelism. Ideally, given the specs of a hardware configuration (computation, memory, communication) one could more rapidly design an optimal model. And, vice versa, this may also help in the design of future hardware.
  4. Our work falls within the family of adaptive computation algorithms. Our approach always used identical, homogeneous experts, but future designs (facilitated by more flexible infrastructure) could support heterogeneous experts. This would enable more flexible adaptation by routing to larger experts when more computation is desired— perhaps for harder examples.
  5. Investigating expert layers outside the FFN layer of the Transformer. We find preIn Appendix A, liminary evidence that this similarly can improve model quality. we report quality improvement adding these inside Self-Attention layers, where our layer replaces the weight matrices which produce Q, K, V. However, due to training instabilities with the bfloat16 format, we instead leave this as an area for future work.
  6. Examining Switch Transformer in new and across different modalities. We have thus far only considered language, but we believe that model sparsity can similarly provide advantages in new modalities, as well as multi-modal networks.

9. Conclusion

Switch Transformers are scalable and effective natural language learners. We simplify Mixture of Experts to produce an architecture that is easy to understand, stable to train and vastly more sample efficient than equivalently-sized dense models. We find that these models excel across a diverse set of natural language tasks and in different training regimes, including pre-training, fine-tuning and multi-task training. These advances make it possible to train models with hundreds of billion to trillion parameters and which achieve substantial speedups relative to dense T5 baselines. We hope our work motivates sparse models as an effective architecture and that this encourages researchers and practitioners to consider these flexible models in natural language tasks, and beyond.

Appendix

A. Switch for Attention

Shazeer et al. (2018); Lepikhin et al. (2020) designed MoE Transformers (Shazeer et al., 2017) by adding MoE layers into the dense feedfoward network (FFN) computations of the Transformer. Similarly, our work also replaced the FFN layer in the Transformer, but we briefly explore here an alternate design. We add Switch layers into the Transformer Self-Attention layers. To do so, we replace the trainable weight matrices that produce the queries, keys and values with Switch layers as seen in Figure 10.

Table 10 records the quality after a fixed number of steps as well as training time for several variants. Though we find improvements, we also found these layers to be more unstable when using bfloat16 precision and thus we did not include them in the final variant.

Figure 10: Switch layers in attention. We diagram how to incorporate the Switch layer into the Self-Attention transformer block. For each token (here we show two tokens, x1 = “More” and x2 = “Parameters”), one set of weights produces the query and the other set of unique weights produces the shared keys and values. We experimented with each expert being a linear operation, as well as a FFN, as was the case throughout this work. While we found quality improvements using this, we found this to be more unstable when used with low precision number formats, and thus leave it for future work.

However, when these layers do train stably, we believe the preliminary positive results suggests a future promising direction.

Table 10: Switch attention layer results. All models have 32 experts and train with 524k tokens per batch. Experts FF is when experts replace the FFN in the Transformer, which is our standard setup throughout the paper. Experts FF + Attention is when experts are used to replace both the FFN and the Self-Attention layers. When training with bfloat16 precision the models that have experts attention diverge.

B. Preventing Token Dropping with No-Token-Left-Behind

Due to software constraints on TPU accelerators, the shapes of our Tensors must be statically sized. As a result, each expert has a finite and fixed capacity to process token representations. This, however, presents an issue for our model which dynamically routes tokens at run-time that may result in an uneven distribution over experts. If the number of tokens sent to an expert is less than the expert capacity, then the computation may simply be padded – an inefficient use of the hardware, but mathematically correct. However, when the number of tokens sent to an expert is larger than its capacity (expert overflow), a protocol is needed to handle this. Lepikhin et al. (2020) adapts a Mixture-of-Expert model and addresses expert overflow by passing its representation to the next layer without processing through a residual connection which we also follow.

We suspected that having no computation applied to tokens could be very wasteful, especially since if there is overflow on one expert, that means another expert will have extra capacity. With this intuition we create No-Token-Left-Behind, which iteratively reroutes any tokens that are at first routed to an expert that is overflowing. Figure 11 shows a graphical description of this method, which will allow us to guarantee almost no tokens will be dropped during training and inference. We hypothesised that this could improve performance and further stabilize training, but we found no empirical benefits. We suspect that once the network learns associations between different tokens and experts, if this association is changed (e.g. sending a token to its second highest expert) then performance could be degraded.

C. Encouraging Exploration Across Experts

At each expert-layer, the router determines to which expert to send the token. This is a discrete decision over the available experts, conditioned on information about the token’s representation. Based on the incoming token representation, the router determines the best expert, however, it receives no counterfactual information about how well it would have done selecting an alternate expert. As in reinforcement learning, a classic explorationexploitation dilemma arises (Sutton and Barto, 2018). These issues have been similarly noted and addressed differently by Rosenbaum et al. (2017) which demonstrated success in multi-task learning. This particular setting most closely matches that of a contextual bandit (Robbins, 1952). Deterministically selecting the top expert always amounts to an exploitative strategy – we consider balancing exploration to seek better expert assignment.

To introduce exploration, we consider several approaches: 1) deterministic or argmax 2) sampling from the softmax distribution 3) input dropout on the incoming representation 4) multiplicative jitter noise on the incoming representation.

The resulting impact on model quality is reported in Table 11. Throughout this work, we use input jitter to inject noise as we have found it to empirically perform the best.

D. Switch Transformers in Lower Compute Regimes

Switch Transformer is also an effective architecture at small scales as well as in regimes with thousands of cores and trillions of parameters. Many of our prior experiments were at the scale of 10B+ parameter models, but we show in Figure 12 as few as 2 experts produce compelling gains over a FLOP-matched counterpart. Even if a super computer is not readily available, training Switch Transformers with 2, 4, or 8 experts (as we typically recommend one expert per core) results in solid improvements over T5 dense baselines.

Figure 11: Diagram of the No-Token-Left-Behind Routing. Stage 1 is equivalent to Switch routing where tokens are routed to the expert with the highest probability from the router. In Stage 2 we look at all tokens that have overflowed and route them to the expert with which has the second highest probability. Tokens can still be overflowed if their second highest expert has too many tokens, but this allows most of the tokens to be routed. This process can be iterated to guarantee virtually no tokens are dropped at all.

Table 11: Router Exploration Strategies. Quality of the Switch Transformer, measured by the negative log perplexity, under different randomness-strategies for selecting the expert (lower is better). There is no material speed performance difference between the variants.

Figure 12: Switch Transformer with few experts. Switch Transformer improves over the baseline even with very few experts. Here we show scaling properties at very small scales, where we improve over the T5-Base model using 2, 4, and 8 experts.

E. Relation of Upstream to Downstream Model Performance

There is no guarantee that a model’s quality on a pre-training objective will translate to downstream task results. Figure 13 presents the correlation of the upstream model quality, for both dense and Switch models, on the C4 pre-training task with two downstream task measures: average SuperGLUE performance and TriviaQA score. We choose these two tasks as one probes the model’s reasoning and the other factual knowledge.

Figure 13: Upstream pre-trained quality to downstream model quality. We correlate the upstream performance with downstream quality on both SuperGLUE and TriviaQA (SOTA recorded without SSM), reasoning and knowledge-heavy benchmarks, respectively (validation sets). We find that, as with the baseline, the Switch model scales with improvements in the upstream pre-training task. For SuperGLUE, we find a loosely linear relation between negative log perplexity and the average SuperGLUE score. However, the dense model often performs better for a fixed perplexity, particularly in the large-scale regime. Conversely, on the knowledge-heavy task, TriviaQA, we find that the Switch Transformer may follow an improved scaling relationship – for a given upstream perplexity, it does better than a dense counterpart. Further statistics (expensive to collect and left to future work) would be necessary to confirm these observations.

We find a consistent correlation, indicating that for both baseline and Switch models, improved pre-training leads to better downstream results. Additionally, for a fixed upstream perplexity we find that both Switch and dense models perform similarly in the small to medium model size regime. However, in the largest model regime (T5-11B/T5-XXL) our largest Switch models, as mentioned in Section 5.6, do not always translate their upstream perplexity well to downstream fine-tuning on the SuperGLUE task. This warrants future investigation and study to fully realize the potential of sparse models. Understanding the fine-tuning dynamics with expert-models is very complicated and is dependent on regularization, load-balancing, and fine-tuning hyper-parameters.

Previous: Survey Efficiency | Full Stack Optimization of Transformer Next: POST | Estimation FLOPs of LLaMA-2

post contain ""

    No matching posts found containing ""