00:00:00

Share Your Feedback 🏝️

Model | Orca Math

Model | Orca Math

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Model | Yi Tech Report Next: Attn | Fire Attention

Model | Orca Math

  • Related Project: Private
  • Category: Paper Review
  • Date: 2024-03-07

Orca-Math: Unlocking the potential of SLMs in Grade School Math

  • url: https://arxiv.org/abs/2402.14830
  • pdf: https://arxiv.org/pdf/2402.14830
  • html: https://arxiv.org/html/2402.14830v1
  • abstract: Mathematical word problem-solving has long been recognized as a complex task for small language models (SLMs). A recent study hypothesized that the smallest model size, needed to achieve over 80% accuracy on the GSM8K benchmark, is 34 billion parameters. To reach this level of performance with smaller models, researcher often train SLMs to generate Python code or use tools to help avoid calculation errors. Additionally, they employ ensembling, where outputs of up to 100 model runs are combined to arrive at a more accurate result. Result selection is done using consensus, majority vote or a separate a verifier model used in conjunction with the SLM. Ensembling provides a substantial boost in accuracy but at a significant cost increase with multiple calls to the model (e.g., Phi-GSM uses top-48 to boost the performance from 68.2 to 81.5). In this work, we present Orca-Math, a 7-billion-parameter SLM based on the Mistral-7B, which achieves 86.81% on GSM8k without the need for multiple model calls or the use of verifiers, code execution or any other external tools. Our approach has the following key elements: (1) A high quality synthetic dataset of 200K math problems created using a multi-agent setup where agents collaborate to create the data, (2) An iterative learning techniques that enables the SLM to practice solving problems, receive feedback on its solutions and learn from preference pairs incorporating the SLM solutions and the feedback. When trained with Supervised Fine-Tuning alone, Orca-Math achieves 81.50% on GSM8k pass@1 metric. With iterative preference learning, Orca-Math achieves 86.81% pass@1. Orca-Math surpasses the performance of significantly larger models such as LLAMA-2-70B, WizardMath-70B, Gemini-Pro, ChatGPT-3.5. It also significantly outperforms other smaller models while using much smaller data (hundreds of thousands vs. millions of problems).

Contents

TL;DR


고품질 합성 데이터를 생성하고, 이를 통해 소형 언어 모델의 수학적 인퍼런스 능력을 반복적으로 향상시키는 방법을 체계적으로 제시합니다. 수학적 인퍼런스 능력의 향상을 위해, 데이터셋의 다양성과 난이도를 높이고, 반복 학습을 통해 모델의 성능을 단계적으로 개선하는 과정이 상세히 설명됩니다. 이를 통해, 소형 언어 모델이 대규모 모델과 경쟁할 수 있는 성능을 갖출 수 있음을 입증합니다.

  • 소형 언어 모델(SLM)의 수학적 인퍼런스 능력 향상 방법 제안
  • 고품질의 합성 데이터셋인 Orca-Math-200K 생성 및 사용
  • 반복적 학습 절차를 통해 성능 개선, 주요 벤치마크에서 높은 성과 달성

1 문제 설정

대규모 언어모델(LLM), 예를 들어 GPT-4,은 소규모 모델(SLM)에서는 볼 수 없었던 향상된 수학적 인퍼런스 능력을 보입니다. 이런 능력은 대규모 모델 크기, 데이터셋 크기 및 학습에 필요한 계산량에 크게 기인합니다. 최근 연구들은 SLM의 인퍼런스 능력 향상에 집중하고 있으며, 그 중 한 가지 유망한 방법은 대규모 모델을 사용해 SLM을 위한 맞춤형 고품질 합성 데이터를 생성하는 것입니다. 이 연구는 수학적 인퍼런스에 초점을 맞추고 있으며, 특히 초등 수학 문제(GSM8K 벤치마크)를 사용하여 연구를 진행합니다.


2 데이터셋 생성: Agent-Instruct

다양한 에이전트를 활용하여 초등 수학 문제의 다양성을 확보하고 난이도를 높이기 위한 데이터를 생성합니다.

2.1 기본 데이터셋

기존의 공개 데이터셋에서 36,217개의 문제를 수집합니다. 이 데이터는 Lila 벤치마크를 사용하여 수집되었습니다.

2.2 에이전트 - Ask Me Anything

기본 데이터셋의 각 문제를 변형하여 새로운 문제를 생성합니다. 예를 들어, 원래 문제를 문장으로 변환한 후, 각 숫자를 다른 문제로 바꾸어 새로운 문제를 생성합니다. 이 과정에서 GPT-4-Turbo를 사용해 문제의 해답을 생성합니다.

2.3 에이전트 - Suggester & Editor

기존 문제를 더 어렵게 변형하는 에이전트를 활용합니다. Suggester는 문제의 난이도를 높이는 방법을 제안하고, Editor는 이를 바탕으로 문제를 수정하여 난이도를 높입니다. 이런 과정을 통해 37,157개의 추가 문제를 생성합니다.

2.4 DMath 데이터셋

DMath에서 6,216개의 문제를 추가로 포함시킵니다.


3 학습

3.1 Supervised Fine-Tuning Experiment (1차 반복)

Mistral-7B 모델을 Orca-Math-200K 데이터셋으로 파인튜닝합니다. 학습률은 $1 \times 10^{-6}$이며, 각 GPU마다 배치 크기는 3으로 설정됩니다. 8개의 A100 노드에서 한 에포크 동안 학습을 진행합니다.

3.2 Iterative Learning from both Positive and Negative Signals (2차 반복)

첫 번째 반복에서 파인튜닝된 모델로부터 네 가지 응답을 샘플링하여 긍정적 및 부정적 해답을 생성합니다. 이후, GPT-4 기반의 정확한 일치 기준을 사용하여 해답의 일치 여부를 평가하고, 이를 통해 training dataset를 구성합니다.

