대규모 AI 모델 학습을 10배 빠르게: Dion 옵티마이저를 사용한 분산 학습
최근 연구에서 소개된 Dion 옵티마이저는 대규모 AI 모델 학습 시 발생하는 통신 병목 현상을 해결하는 획기적인 접근법을 제시합니다.
핵심
실제 성능
Dion은 분산 환경에서 대규모 AI 모델 학습 시 발생하는 통신 오버헤드를 획기적으로 줄이기 위해 설계된 옵티마이저입니다. 기존 옵티마이저들이 전체 gradient 행렬을 동기화하여 발생하는 높은 통신 비용 대신, Dion은 로우랭크 근사와 에러 피드백 메커니즘을 도입하여 장치별 모멘텀 버퍼를 독립적으로 업데이트합니다. 이로 인해 동기적 업데이트(DDP, FSDP와 동일)를 유지하면서도 통신 비용을 \(O(mn)\)에서 훨씬 낮은 \(O((m+n) \cdot r)\)로 줄일 수 있습니다.
대규모 모델 학습의 필요성
최근 AI 모델의 규모와 복잡성이 커지면서 하나의 가속기만으로는 효율적인 학습을 진행하기 어렵게 되었습니다. 따라서 여러 GPU나 TPU를 활용한 분산 학습이 필수적입니다.
분산 학습 기법
분산 학습에서는 대표적으로 Distributed Data Parallel (DDP), Fully Sharded Data Parallel (FSDP) 등이 사용됩니다. 이들 기법은 모델 파라미터와 데이터를 여러 장치에 분산시켜 학습 속도와 효율성을 높이지만, gradient 동기화 과정에서 통신 오버헤드가 급증하는 문제가 있습니다.
기존 접근 방식의 한계
기존 연구에서는 gradient의 양자화(quantization)나 스파스화(sparsification), 혹은 동기화 빈도를 줄여 통신 비용을 감소시키려 하였으나, 이러한 방법은 학습 정확도나 동기적 업데이트의 이점을 어느 정도 포기하는 경우가 많았습니다.
전체 Gradient 동기화의 부담
분산 학습에서 예를 들어 파라미터 행렬이 \(X \in \mathbb{R}^{m \times n}\) 인 경우, Adam과 같은 옵티마이저는 매 업데이트마다 동일한 크기의 전체 gradient 행렬을 동기화해야 합니다. 이로 인해 통신 비용은
\(O(m \cdot n)\)
에 달하며, 이는 모델 크기가 커질수록 더욱 심각해집니다.
통신 인프라의 문제
고속 인터커넥트 (예: InfiniBand)에 의존하게 되면, 비용이 많이 들고 장치들이 한 데이터 센터 내에 있어야 하는 제약이 따르게 됩니다.
기존 해결책의 단점
Dion(Distributed Orthonormalization)은 동기적 업데이트 방식을 그대로 유지하면서도 분산 환경에서 발생하는 통신 비용을 크게 줄이기 위한 새로운 옵티마이저입니다. 핵심은 로우랭크 근사와 에러 피드백 메커니즘에 있습니다.
Gradient 계산
각 학습 step마다 gradient
\(G_t \in \mathbb{R}^{m \times n}\)
를 계산합니다.
임시 버퍼 구성
기존 모멘텀 버퍼에 현재 gradient를 더해 임시 버퍼를 만듭니다.
\(B_t \leftarrow M_{t-1} + G_t\)
로우랭크 근사 - 단일 Power Iteration 사용
임시 버퍼 \(B_t\)에 대해 단일 Power Iteration을 수행하여,
\(P_t \in \mathbb{R}^{m \times r}\)와
\(R_t \in \mathbb{R}^{n \times r}\)
를 근사해 냅니다.
이는 기존 Muon 옵티마이저에서 사용한 Newton–Schulz 반복법보다 계산 및 통신 비용 측면에서 효율적입니다.
에러 피드백 기반 모멘텀 업데이트
로우랭크 근사로 추출된 성분만큼을 모멘텀 버퍼에서 보정하고, 소실되는 정보를 에러 피드백으로 보존합니다.
\(M_t \leftarrow B_t - (1 - \mu) \cdot P_t R_t^\top\)
여기서 \(\mu \in (0,1)\)는 모멘텀 감쇠 계수입니다.
오른쪽 인자의 재정규화
Power Iteration으로 얻은 \(R_t\)를 열 단위로 정규화하여 새로운 \(Q_t\)를 업데이트합니다.
\(Q_t \leftarrow \text{ColumnNormalize}(R_t)\)
파라미터 업데이트 (Orthonormal 업데이트)
최종적으로 파라미터 행렬은 아래의 규칙에 따라 업데이트됩니다.
\(X_t \leftarrow X_{t-1} - \eta \cdot P_t Q_t^\top\)
여기서 \(\eta\)는 학습률입니다.
모멘텀 업데이트:
\(M_t = B_t - (1 - \mu) \cdot P_t R_t^\top\)
파라미터 업데이트:
\(X_t = X_{t-1} - \eta \cdot P_t Q_t^\top\)
로우랭크 업데이트의 필요성
전체 gradient 행렬을 동기화하는 대신, 낮은 차원 \(r\)의 근사값만을 사용하여 통신해야 할 데이터의 양을 크게 줄입니다. 이렇게 하면 통신 비용은
\(O((m+n) \cdot r)\)
로 감소하게 됩니다.
Dion은 중앙집중형 알고리즘의 원리를 분산 환경에 맞게 확장하여, 여러 가속기 간 동기적 업데이트를 유지하면서도 전체 gradient 및 모멘텀 행렬을 동기화하지 않습니다.
데이터 병렬 축:
각 워커는 먼저 로컬에서 gradient 계산, 모멘텀 업데이트, 단일 Power Iteration 등의 연산을 수행합니다. 이후, 각 워커에서 얻어진 공통 업데이트 요소(예: \(P_t\), 로컬 \(R_t(i)\))를 all-reduce mean 연산을 통해 동기화합니다.
이때 통신 비용은 \((m+n) \cdot r\)입니다.
웨이트 병렬 축:
각 워커의 일부 상태(예: 열 정규화 값 \(c(i)\))는 all-reduce sum을 통해 동기화되며, 통신 비용은 \((m+1) \cdot r\) 정도로 발생합니다.
로컬 Gradient 계산
각 워커는 \(G^t(i) \in \mathbb{R}^{m \times n_i}\)를 계산합니다.
로컬 버퍼 업데이트
각 워커에서
\(B^t(i) \leftarrow M^{t-1}(i) + G^t(i)\)
를 업데이트합니다.
Distributed Power Iteration
각 워커는 단일 Power Iteration을 통해 로컬에서 \(P_t\)와 \(R_t(i)\)를 계산하고, 데이터 병렬 동기화를 통해 전체 공통의 \(P_t\)를 확보합니다.
로컬 모멘텀 업데이트
각 워커는
\(M^t(i) \leftarrow B^t(i) - (1 - \mu) \cdot P_t \big(R_t(i)\big)^\top\)
를 통해 모멘텀을 업데이트합니다.
분산 Column 정규화
각 워커는 \(R_t(i)\)를 열 단위로 정규화하여 \(Q_t(i)\)를 구합니다.
파라미터 업데이트
각 워커에서
\(X^t(i) \leftarrow X^{t-1}(i) - \eta \cdot P_t \big(Q_t(i)\big)^\top\)
와 같이 파라미터를 업데이트합니다.
논문에서는 Dion의 성능을 여러 환경 및 다른 옵티마이저(AdamW, Muon, DeMO)와 비교하며 평가하였습니다.
낮은 Rank 실험:
\(r = \frac{d}{4},\ \frac{d}{8},\ \frac{d}{16} \quad (d=768)\)
인 경우, Dion은 rank가 낮아져도 안정적인 수렴을 보이며 Adam보다 빠른 수렴을 보입니다. 에러 피드백 메커니즘 덕분에 정보 손실 문제도 완화됩니다.
전체 Rank 비교:
\(r = d \quad \text{및} \quad r = \frac{d}{2}\)
인 경우, 이론적으로 Muon과 유사한 성능을 보여야 하나, 실제 실험에서는 배치 크기가 클 때 Dion이 Muon보다 더 우수한 성능을 나타내었습니다.
계산 비용 및 효율성:
단일 Power Iteration 방식이 전면 SVD와 비교해 거의 동일한 수렴 성능을 보이면서 계산 비용은 훨씬 낮은 장점을 보여줍니다.
Dion의 핵심 구성 요소인 에러 피드백 메커니즘과 단일 Power Iteration의 효과를 개별적으로 평가하였습니다.
실험 설계
에러 피드백 없이 단순하게
\(M_t \leftarrow \mu M_{t-1} + G_t\)
로 업데이트하는 baseline variant와 비교.
결과
낮은 rank 설정 시 baseline은 급격한 성능 저하를 보인 반면, Dion은 에러 피드백 메커니즘 덕분에 안정적으로 수렴합니다.
해석
에러 피드백은 로우랭크 근사 과정에서 누락되는 정보를 보완하여, 지속적이고 정확한 모멘텀 업데이트를 가능하게 합니다.
실험 설계
매 step마다 full SVD를 적용하는 방식과 단일 Power Iteration을 이용한 방식을 비교.
결과
두 방식 간의 수렴 및 최종 성능의 차이는 미미하였으나, 단일 Power Iteration 방식은 훨씬 낮은 계산 비용을 소요합니다.
해석
이전 step의 \(Q_{t-1}\)를 초기값으로 사용함으로써, 단일 Power Iteration은 충분한 근사 정확도를 보장하며 실시간 분산 학습에 적합한 방법임을 확인하였습니다.
DeMO 옵티마이저 모멘텀의 빠른 성분을 DCT 기반으로 압축하여 통신 오버헤드를 감소시키는 방식을 사용하였으나, 실험 결과에서는 대규모 배치 환경에서 Dion에 비해 다소 저조한 성능을 보였습니다.
위의 내용을 통해 Dion 옵티마이저의 설계 원리, 알고리즘 절차, 분산 구현 전략, 실험 결과 및 ablation study를 종합적으로 이해할 수 있습니다.
분산 환경에서의 통신 오버헤드 문제를 효과적으로 해결함과 동시에, 기존 동기화 기법의 장점을 유지하는 Dion은 대규모 AI 모델의 효율적인 학습과 배포를 위한 중요한 연구 성과로 평가할 수 있습니다.