3.3 Iterative Learning (3차 반복)

2차 반복에서 생성된 모델(M2)을 사용하여 네 가지 응답을 다시 생성하고, 이를 바탕으로 새로운 training dataset를 구성합니다. Direct Preference Optimization (DPO)와 Kahneman-Tversky Optimization (KTO)을 사용하여 긍정적 및 부정적 피드백을 모두 학습합니다.


4 평가

GPT-4 기반의 정확한 일치 기준을 사용하여 모델의 성능을 평가합니다. 예를 들어, student의 답안과 정답의 일치 여부를 확인하는 과정을 통해 모델의 성능을 측정합니다.


5 결과

다양한 학습 절차를 통해 GSM8k 테스트 세트에서의 성능을 평가합니다. 첫 번째 반복에서는 지도 학습으로 79.91%의 정확도를 달성하고, 두 번째 반복에서는 KTO를 통해 85.06%로 성능이 향상됩니다. 세 번째 반복에서는 KTO를 통해 86.87%의 정확도를 달성합니다.


6 관련 연구

이 연구는 합성 데이터를 통한 학습과 반복 학습을 통해 SLM의 수학적 인퍼런스 능력을 향상시키는 방법을 제시합니다. 특히 KTO의 강력한 성능과 모델 생성 긍정 해답의 효과를 입증합니다. 이 연구는 소형 언어 모델의 반복 학습 및 자기 개선을 위한 초기 단계로서의 의의를 가집니다.


[참고자료 1] 주요 섹션 살펴보기

3.2 Iterative Learning from both Positive and Negative Signals

Dataset Construction Iteration #2

각 문제에 대해 추가적인 긍정적 및 부정적 해답을 생성하기 위해, 첫 번째 반복에서 파인튜닝된 모델로부터 네 가지 응답을 샘플링합니다. 구체적으로, 다음과 같은 설정을 사용합니다.

  • top_p: 0.95
  • temperature: 0.7

이 과정은 200,000개의 문제 각각에 대해 하나의 GPT-4-Turbo 생성 해답과 네 가지 student 생성 해답을 포함하는 데이터셋을 생성합니다. 이후, GPT4-Based-Exact-Match (섹션 4 참조)에서 정의된 프롬프트를 사용하여 teacher(GPT-4-Turbo)의 해답과 student의 해답 간의 일치 여부를 평가합니다. student 생성 해답이 teacher 해답과 일치하지 않는 경우, 해당 해답을 부정적으로 레이블링하고, 그렇지 않은 경우에는 긍정적으로 레이블링합니다. 그런 다음 다음과 같이 선호 데이터셋을 구성합니다.

  1. 각 질문 \(q_i\)에 대해 \(q^+_i\) (긍정적인 모든 해답의 집합)을 구성합니다. teacher의 해답을 긍정적으로 취급하여, 이 집합에는 최소 하나의 요소가 포함됩니다.
  2. 각 질문 \(q_i\)에 대해 \(q^-_i\) (부정적인 모든 해답의 집합)을 구성합니다. 이 집합은 네 개의 응답이 모두 teacher의 해답과 일치하는 경우 비어 있을 수 있습니다. 이런 상황에서는 \(q^-_j\)에서 무작위로 한 응답을 샘플링합니다.

최종 선호 데이터셋은 모든 \(q_i\)에 대한 \(Q_i\)의 합집합으로 생성됩니다.

\[Q_i = \{ (q_i, a^+_i, a^-_i) | a^+_i \in q^+_i \text{ and } a^-_i \in q^-_i \}\]

Dataset Construction Iteration #3

\(M_2\)를 사용하여 네 가지 응답을 다시 생성하고, 이를 바탕으로 새로운 training dataset를 구성합니다. 두 가지 알고리즘의 성능을 평가합니다.

  • Direct Preference Optimization (DPO) [31]
  • Kahneman-Tversky Optimization (KTO) [10]

DPO는 효율적으로 언어 모델을 선호도에 맞추어 파인튜닝하는 간단하고 유명한 접근 방식입니다. KTO는 출력의 품질을 평가하기 위해 이진 “예” 또는 “아니오” 응답만을 요구합니다.


4 Evaluation

정확한 일치 기준을 메트릭으로 사용합니다. 모델 생성 해답을 주고, GPT-4에게 최종 짧은 해답을 추출하여 금 해답과 일치시키도록 합니다. 이를 GPT4-based-Exact-Match라고 합니다. 다음은 프롬프트 템플릿의 예입니다.

SYSTEM: 수학 전문가 teacher로서, 당신의 역할은 student의 단어 문제에 대한 해답을 평가하는 것입니다. 문제는 문제 작성자가 제공한 올바른 해답과 함께 제공됩니다. 단어 문제를 해결하는 방법은 다양할 수 있으므로, student의 단계가 문제 작성자의 해답과 항상 일치하지 않을 수 있습니다. 그러나 최종 해답은 일반적으로 하나의 숫자로 고유하며 문제 작성자의 해답과 일치해야 합니다. 당신의 과제는 student의 해답을 분석하여 실수를 식별하고 오류를 수정할 수 있는지 판단하는 것입니다. student의 해답이 수정 불가능한 경우, 이해도를 향상시키기 위한 연습 문제를 만드는 것을 고려하십시오.

Error Analysis: 문제 작성자의 해답에서 최종 해답을 추출하고 student의 해답과 비교하세요. 일치합니까?

Final Verdict: 올바름/올바르지 않음

USER: Question: Billy는 사람들의 세금 신고를 도와주는 자원봉사를 하고 있습니다. 그는 하루에 3시간 동안 한 시간에 2명을 도울 수 있습니다. Billy는 3월 1일부터 4월 19일까지의 날 중 20%를 쉬고, 나머지 날에 사람들을 돕습니다. 그는 총 몇 명을 도울 수 있습니까? (3월에는 31일이 있습니다.)

Problem Setter’s Answer:

  1. 먼저 3월 1일부터 4월 19일까지의 날 수를 구합니다. 31일 + 19일 = 50일
  2. Billy가 쉬는 날의 비율을 100%에서 빼서 일하는 날의 비율을 구합니다. 100% - 20% = 80%
  3. 그 비율을 50일에 곱하여 Billy가 일하는 날 수를 구합니다. 50일 * 80% = 40일
  4. 하루에 Billy가 도울 수 있는 사람 수를 구합니다. 2명/시간 * 3시간/일 = 6명/일
  5. Billy가 일하는 날 수에 하루에 도울 수 있는 사람 수를 곱하여 총 도울 수 있는 사람 수를 구합니다. 6명/일 * 40일 = 240명

Student Answer:

  1. 먼저 3월 1일부터 4월 19일까지의 총 날 수를 계산합니다.
    • 3월에는 31일, 4월에는 30일이 있습니다. 따라서 3월 1일부터 4월 19일까지의 총 날 수는: 31 (3월) + 19 (4월) = 50일
  2. 이제 Billy가 쉬는 날 수를 계산합니다. 그는 50일 중 20%를 쉽니다.
    • 50일의 20% = 0.20 * 50 = 10일
  3. 따라서 Billy는 10일을 쉽니다. 이는 그가 일하는 날 수를 남깁니다. 50일 - 10일 = 40일
  4. Billy는 하루에 3시간 동안 사람들을 돕고, 한 시간에 2명을 도울 수 있습니다. 따라서 하루에 그는:
    • 3시간/일 * 2명/시간 = 6명/일
  5. 이제 그가 일하는 40일 동안 도울 수 있는 사람 수를 계산합니다.
    • 40일 * 6명/일 = 240명

따라서, Billy는 240명을 돕습니다.

ASSISTANT:

  • Error Analysis: student의 최종 해답인 240명은 문제 작성자의 해답과 일치합니다.
  • Final Verdict: 올바름

Performance of Training Procedures on GSM8K Test Set

Table 2는 1319개의 단어 문제를 포함하는 GSM8k 테스트 세트에서 여러 학습 절차의 성능을 보여줍니다. Mistral-7B를 최대 세 번의 반복 동안 파인튜닝합니다. 첫 번째 반복에서는 지도 학습(SFT)을 사용하여 M1을 얻습니다. 두 번째 반복에서는 SFT, DPO [31], KTO를 비교합니다. KTO로 훈련된 모델이 이 그룹에서 더 나은 성능을 보입니다. 이를 M2라고 부르고, M2를 사용하여 세 번째 반복을 위한 데이터셋을 생성합니다. 세 번째 반복에서는 DPO와 KTO를 비교하며, 이때 M2가 시작점이 됩니다. 또한 Orca-Math-200K 데이터셋에서 세 번의 SFT 학습과 비교합니다. 모든 SFT 학습에서는 $1 \times 10^{-6}$의 일정 학습률을 사용합니다. 장치당 배치 크기는 3으로 설정하고, 에포크 수는 1로 설정합니다. DPO 및 KTO 학습 작업에서는 베타 값을 0.3으로 설정하고, 장치당 배치 크기는 3, 기울기 누적 단계는 11, 에포크 수는 1로 설정합니다. 두 번째 반복에서 DPO와 KTO 학습에는 $1 \times 10^{-6}$의 일정 학습률을 사용하고, 세 번째 반복에서는 $1 \times 10^{-7}$의 일정 학습률을 사용합니다.

Training Procedure Pass@1 Accuracy on GSM8K Test Set
SFT (M1) 79.91
SFT (M1) → SFT 81.50
SFT (M1) → DPO 84.23
SFT (M1) → KTO (M2) 85.06
SFT (M1) → SFT → SFT 80.44
SFT → KTO (M2) → DPO 84.91
SFT → KTO (M2) → KTO (Orca-Math) 86.87

Table 2: GSM8k 테스트 세트에서 다양한 반복 학습 실험 및 기준 성능. SFT는 Orca-Math-200K 데이터셋에서 한 에포크 동안의 학습을 의미합니다. SFT → SFT는 Orca-Math-200K에서 두 에포크 동안의 학습을 의미합니다. SFT → DPO (KTO)는 M1에서 시작하여 반복 #2를 위한 데이터셋에서 DPO (KTO)로 한 에포크 동안의 학습을 의미합니다. SFT → SFT → SFT는 Orca-Math-200K에서 세 에포크 동안의 학습을 의미합니다. SFT → KTO → DPO (KTO)는 M2에서 시작하여 반복 #3을 위한 데이터셋에서 DPO (KTO)로 한 에포크 동안의 학습을 의미합니다. 평가를 위해 greedy decoding을 사용합니다.

5.1 Ablation Studies

Model Generated Positives

모델 생성 긍정의 영향을 연구하기 위해 데이터셋을 teacher 생성 해답만 포함하도록 제한합니다. 즉, 반복 #2 데이터셋 생성에서 모델 \(q^+_i\)가 생성한 모든 \(a^+_i\)를 제거합니다. Table 3은 이 데이터셋에서 M1을 DPO와 KTO로 한 에포크 동안 학습한 결과를 보여줍니다. 반복 #2에 대한 하이퍼파라미터를 재사용합니다. 학습 알고리즘과 관계없이 성능이 크게 감소하는 것을 확인할 수 있습니다.

Training Procedure Pass@1 Accuracy on GSM8K Test Set
M1 → DPO 81.96 (-2.27)
M1 → KTO 82.79 (-2.27)

Table 3: teacher 생성 긍정 해답만 사용할 때 성능 저하.

Synthetic Negatives

모든 네 가지 응답이 긍정적인 경우, 선호 데이터셋 생성은 합성 부정적 해답 생성을 포함합니다. 모든 샘플링된 응답이 긍정적인 질문 \(q_i\)를 무시함으로써 이런 합성 부정적 해답의 영향을 연구합니다 (Table 4). 이는 반복 #2의 질문 수를 약 80k, 반복 #3의 질문 수를 약 104k 줄입니다.

Training Procedure Pass@1 Accuracy on GSM8K Test Set
M1 → DPO 60.73 (-23.5)
M1 → KTO 85.22 (+0.17)
M1 → KTO → KTO 85.44 (-1.43)

Table 4: 모든 샘플링된 응답이 긍정적인 문제를 포함하는 것이 유익함을 보여줌.

5.2 Math Benchmarks beyond GSM8K

Table 5는 여러 다른 단어 문제 데이터셋에서 Orca-Math의 성능을 보여줍니다. 평가의 편의를 위해, 각 문제의 해답이 단일 숫자인 데이터셋을 선택했습니다. 벤치마크의 테스트 세트는 Lila에서 얻었습니다. GPT4 기반의 정확한 일치 메트릭을 사용하며, 모델 응답은 greedy decoding을 사용하여 생성됩니다.

Test Set Orca-Math-SFT (M1) Orca-Math
AddSub 88.99 91.74
ASDiv 91.10 91.10
MultiArith 98.28 98.28
SingleOp 98.74 99.37
SingleEq 97.25 99.08
Svamp 87.63 91.30

Table 5: Iteration #1 (M1)과 Orca-Math에서 SFT 훈련된 모델의 AddSub, ASDiv, MultiArith, SingleOp, SingleEq, Svamp 성능.


[참고자료 1] Kahneman-Tversky Optimization (KTO)

Kahneman-Tversky Optimization (KTO)는 심리학자 Daniel Kahneman과 Amos Tversky의 행동 경제학 이론을 기반으로 한 최적화 방법입니다. 이 최적화 방법은 사람들이 실제로 의사결정을 내리는 방식을 반영하는 모델을 통해 보다 현실적인 최적화 문제를 해결하는 것을 목표로 합니다. KTO는 특히 불확실성과 위험이 존재하는 환경에서의 의사결정 과정을 다룹니다.

이론적 배경

KTO는 주로 Kahneman과 Tversky의 전망 이론 (Prospect Theory)에 기반을 둡니다. 전망 이론은 사람들이 기대효용 이론의 합리적 모델과는 다르게 실제로 어떻게 선택을 하는지를 설명합니다. 이 이론의 핵심 요소는 다음과 같습니다.

  1. 참조점 (Reference Point): 사람들은 특정 참조점을 기준으로 이익과 손실을 평가합니다.
  2. 가치 함수 (Value Function): 이 함수는 이익보다 손실에 더 민감한 (loss aversion) 비대칭적인 형태를 가집니다. 즉, 동일한 크기의 손실이 이익보다 더 큰 영향을 미칩니다.
  3. 확률 가중 함수 (Probability Weighting Function): 사람들은 실제 확률을 왜곡하여 인지합니다. 낮은 확률은 과대평가되고 높은 확률은 과소평가됩니다.

최적화 문제로서의 KTO

KTO는 이런 휴먼의 의사결정 편향을 반영한 최적화 문제를 정의합니다. 이를 위해 전통적인 최적화 문제의 목표 함수에 전망 이론의 요소를 통합합니다.

  1. 참조점 설정: 의사결정자가 기준으로 삼는 참조점을 정의합니다. 이 참조점은 주관적일 수 있으며, 상황에 따라 달라질 수 있습니다.
  2. 가치 함수 적용: 이익과 손실을 가치 함수로 변환합니다. 가치 함수 \(v(x)\)는 일반적으로 다음과 같은 형태를 가집니다.
\[v(x) = \begin{cases} x^\alpha & \text{if } x \geq 0 \\ -\lambda (-x)^\beta & \text{if } x < 0 \end{cases}\]

\(\alpha\)와 \(\beta\)는 감쇠 계수로, 일반적으로 \(0 < \alpha, \beta \leq 1\)이며, \(\lambda\)는 손실 회피 계수로 \(\lambda > 1\)입니다.

  1. 확률 가중 함수 적용: 의사결정자는 확률 \(p\)를 확률 가중 함수 \(\pi(p)\)로 변환하여 인지합니다. 일반적으로 이 함수는 다음과 같은 형태를 가집니다.
\[\pi(p) = \frac{p^\gamma}{(p^\gamma + (1-p)^\gamma)^{1/\gamma}}\]

\(0 < \gamma \leq 1\)는 가중 계수입니다.

KTO의 목표 함수

KTO의 목표 함수는 이런 요소들을 통합하여 정의됩니다. 예를 들어, 의사결정자가 \(n\)개의 대안을 가지고 있고, 각 대안 \(i\)에 대해 예상되는 결과가 \(x_i\), 확률이 \(p_i\)라고 할 때, KTO의 목표 함수는 다음과 같이 정의될 수 있습니다.

\[\text{KTO Objective} = \sum_{i=1}^n \pi(p_i) v(x_i - r)\]

\(r\)은 참조점입니다.

예를 들어, 투자자가 두 가지 투자 옵션 \(A\)와 \(B\)가 있으며, 다음과 같은 확률로 손실과 이익이 발생할 수 있다고 가정하면,

  • 옵션 \(A\): 50% 확률로 $100 이익, 50% 확률로 $50 손실
  • 옵션 \(B\): 100% 확률로 $20 이익

이 경우, 투자자의 참조점 \(r\)을 $0로 설정하고, 가치 함수와 확률 가중 함수를 적용하여 각각의 기대 가치를 계산할 수 있습니다.

  • 옵션 \(A\)의 기대 가치

    \(x_1 = 100\), \(p_1 = 0.5\)

    \(x_2 = -50\), \(p_2 = 0.5\)

    \[v(100) = 100^\alpha\] \[v(-50) = -\lambda (50^\beta)\] \[\pi(0.5) = \frac{0.5^\gamma}{(0.5^\gamma + 0.5^\gamma)^{1/\gamma}} = 0.5\]

    따라서,

    \[\text{KTO}(A) = 0.5 \cdot 100^\alpha + 0.5 \cdot (-\lambda (50^\beta))\]
  • 옵션 \(B\)의 기대 가치

    \(x = 20\), \(p = 1\)

    \[v(20) = 20^\alpha\] \[\pi(1) = 1\]

    따라서, 두 기대 가치를 비교하여 최적의 투자를 선택할 수 있습니다.

    \[\text{KTO}(B) = 1 \cdot 20^\alpha\]

결론

Kahneman-Tversky Optimization (KTO)은 전통적인 최적화 기법에 행동 경제학적 요소를 통합하여 보다 현실적인 의사결정을 지원하는 방법으로 실제 사람들의 의사결정 과정에서 나타나는 편향을 반영할 수 있으며, 불확실성과 위험이 존재하는 환경에서의 최적화 문제를 효과적으로 해결할 수 있습니다.


1 Problem Setup

Frontier Language Models such as GPT-4 [1] have demonstrated capabilities previously unseen in smaller models, most notably the remarkable ability to reason (e.g. mathematical reasoning that requires both language comprehension and mathematical understanding). These capabilities have been largely attributed to the very large scale the model size, the dataset size and ultimately the amount of compute needed for training.

Several recent studies have focused on improved the reasoning abilities of small language models (SLMs). Despite that the extent to which scale is needed for achieving reasoning capabilities is still an open research question.

One of the promising directions of improving the reasoning capabilities of SLMs is using frontier language models, such as GPT-4, to create tailored and high-quality synthetic data that can be used to train the SLM. The high quality of the training data and the ability to elicit richer learning signals (e.g. explanations) have been show to significantly improve SLMs abilities in acquiring skills that had only emerged before at much larger scale.

This paradigm fits under a teacher-student approach where the large model (the teacher) is creating demonstrations for the SLM (the student) to learn from. In this work we further explore this direction with focus on mathematical reasoning on grade school math world problem, using the popular GSM8K benchmark.

Several other studies have demonstrated positive results on GSM8K recently with SLMs, e.g. Phi-GSM [21], OVM [38], etc. However, many of them employ ensembling, where outputs of up to 100 model runs are combined to arrive at a more accurate results. Result selection is done using, consensus, majority vote or by using a separate a verifier model to score/verify the outputs and select the best answer. Ensembling provides a substantial boost in accuracy (e.g., Phi-GSM uses top-48 to boost the performance from 68.2 to 81.5, [22] uses top-100 to boost LLAMA-2’s performance from 38.6% to 71.9%). However it comes at a significant increase in cost with multiple calls to the model, generating and verifying a 100 different solutions requires 200 different calls to the models. Additionally, some of them use very larger amounts of data (e.g. 12M for Phi-GSM) or use tools or code to avoid calculation errors.

In this work, we extend the teacher-student paradigm to an iterative learning settings with high-quality synthetic training data as follows:

  • We create Orca-Math-dataset, a synthetic dataset of 200K math problems, paired with GPT-4-Turbo solutions. The dataset was generated using an agent-based setup, hereby referred as, Agent-Instruct, that not only paraphrases existing problems but aims to expand the problem set both in diversity and difficulty.
  • We introduce an iterative learning procedure where we: (1) use the dataset for supervised finetuning to train the SLM on demonstrations, (2) allow the SLM to practice generating multiple solutions and (3) use the teacher to provide feedback to the student. The feedback comes in the form of evaluating the solutions generated by the student or providing a teacher solution.

With the supervised finetuning alone, we achieve 81.50% on GSM8k at pass@1 metric. The iterative learning loop further improves the pass@1 to 86.81%. without the need for multiple model calls or the use of verifiers, code execution or any other external tools. The model exceeding much bigger models like LLAMA-2-70B (56.8%), WizardMath-70B (81.6%), Gemini Pro (86.5% with 32 trials) and GPT-3.5 (77.4%). Most notably it can reach this level with only 200K examples (orders of magnitude less than other datasets).

2 Dataset Construction: Agent-Instruct

The goal of this step is to create a diverse set of grade school math word problems that contains both easy and hard problems. Towards this goal we create a variety of agents.

Seed Set We start by collecting sample math word problems from existing open-source datasets, namely NumGLUE [26], AddSub [13], ALGES [17], ASDiv [24], DRAW [35], GSM8k [7], MATHQA [2], MultiArith [32], SingeOP [33], SingleEQ [16], and SVAMP [30]. We collect a total of 36, 217 problems. We utilize the Lila [25] benchmark to collect the datasets. Specifically, we collect problems from the train and validation splits from Lila to construct the seed set. Interested readers, please refer to Lila [25].

Agent - Ask Me Anything

We expand the seed set by creating multiple word problems from each problem in the seed set. We utilize the subsequent prompt for problem creation.

Prompt for Problem Creation:

Your goal is to create multiple word problems from a given word problem and its answer. First convert the question of the word problem into a statement. Then for each number in the converted problem create a new word problem. Here are some examples:

Example 1:

Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Answer: 72

Replacing question with statement: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. Natalia sold altogether 72 clips in April and May.

All questions:

Target Question
48 Natalia sold clips to some of her friends in April, and then she sold half as many clips in May. Natalia sold altogether 72 clips in April and May. How many clips did she sell in April?
half Natalia sold clips to 48 of her friends in April, and then she sold some clips in May. Natalia sold altogether 72 clips in April and May. What is the ratio of the number clips sold in April to number clips sold in May?
72 Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Example 2:

Q: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
Answer: 10

Replacing question with statement: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. She earned $10.

All questions:

Target Question
12 Weng earns a certain amount per hour for babysitting. Yesterday, she just did 50 minutes of babysitting and earned 10. How much does she earn per hour?
50 Weng earns 12 an hour for babysitting. Yesterday, she just did some babysitting and earned 10. How much time did she spend on babysitting?
10 Weng earns 12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?

Example 3:

Q: Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?
Answer: 5

Replacing question with statement: Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents gave her twice as much as her parents. She needs 5 more to buy the wallet.

All questions:

Target Question
100 Betty is saving money for a new wallet. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents twice as much as her parents. She needs 5 more to buy the wallet. What is the cost of the wallet?
half Betty is saving money for a new wallet which costs 100. She has some money saved, her parents decided to give her 15, and her grandparents gave her twice as much as her parents. Now, Betty needs 5 more to buy the wallet. What is the ratio of the money Betty have saved initially to the cost of wallet?
15 Betty is saving money for a new wallet which costs 100. She has half of the money she needs, her parents decided to give her some money, and her grandparents gave her twice as much as her parents. Now, Betty needs 5 more to buy the wallet. How much money did her parents give her?
twice Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents also chipped in. Now, Betty needs 5 more to buy the wallet. What is the ratio of the amount given by her grandparents to the amount given by her parents?
5 Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?

Example 4:

Q: Your teacher is giving a test worth 200 points. There is a total of 30 5-point and 10-point questions. How many 5-point questions are on the test?
Answer: 20

Note: The “Ask Me Anything” agent is generating problems based on the seed in Example 4. Examples 1 to 3 are provided as few-shot demonstrations. This agent creates a total of 120,445 new problems. All generated problems exhibit a similar narrative to that of the seed word problem. The solutions to these word problems are generated using GPT-4-Turbo.

Agent - Suggester & Editor

We further expand the seed set by developing challenging problems. To achieve this, we introduce two new agents, namely Suggester and Editor. This setup involves the establishment of a team of agents working collaboratively to create a dataset geared toward a predefined objective. In this instance, the objective is to modify an existing problem to increase its difficulty.

Example of Iterative Process:

Seed Problem: Every hour Joanne has to collect the coins out of the fountain inside the mall. During the first hour, she collected 15 coins. For the next two hours, she collected 35 coins from the fountain. In the fourth hour, she collected 50 coins from the fountain but she gave 15 of them to her coworker so she could buy a soda. How many coins did she have after the fourth hour?

Suggester: Increase the number of hours, add more variables for each hour, and incorporate additional transactions.

Editor: Every hour, Joanne has to collect the coins from two separate fountains inside the mall - Fountain A and Fountain B. During the first hour, she collected 20 coins from Fountain A and 30 coins from Fountain B. In the next hour, she collected 25 coins more than she collected in Fountain A during the previous hour and 10 coins less than she collected in Fountain B during the previous hour. In the third hour, she collected the same number of coins in Fountain A as she collected in the second hour and additional 45 coins this time from Fountain B. During the fourth hour, she collected 90 coins in total, but she gave 20 coins to her coworker to buy a soda and 40 coins to another coworker to buy a meal.

Suggester: Incorporate more fountains, increase the duration, and add more complex calculations for coins collected and spent or donated.

Editor: Joanne has to collect coins from four separate fountains inside the mall - Fountain A, Fountain B, Fountain C, and Fountain D every hour for 8 hours. In the first hour, she collects 20 coins from Fountain A, 30 coins from Fountain B, 40 coins from Fountain C, and 50 coins from Fountain D. In the subsequent hours, her collections have the following fluctuations:

  • For Fountain A, she collects 10 more coins each hour than she collected in the previous hour.
  • For Fountain B, her collections decrease by 5 coins each hour compared to the previous hour.
  • For Fountain C, she collects double the number of coins each hour compared to the number of coins collected in Fountain A the same hour.
  • For Fountain D, the number of coins collected remains constant at 50 coins each hour.

At the end of the third hour, she donates 20 coins. During the seventh hour, she has to share coins with her coworkers and splits them as follows: 15 coins for coworker one, 20 coins for coworker two, and she keeps 70 coins. After 8 hours, Joanne decides to count her collected coins. How many coins did she have at the end of the eighth hour?

We allow two rounds of iterations per problem and filter problems where the GPT4-turbo generated answer exceeds 1800 characters. At the end of this process, we collect 37, 157 problems.

DMath Furthermore, we include 6, 216 problems sourced from DMath [15]. These problems represent a subset of the 7, 943 problems present in the DMath training set, in which the solution computed by GPT4-Turbo aligns with the precise gold-standard answer.

3 Training

3.1 Supervised Fine-Tuning Experiment (Iteration #1)

We finetune Mistral-7B on the Orca-Math-200K dataset. We have not used packing. The data is presented in the following instruction format:

USER:\n{question}\n\nASSISTANT:\n{answer}

The loss is computed only on the answer tokens. We employ a constant learning rate of $1 × 10−6$. The per-device batch size is set to 3. Training is conducted for one epoch on eight A100 nodes, with each node containing eight GPUs.

## 3.2 Iterative Learning from both Positive and Negative Signals

### Dataset Construction Iteration #2
To generate additional positive and negative solutions for each problem, we sample four responses from the SFT-tuned model from iteration #1. Specifically, we utilize:
- **top_p**: 0.95
- **temperature**: 0.7

This process results in a dataset where each of the 200,000 problems has one GPT4-Turbo generated solution and four student-generated solutions. Subsequently, we employ the prompt defined in GPT4-Based-Exact-Match (see section 4 for details) to assess the alignment between the teacher’s (GPT4-Turbo) answer and the student’s answer. For all solutions where the student-generated answer does not match the teacher’s answer, we label them as negative; otherwise, we label the solution as positive. We then construct the preference dataset as follows:
    1. For each question, $$ q_i $$, we construct $$ q^+_i $$ (the set of all positive solutions for $$ q_i $$). We treat the teacher solution as positive, thus this set by construction contains at least one element.
    2. For each question, $$ q_i $$, we also construct $$ q^-_i $$ (the set of all negative solutions for $$ q_i $$). This set can be empty if all the four responses are aligned with the teacher’s solution. For such situations, we randomly sample one response from $$ q^-_j $$ for 4 different $$ q_j $$ where $$ j \neq i $$.

The final preference dataset is created by taking the union of $$ Q_i $$ for all $$ q_i $$ in the training dataset.

$$ Q_i = \{ (q_i, a^+_i, a^-_i) \\| a^+_i \in q^+_i \text{ and } a^-_i \in q^-_i \} $$

### Dataset Construction Iteration #3
Let $$ M_2 $$ denote the model trained with KTO [10] on the dataset constructed for Iteration #2. We replicate the same procedure for the construction of dataset for Iteration #3; however, we utilize $$ M_2 $$ to generate the four responses instead of the SFT-tuned model from iteration #1.

To learn from both positive and negative feedback, we evaluate the performance of two algorithms:
    - **Direct Preference Optimization (DPO)** as described by [31]
    - **Kahneman-Tversky Optimization (KTO)** introduced by [10]

DPO is a simple and popular approach for efficiently fine-tuning language models to align with preferences. KTO distinguishes itself by requiring only a binary “yes” or “no” response to assess the quality of an output.

## 4 Evaluation

We use exact match as the metric. Given a model generated answer, we prompt GPT-4 to extract the final short answer and match it with the gold short answer. We will refer to this metric as GPT4-based-Exact-Match. The following figure shows the prompt template:


**SYSTEM:**

As an expert Math teacher, your role is to evaluate a student’s answer to a word problem. The problem is accompanied by a correct solution provided by the problem setter. It is important to remember that there may be various methods to solve a word problem, so the student’s steps might not always align with those in the problem setter’s solution. However, the final answer, typically a number, should be unique and match the problem setter’s answer. Your task involves analyzing the student’s solution to identify any mistakes and determine whether the answer can be modified to correct the error. If the student’s answer is unfixable, consider creating practice problems to help improve their understanding.


Use the following format:

  • Error Analysis: In one sentence, extract the final answer from the problem setter’s solution and compare it with the student’s answer. Do they match?
  • Final Verdict: Correct/Incorrect

USER: Question: Billy is volunteering his time to help people do their taxes. He can help 2 people per hour for 3 hours a day. If he takes 20% of the days between March 1st and April 19th off, and helps people on all the other days. How many people does he help? (Remember there are 31 days in March.)

Problem Setter’s Answer:

  1. First find the number of days between March 1st and April 19th: 31 days + 19 days = 50 days
  2. Then subtract the percentage of time Billy takes off from 100% to find the percentage of time he works: 100% - 20% = 80%
  3. Then multiply that number by 80% to find the number of days Billy works: 50 days * 80% = 40 days
  4. Then find how many people Billy can help per day: 2 people/hour * 3 hours/day = 6 people/day
  5. Then multiply the number of people he helps per day by the number of days he works to find the total number of people he helps: 6 people/day * 40 days = 240 people ```

Student Answer:

  1. First, let’s calculate the total number of days between March 1st and April 19th.
    • March has 31 days, and April has 30 days. So, the total number of days from March 1st to April 19th is: 31 (March) + 19 (April) = 50 days
  2. Now, let’s calculate how many days Billy takes off. He takes 20% of the 50 days off.
    • 20% of 50 days = 0.20 * 50 = 10 days
  3. So, Billy takes 10 days off. This leaves him with: 50 days - 10 days = 40 days
  4. Billy helps people for 3 hours a day, and he can help 2 people per hour. So, in one day, he can help:
    • 3 hours/day * 2 people/hour = 6 people/day
  5. Now, let’s calculate how many people he helps over the 40 days he works:
    • 40 days * 6 people/day = 240 people ```

Therefore, Billy helps 240 people.

**ASSISTANT:**
- **Error Analysis:** The student’s final answer of helping 240 people matches the problem setter’s solution.
- **Final Verdict:** Correct

5 Results

Performance of Training Procedures on GSM8K Test Set

Table 2 captures the performance of several training procedures on the GSM8k test set containing 1319 word problems. We fine-tune Mistral-7B for up to three iterations. In the first iteration, we use supervised fine-tuning (SFT) to obtain M1. For the second iteration, we compare SFT, DPO [31], and KTO. The KTO-trained model performs better in this group. We call this M2 and use M2 to generate the dataset for iteration #3. For the third iteration, we compare DPO and KTO where M2 serves as the starting point. We also compare these against three epochs of SFT training on the Orca-Math-200K dataset. For all SFT training, we employ a constant learning rate of $1 \times 10^{-6}$. The per-device batch size is set to 3 and the number of epochs is set to 1. For DPO and KTO training jobs, we set beta to 0.3, per-device batch size to 3, gradient-accumulation-steps to 11, and number of epochs to 1. For DPO and KTO training in iteration #2, we employ a constant learning rate of $1 \times 10^{-6}$ and for iteration #3, a constant learning rate of $1 \times 10^{-7}$.

Training Procedure Pass@1 Accuracy on GSM8K Test Set
SFT (M1) 79.91
SFT (M1) → SFT 81.50
SFT (M1) → DPO 84.23
SFT (M1) → KTO (M2) 85.06
SFT (M1) → SFT → SFT 80.44
SFT → KTO (M2) → DPO 84.91
SFT → KTO (M2) → KTO (Orca-Math) 86.87

Table 2: Performance of various iterative learning experiments and baselines on the GSM8k test set. SFT stands for one epoch of training on the Orca-Math-200K dataset. SFT → SFT stands for two epochs of training on Orca-Math-200K. SFT → DPO (KTO) stands for one epoch of training on the dataset for iteration #2 with DPO (KTO) starting with M1. SFT → SFT → SFT stands for three epochs of training on Orca-Math-200K. SFT → KTO → DPO (KTO) stands for one epoch of training on the dataset for iteration #3 with DPO (KTO) starting with M2. For evaluation, we employ greedy decoding.

5.1 Ablation Studies to contain only teacher generated solution.

Model Generated Positives

We study the impact of model-generated positives by limiting the dataset to contain only teacher-generated solutions. In other words, we remove any $a^+_i$ that is model $q^+_i$ generated in the creation of the dataset for iteration #2. Table 3 shows the result of training M1 with DPO and KTO on this dataset for one epoch. We reuse the hyperparameters for iteration #2. Irrespective of the training algorithm, we see a significant performance drop.

Training Procedure Pass@1 Accuracy on GSM8K Test Set
M1 → DPO 81.96 (-2.27)
M1 → KTO 82.79 (-2.27)

Table 3: Performance drop when using only teacher-generated positives.

Synthetic Negatives

The preference dataset creation involves synthetic negative creation in the situation where all four responses generated from M1 or M2 are positive. We study the impact of these synthetic negatives by ignoring the questions, $q_i$, where all sampled responses are positive (Table 4). This reduces the number of questions for iteration #2 by around 80k and for iteration #3 by around 104k.

Training Procedure Pass@1 Accuracy on GSM8K Test Set
M1 → DPO 60.73 (-23.5)
M1 → KTO 85.22 (+0.17)
M1 → KTO → KTO 85.44 (-1.43)

Table 4: Impact of ignoring problems where all sampled responses are positive.

5.2 Math Benchmarks beyond GSM8k

Table 5 presents the performance of Orca-Math on several other word problem datasets. For ease of evaluation, we selected datasets where the answer to each problem is a single number. The test sets of the benchmarks are obtained from Lila. We employ the GPT4-based exact-match metric, and model responses are generated using greedy decoding.

Test Set Orca-Math-SFT (M1) Orca-Math
AddSub 88.99 91.74
ASDiv 91.10 91.10
MultiArith 98.28 98.28
SingleOp 98.74 99.37
SingleEq 97.25 99.08
Svamp 87.63 91.30

Table 5: Performance of SFT trained model from Iteration #1 (M1) and Orca-Math on AddSub, ASDiv, MultiArith, SingleOp, SingleEq, and Svamp.

5.3 Contamination Check

We never use the test split of GSM8K or any other datasets during training or as seeds for synthetic problem generation. Nevertheless, We take the following approach for detecting any potential text contamination.

  1. We begin by preprocessing the texts, which includes converting all characters to lowercase, removing punctuation, tokenizing the text into individual words, and removing common English stopwords to ensure uniformity in the data.
  2. We then vectorize our text corpus using the Term Frequency-Inverse Document Frequency (TF-IDF) method and determine the cosine similarity between the test and training sets, from which we select the top-k (k=10) most analogous questions for each test query.
  3. Finally, we evaluate the extent of text contamination by counting the number of test questions with the highest n-gram overlap above a preset threshold of 0.5 with their corresponding training set matches. We calculate the overlap of n-grams between pairs of texts using the Jaccard similarity. To conduct a rigorous contamination check, we set n=1. It is important to note that the n-gram overlap, when measured using Jaccard similarity, is a non-increasing function of n.
  4. Upon executing our algorithm, we determined that the count of test questions exhibiting significant n-gram overlap is eight, thus indicating negligible text contamination within our test set according to the defined threshold. When limiting the train set to contain only the seed problems, the count of test questions exhibiting significant n-gram overlap is seven. Note that, for n ≥ 2, the count of test questions exhibiting significant n-gram overlap is zero.

The generation of synthetic data through generative artificial intelligence (AI) models has evolved rapidly. Numerous datasets [27, 20, 28, 23, 9, 8, 45, 6, 36] have been proposed for both specialized and generic domains, with math-related datasets [40, 43, 44, 18] being closely related to our work.

Learning from rich signals has also garnered significant attention recently. Several studies [31, 10, 22, 3, 5, 41], have demonstrated the usefulness of preference learning. In this work, we present a detailed analysis of agent-based synthetic data generation and iterative preference learning in the grade school level math domain. Specifically, we demonstrate the robustness of KTO over DPO and the effectiveness of using model-generated positives to improve model training. We believe this is a preliminary step toward iterative learning and self improvement of small language models in challenging domains.

7 Conclusions

Our study provides compelling evidence that the mathematical reasoning capabilities of Small Language Models (SLMs) can be substantially enhanced. By employing iterative learning techniques and leveraging both positive and negative signals, we have successfully surpassed the previously perceived 80% barrier on the GSM8k benchmark. Our 7B model, trained with 200K data, achieved an impressive 86.81% accuracy. Furthermore, the incorporation of agents in dataset generation has proven to be a valuable approach, enabling the creation of more diverse and interesting datasets. These findings not only highlight the potential for significant improvements in SLM performance but also underscore the importance of innovative learning strategies and dataset generation methods in advancing the creation of powerful SLMs.

Previous: Model | Yi Tech Report Next: Attn | Fire Attention

post contain ""

    No matching posts found containing ""