Contents
고품질 합성 데이터를 생성하고, 이를 통해 소형 언어 모델의 수학적 인퍼런스 능력을 반복적으로 향상시키는 방법을 체계적으로 제시합니다. 수학적 인퍼런스 능력의 향상을 위해, 데이터셋의 다양성과 난이도를 높이고, 반복 학습을 통해 모델의 성능을 단계적으로 개선하는 과정이 상세히 설명됩니다. 이를 통해, 소형 언어 모델이 대규모 모델과 경쟁할 수 있는 성능을 갖출 수 있음을 입증합니다.
1 문제 설정
대규모 언어모델(LLM), 예를 들어 GPT-4,은 소규모 모델(SLM)에서는 볼 수 없었던 향상된 수학적 인퍼런스 능력을 보입니다. 이런 능력은 대규모 모델 크기, 데이터셋 크기 및 학습에 필요한 계산량에 크게 기인합니다. 최근 연구들은 SLM의 인퍼런스 능력 향상에 집중하고 있으며, 그 중 한 가지 유망한 방법은 대규모 모델을 사용해 SLM을 위한 맞춤형 고품질 합성 데이터를 생성하는 것입니다. 이 연구는 수학적 인퍼런스에 초점을 맞추고 있으며, 특히 초등 수학 문제(GSM8K 벤치마크)를 사용하여 연구를 진행합니다.
2 데이터셋 생성: Agent-Instruct
다양한 에이전트를 활용하여 초등 수학 문제의 다양성을 확보하고 난이도를 높이기 위한 데이터를 생성합니다.
2.1 기본 데이터셋
기존의 공개 데이터셋에서 36,217개의 문제를 수집합니다. 이 데이터는 Lila 벤치마크를 사용하여 수집되었습니다.
2.2 에이전트 - Ask Me Anything
기본 데이터셋의 각 문제를 변형하여 새로운 문제를 생성합니다. 예를 들어, 원래 문제를 문장으로 변환한 후, 각 숫자를 다른 문제로 바꾸어 새로운 문제를 생성합니다. 이 과정에서 GPT-4-Turbo를 사용해 문제의 해답을 생성합니다.
2.3 에이전트 - Suggester & Editor
기존 문제를 더 어렵게 변형하는 에이전트를 활용합니다. Suggester는 문제의 난이도를 높이는 방법을 제안하고, Editor는 이를 바탕으로 문제를 수정하여 난이도를 높입니다. 이런 과정을 통해 37,157개의 추가 문제를 생성합니다.
2.4 DMath 데이터셋
DMath에서 6,216개의 문제를 추가로 포함시킵니다.
3 학습
3.1 Supervised Fine-Tuning Experiment (1차 반복)
Mistral-7B 모델을 Orca-Math-200K 데이터셋으로 파인튜닝합니다. 학습률은 $1 \times 10^{-6}$이며, 각 GPU마다 배치 크기는 3으로 설정됩니다. 8개의 A100 노드에서 한 에포크 동안 학습을 진행합니다.
3.2 Iterative Learning from both Positive and Negative Signals (2차 반복)
첫 번째 반복에서 파인튜닝된 모델로부터 네 가지 응답을 샘플링하여 긍정적 및 부정적 해답을 생성합니다. 이후, GPT-4 기반의 정확한 일치 기준을 사용하여 해답의 일치 여부를 평가하고, 이를 통해 training dataset를 구성합니다.
3.3 Iterative Learning (3차 반복)
2차 반복에서 생성된 모델(M2)을 사용하여 네 가지 응답을 다시 생성하고, 이를 바탕으로 새로운 training dataset를 구성합니다. Direct Preference Optimization (DPO)와 Kahneman-Tversky Optimization (KTO)을 사용하여 긍정적 및 부정적 피드백을 모두 학습합니다.
4 평가
GPT-4 기반의 정확한 일치 기준을 사용하여 모델의 성능을 평가합니다. 예를 들어, student의 답안과 정답의 일치 여부를 확인하는 과정을 통해 모델의 성능을 측정합니다.
5 결과
다양한 학습 절차를 통해 GSM8k 테스트 세트에서의 성능을 평가합니다. 첫 번째 반복에서는 지도 학습으로 79.91%의 정확도를 달성하고, 두 번째 반복에서는 KTO를 통해 85.06%로 성능이 향상됩니다. 세 번째 반복에서는 KTO를 통해 86.87%의 정확도를 달성합니다.
6 관련 연구
이 연구는 합성 데이터를 통한 학습과 반복 학습을 통해 SLM의 수학적 인퍼런스 능력을 향상시키는 방법을 제시합니다. 특히 KTO의 강력한 성능과 모델 생성 긍정 해답의 효과를 입증합니다. 이 연구는 소형 언어 모델의 반복 학습 및 자기 개선을 위한 초기 단계로서의 의의를 가집니다.
[참고자료 1] 주요 섹션 살펴보기
3.2 Iterative Learning from both Positive and Negative Signals
Dataset Construction Iteration #2
각 문제에 대해 추가적인 긍정적 및 부정적 해답을 생성하기 위해, 첫 번째 반복에서 파인튜닝된 모델로부터 네 가지 응답을 샘플링합니다. 구체적으로, 다음과 같은 설정을 사용합니다.
이 과정은 200,000개의 문제 각각에 대해 하나의 GPT-4-Turbo 생성 해답과 네 가지 student 생성 해답을 포함하는 데이터셋을 생성합니다. 이후, GPT4-Based-Exact-Match (섹션 4 참조)에서 정의된 프롬프트를 사용하여 teacher(GPT-4-Turbo)의 해답과 student의 해답 간의 일치 여부를 평가합니다. student 생성 해답이 teacher 해답과 일치하지 않는 경우, 해당 해답을 부정적으로 레이블링하고, 그렇지 않은 경우에는 긍정적으로 레이블링합니다. 그런 다음 다음과 같이 선호 데이터셋을 구성합니다.
최종 선호 데이터셋은 모든 \(q_i\)에 대한 \(Q_i\)의 합집합으로 생성됩니다.
\[Q_i = \{ (q_i, a^+_i, a^-_i) | a^+_i \in q^+_i \text{ and } a^-_i \in q^-_i \}\]Dataset Construction Iteration #3
\(M_2\)를 사용하여 네 가지 응답을 다시 생성하고, 이를 바탕으로 새로운 training dataset를 구성합니다. 두 가지 알고리즘의 성능을 평가합니다.
DPO는 효율적으로 언어 모델을 선호도에 맞추어 파인튜닝하는 간단하고 유명한 접근 방식입니다. KTO는 출력의 품질을 평가하기 위해 이진 “예” 또는 “아니오” 응답만을 요구합니다.
4 Evaluation
정확한 일치 기준을 메트릭으로 사용합니다. 모델 생성 해답을 주고, GPT-4에게 최종 짧은 해답을 추출하여 금 해답과 일치시키도록 합니다. 이를 GPT4-based-Exact-Match라고 합니다. 다음은 프롬프트 템플릿의 예입니다.
SYSTEM: 수학 전문가 teacher로서, 당신의 역할은 student의 단어 문제에 대한 해답을 평가하는 것입니다. 문제는 문제 작성자가 제공한 올바른 해답과 함께 제공됩니다. 단어 문제를 해결하는 방법은 다양할 수 있으므로, student의 단계가 문제 작성자의 해답과 항상 일치하지 않을 수 있습니다. 그러나 최종 해답은 일반적으로 하나의 숫자로 고유하며 문제 작성자의 해답과 일치해야 합니다. 당신의 과제는 student의 해답을 분석하여 실수를 식별하고 오류를 수정할 수 있는지 판단하는 것입니다. student의 해답이 수정 불가능한 경우, 이해도를 향상시키기 위한 연습 문제를 만드는 것을 고려하십시오.
Error Analysis: 문제 작성자의 해답에서 최종 해답을 추출하고 student의 해답과 비교하세요. 일치합니까?
Final Verdict: 올바름/올바르지 않음
USER: Question: Billy는 사람들의 세금 신고를 도와주는 자원봉사를 하고 있습니다. 그는 하루에 3시간 동안 한 시간에 2명을 도울 수 있습니다. Billy는 3월 1일부터 4월 19일까지의 날 중 20%를 쉬고, 나머지 날에 사람들을 돕습니다. 그는 총 몇 명을 도울 수 있습니까? (3월에는 31일이 있습니다.)
Problem Setter’s Answer:
Student Answer:
따라서, Billy는 240명을 돕습니다.
ASSISTANT:
Performance of Training Procedures on GSM8K Test Set
Table 2는 1319개의 단어 문제를 포함하는 GSM8k 테스트 세트에서 여러 학습 절차의 성능을 보여줍니다. Mistral-7B를 최대 세 번의 반복 동안 파인튜닝합니다. 첫 번째 반복에서는 지도 학습(SFT)을 사용하여 M1을 얻습니다. 두 번째 반복에서는 SFT, DPO [31], KTO를 비교합니다. KTO로 훈련된 모델이 이 그룹에서 더 나은 성능을 보입니다. 이를 M2라고 부르고, M2를 사용하여 세 번째 반복을 위한 데이터셋을 생성합니다. 세 번째 반복에서는 DPO와 KTO를 비교하며, 이때 M2가 시작점이 됩니다. 또한 Orca-Math-200K 데이터셋에서 세 번의 SFT 학습과 비교합니다. 모든 SFT 학습에서는 $1 \times 10^{-6}$의 일정 학습률을 사용합니다. 장치당 배치 크기는 3으로 설정하고, 에포크 수는 1로 설정합니다. DPO 및 KTO 학습 작업에서는 베타 값을 0.3으로 설정하고, 장치당 배치 크기는 3, 기울기 누적 단계는 11, 에포크 수는 1로 설정합니다. 두 번째 반복에서 DPO와 KTO 학습에는 $1 \times 10^{-6}$의 일정 학습률을 사용하고, 세 번째 반복에서는 $1 \times 10^{-7}$의 일정 학습률을 사용합니다.
Training Procedure | Pass@1 Accuracy on GSM8K Test Set |
---|---|
SFT (M1) | 79.91 |
SFT (M1) → SFT | 81.50 |
SFT (M1) → DPO | 84.23 |
SFT (M1) → KTO (M2) | 85.06 |
SFT (M1) → SFT → SFT | 80.44 |
SFT → KTO (M2) → DPO | 84.91 |
SFT → KTO (M2) → KTO (Orca-Math) | 86.87 |
Table 2: GSM8k 테스트 세트에서 다양한 반복 학습 실험 및 기준 성능. SFT는 Orca-Math-200K 데이터셋에서 한 에포크 동안의 학습을 의미합니다. SFT → SFT는 Orca-Math-200K에서 두 에포크 동안의 학습을 의미합니다. SFT → DPO (KTO)는 M1에서 시작하여 반복 #2를 위한 데이터셋에서 DPO (KTO)로 한 에포크 동안의 학습을 의미합니다. SFT → SFT → SFT는 Orca-Math-200K에서 세 에포크 동안의 학습을 의미합니다. SFT → KTO → DPO (KTO)는 M2에서 시작하여 반복 #3을 위한 데이터셋에서 DPO (KTO)로 한 에포크 동안의 학습을 의미합니다. 평가를 위해 greedy decoding을 사용합니다.
5.1 Ablation Studies
Model Generated Positives
모델 생성 긍정의 영향을 연구하기 위해 데이터셋을 teacher 생성 해답만 포함하도록 제한합니다. 즉, 반복 #2 데이터셋 생성에서 모델 \(q^+_i\)가 생성한 모든 \(a^+_i\)를 제거합니다. Table 3은 이 데이터셋에서 M1을 DPO와 KTO로 한 에포크 동안 학습한 결과를 보여줍니다. 반복 #2에 대한 하이퍼파라미터를 재사용합니다. 학습 알고리즘과 관계없이 성능이 크게 감소하는 것을 확인할 수 있습니다.
Training Procedure | Pass@1 Accuracy on GSM8K Test Set |
---|---|
M1 → DPO | 81.96 (-2.27) |
M1 → KTO | 82.79 (-2.27) |
Table 3: teacher 생성 긍정 해답만 사용할 때 성능 저하.
Synthetic Negatives
모든 네 가지 응답이 긍정적인 경우, 선호 데이터셋 생성은 합성 부정적 해답 생성을 포함합니다. 모든 샘플링된 응답이 긍정적인 질문 \(q_i\)를 무시함으로써 이런 합성 부정적 해답의 영향을 연구합니다 (Table 4). 이는 반복 #2의 질문 수를 약 80k, 반복 #3의 질문 수를 약 104k 줄입니다.
Training Procedure | Pass@1 Accuracy on GSM8K Test Set |
---|---|
M1 → DPO | 60.73 (-23.5) |
M1 → KTO | 85.22 (+0.17) |
M1 → KTO → KTO | 85.44 (-1.43) |
Table 4: 모든 샘플링된 응답이 긍정적인 문제를 포함하는 것이 유익함을 보여줌.
5.2 Math Benchmarks beyond GSM8K
Table 5는 여러 다른 단어 문제 데이터셋에서 Orca-Math의 성능을 보여줍니다. 평가의 편의를 위해, 각 문제의 해답이 단일 숫자인 데이터셋을 선택했습니다. 벤치마크의 테스트 세트는 Lila에서 얻었습니다. GPT4 기반의 정확한 일치 메트릭을 사용하며, 모델 응답은 greedy decoding을 사용하여 생성됩니다.
Test Set | Orca-Math-SFT (M1) | Orca-Math |
---|---|---|
AddSub | 88.99 | 91.74 |
ASDiv | 91.10 | 91.10 |
MultiArith | 98.28 | 98.28 |
SingleOp | 98.74 | 99.37 |
SingleEq | 97.25 | 99.08 |
Svamp | 87.63 | 91.30 |
Table 5: Iteration #1 (M1)과 Orca-Math에서 SFT 훈련된 모델의 AddSub, ASDiv, MultiArith, SingleOp, SingleEq, Svamp 성능.
[참고자료 1] Kahneman-Tversky Optimization (KTO)
Kahneman-Tversky Optimization (KTO)는 심리학자 Daniel Kahneman과 Amos Tversky의 행동 경제학 이론을 기반으로 한 최적화 방법입니다. 이 최적화 방법은 사람들이 실제로 의사결정을 내리는 방식을 반영하는 모델을 통해 보다 현실적인 최적화 문제를 해결하는 것을 목표로 합니다. KTO는 특히 불확실성과 위험이 존재하는 환경에서의 의사결정 과정을 다룹니다.
이론적 배경
KTO는 주로 Kahneman과 Tversky의 전망 이론 (Prospect Theory)에 기반을 둡니다. 전망 이론은 사람들이 기대효용 이론의 합리적 모델과는 다르게 실제로 어떻게 선택을 하는지를 설명합니다. 이 이론의 핵심 요소는 다음과 같습니다.
최적화 문제로서의 KTO
KTO는 이런 휴먼의 의사결정 편향을 반영한 최적화 문제를 정의합니다. 이를 위해 전통적인 최적화 문제의 목표 함수에 전망 이론의 요소를 통합합니다.
\(\alpha\)와 \(\beta\)는 감쇠 계수로, 일반적으로 \(0 < \alpha, \beta \leq 1\)이며, \(\lambda\)는 손실 회피 계수로 \(\lambda > 1\)입니다.
\(0 < \gamma \leq 1\)는 가중 계수입니다.
KTO의 목표 함수
KTO의 목표 함수는 이런 요소들을 통합하여 정의됩니다. 예를 들어, 의사결정자가 \(n\)개의 대안을 가지고 있고, 각 대안 \(i\)에 대해 예상되는 결과가 \(x_i\), 확률이 \(p_i\)라고 할 때, KTO의 목표 함수는 다음과 같이 정의될 수 있습니다.
\[\text{KTO Objective} = \sum_{i=1}^n \pi(p_i) v(x_i - r)\]\(r\)은 참조점입니다.
예를 들어, 투자자가 두 가지 투자 옵션 \(A\)와 \(B\)가 있으며, 다음과 같은 확률로 손실과 이익이 발생할 수 있다고 가정하면,
이 경우, 투자자의 참조점 \(r\)을 $0로 설정하고, 가치 함수와 확률 가중 함수를 적용하여 각각의 기대 가치를 계산할 수 있습니다.
옵션 \(A\)의 기대 가치
\(x_1 = 100\), \(p_1 = 0.5\)
\(x_2 = -50\), \(p_2 = 0.5\)
\[v(100) = 100^\alpha\] \[v(-50) = -\lambda (50^\beta)\] \[\pi(0.5) = \frac{0.5^\gamma}{(0.5^\gamma + 0.5^\gamma)^{1/\gamma}} = 0.5\]따라서,
\[\text{KTO}(A) = 0.5 \cdot 100^\alpha + 0.5 \cdot (-\lambda (50^\beta))\]옵션 \(B\)의 기대 가치
\(x = 20\), \(p = 1\)
\[v(20) = 20^\alpha\] \[\pi(1) = 1\]따라서, 두 기대 가치를 비교하여 최적의 투자를 선택할 수 있습니다.
\[\text{KTO}(B) = 1 \cdot 20^\alpha\]결론
Kahneman-Tversky Optimization (KTO)은 전통적인 최적화 기법에 행동 경제학적 요소를 통합하여 보다 현실적인 의사결정을 지원하는 방법으로 실제 사람들의 의사결정 과정에서 나타나는 편향을 반영할 수 있으며, 불확실성과 위험이 존재하는 환경에서의 최적화 문제를 효과적으로 해결할 수 있습니다.
Frontier Language Models such as GPT-4 [1] have demonstrated capabilities previously unseen in smaller models, most notably the remarkable ability to reason (e.g. mathematical reasoning that requires both language comprehension and mathematical understanding). These capabilities have been largely attributed to the very large scale the model size, the dataset size and ultimately the amount of compute needed for training.
Several recent studies have focused on improved the reasoning abilities of small language models (SLMs). Despite that the extent to which scale is needed for achieving reasoning capabilities is still an open research question.
One of the promising directions of improving the reasoning capabilities of SLMs is using frontier language models, such as GPT-4, to create tailored and high-quality synthetic data that can be used to train the SLM. The high quality of the training data and the ability to elicit richer learning signals (e.g. explanations) have been show to significantly improve SLMs abilities in acquiring skills that had only emerged before at much larger scale.
This paradigm fits under a teacher-student approach where the large model (the teacher) is creating demonstrations for the SLM (the student) to learn from. In this work we further explore this direction with focus on mathematical reasoning on grade school math world problem, using the popular GSM8K benchmark.
Several other studies have demonstrated positive results on GSM8K recently with SLMs, e.g. Phi-GSM [21], OVM [38], etc. However, many of them employ ensembling, where outputs of up to 100 model runs are combined to arrive at a more accurate results. Result selection is done using, consensus, majority vote or by using a separate a verifier model to score/verify the outputs and select the best answer. Ensembling provides a substantial boost in accuracy (e.g., Phi-GSM uses top-48 to boost the performance from 68.2 to 81.5, [22] uses top-100 to boost LLAMA-2’s performance from 38.6% to 71.9%). However it comes at a significant increase in cost with multiple calls to the model, generating and verifying a 100 different solutions requires 200 different calls to the models. Additionally, some of them use very larger amounts of data (e.g. 12M for Phi-GSM) or use tools or code to avoid calculation errors.
In this work, we extend the teacher-student paradigm to an iterative learning settings with high-quality synthetic training data as follows:
With the supervised finetuning alone, we achieve 81.50% on GSM8k at pass@1 metric. The iterative learning loop further improves the pass@1 to 86.81%. without the need for multiple model calls or the use of verifiers, code execution or any other external tools. The model exceeding much bigger models like LLAMA-2-70B (56.8%), WizardMath-70B (81.6%), Gemini Pro (86.5% with 32 trials) and GPT-3.5 (77.4%). Most notably it can reach this level with only 200K examples (orders of magnitude less than other datasets).
The goal of this step is to create a diverse set of grade school math word problems that contains both easy and hard problems. Towards this goal we create a variety of agents.
Seed Set We start by collecting sample math word problems from existing open-source datasets, namely NumGLUE [26], AddSub [13], ALGES [17], ASDiv [24], DRAW [35], GSM8k [7], MATHQA [2], MultiArith [32], SingeOP [33], SingleEQ [16], and SVAMP [30]. We collect a total of 36, 217 problems. We utilize the Lila [25] benchmark to collect the datasets. Specifically, we collect problems from the train and validation splits from Lila to construct the seed set. Interested readers, please refer to Lila [25].
Agent - Ask Me Anything
We expand the seed set by creating multiple word problems from each problem in the seed set. We utilize the subsequent prompt for problem creation.
Prompt for Problem Creation:
Your goal is to create multiple word problems from a given word problem and its answer. First convert the question of the word problem into a statement. Then for each number in the converted problem create a new word problem. Here are some examples:
Example 1:
Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Answer: 72
Replacing question with statement: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. Natalia sold altogether 72 clips in April and May.
All questions:
Target | Question |
---|---|
48 | Natalia sold clips to some of her friends in April, and then she sold half as many clips in May. Natalia sold altogether 72 clips in April and May. How many clips did she sell in April? |
half | Natalia sold clips to 48 of her friends in April, and then she sold some clips in May. Natalia sold altogether 72 clips in April and May. What is the ratio of the number clips sold in April to number clips sold in May? |
72 | Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? |
Example 2:
Q: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
Answer: 10
Replacing question with statement: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. She earned $10.
All questions:
Target | Question |
---|---|
12 | Weng earns a certain amount per hour for babysitting. Yesterday, she just did 50 minutes of babysitting and earned 10. How much does she earn per hour? |
50 | Weng earns 12 an hour for babysitting. Yesterday, she just did some babysitting and earned 10. How much time did she spend on babysitting? |
10 | Weng earns 12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? |
Example 3:
Q: Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?
Answer: 5
Replacing question with statement: Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents gave her twice as much as her parents. She needs 5 more to buy the wallet.
All questions:
Target | Question |
---|---|
100 | Betty is saving money for a new wallet. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents twice as much as her parents. She needs 5 more to buy the wallet. What is the cost of the wallet? |
half | Betty is saving money for a new wallet which costs 100. She has some money saved, her parents decided to give her 15, and her grandparents gave her twice as much as her parents. Now, Betty needs 5 more to buy the wallet. What is the ratio of the money Betty have saved initially to the cost of wallet? |
15 | Betty is saving money for a new wallet which costs 100. She has half of the money she needs, her parents decided to give her some money, and her grandparents gave her twice as much as her parents. Now, Betty needs 5 more to buy the wallet. How much money did her parents give her? |
twice | Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents also chipped in. Now, Betty needs 5 more to buy the wallet. What is the ratio of the amount given by her grandparents to the amount given by her parents? |
5 | Betty is saving money for a new wallet which costs 100. Betty has only half of the money she needs. Her parents decided to give her 15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet? |
Example 4:
Q: Your teacher is giving a test worth 200 points. There is a total of 30 5-point and 10-point questions. How many 5-point questions are on the test?
Answer: 20
Note: The “Ask Me Anything” agent is generating problems based on the seed in Example 4. Examples 1 to 3 are provided as few-shot demonstrations. This agent creates a total of 120,445 new problems. All generated problems exhibit a similar narrative to that of the seed word problem. The solutions to these word problems are generated using GPT-4-Turbo.
Agent - Suggester & Editor
We further expand the seed set by developing challenging problems. To achieve this, we introduce two new agents, namely Suggester and Editor. This setup involves the establishment of a team of agents working collaboratively to create a dataset geared toward a predefined objective. In this instance, the objective is to modify an existing problem to increase its difficulty.
Example of Iterative Process:
Seed Problem: Every hour Joanne has to collect the coins out of the fountain inside the mall. During the first hour, she collected 15 coins. For the next two hours, she collected 35 coins from the fountain. In the fourth hour, she collected 50 coins from the fountain but she gave 15 of them to her coworker so she could buy a soda. How many coins did she have after the fourth hour?
Suggester: Increase the number of hours, add more variables for each hour, and incorporate additional transactions.
Editor: Every hour, Joanne has to collect the coins from two separate fountains inside the mall - Fountain A and Fountain B. During the first hour, she collected 20 coins from Fountain A and 30 coins from Fountain B. In the next hour, she collected 25 coins more than she collected in Fountain A during the previous hour and 10 coins less than she collected in Fountain B during the previous hour. In the third hour, she collected the same number of coins in Fountain A as she collected in the second hour and additional 45 coins this time from Fountain B. During the fourth hour, she collected 90 coins in total, but she gave 20 coins to her coworker to buy a soda and 40 coins to another coworker to buy a meal.
Suggester: Incorporate more fountains, increase the duration, and add more complex calculations for coins collected and spent or donated.
Editor: Joanne has to collect coins from four separate fountains inside the mall - Fountain A, Fountain B, Fountain C, and Fountain D every hour for 8 hours. In the first hour, she collects 20 coins from Fountain A, 30 coins from Fountain B, 40 coins from Fountain C, and 50 coins from Fountain D. In the subsequent hours, her collections have the following fluctuations:
At the end of the third hour, she donates 20 coins. During the seventh hour, she has to share coins with her coworkers and splits them as follows: 15 coins for coworker one, 20 coins for coworker two, and she keeps 70 coins. After 8 hours, Joanne decides to count her collected coins. How many coins did she have at the end of the eighth hour?
We allow two rounds of iterations per problem and filter problems where the GPT4-turbo generated answer exceeds 1800 characters. At the end of this process, we collect 37, 157 problems.
DMath Furthermore, we include 6, 216 problems sourced from DMath [15]. These problems represent a subset of the 7, 943 problems present in the DMath training set, in which the solution computed by GPT4-Turbo aligns with the precise gold-standard answer.
We finetune Mistral-7B on the Orca-Math-200K dataset. We have not used packing. The data is presented in the following instruction format:
USER:\n{question}\n\nASSISTANT:\n{answer}
The loss is computed only on the answer tokens. We employ a constant learning rate of $1 × 10−6$. The per-device batch size is set to 3. Training is conducted for one epoch on eight A100 nodes, with each node containing eight GPUs.
## 3.2 Iterative Learning from both Positive and Negative Signals
### Dataset Construction Iteration #2
To generate additional positive and negative solutions for each problem, we sample four responses from the SFT-tuned model from iteration #1. Specifically, we utilize:
- **top_p**: 0.95
- **temperature**: 0.7
This process results in a dataset where each of the 200,000 problems has one GPT4-Turbo generated solution and four student-generated solutions. Subsequently, we employ the prompt defined in GPT4-Based-Exact-Match (see section 4 for details) to assess the alignment between the teacher’s (GPT4-Turbo) answer and the student’s answer. For all solutions where the student-generated answer does not match the teacher’s answer, we label them as negative; otherwise, we label the solution as positive. We then construct the preference dataset as follows:
1. For each question, $$ q_i $$, we construct $$ q^+_i $$ (the set of all positive solutions for $$ q_i $$). We treat the teacher solution as positive, thus this set by construction contains at least one element.
2. For each question, $$ q_i $$, we also construct $$ q^-_i $$ (the set of all negative solutions for $$ q_i $$). This set can be empty if all the four responses are aligned with the teacher’s solution. For such situations, we randomly sample one response from $$ q^-_j $$ for 4 different $$ q_j $$ where $$ j \neq i $$.
The final preference dataset is created by taking the union of $$ Q_i $$ for all $$ q_i $$ in the training dataset.
$$ Q_i = \{ (q_i, a^+_i, a^-_i) \\| a^+_i \in q^+_i \text{ and } a^-_i \in q^-_i \} $$
### Dataset Construction Iteration #3
Let $$ M_2 $$ denote the model trained with KTO [10] on the dataset constructed for Iteration #2. We replicate the same procedure for the construction of dataset for Iteration #3; however, we utilize $$ M_2 $$ to generate the four responses instead of the SFT-tuned model from iteration #1.
To learn from both positive and negative feedback, we evaluate the performance of two algorithms:
- **Direct Preference Optimization (DPO)** as described by [31]
- **Kahneman-Tversky Optimization (KTO)** introduced by [10]
DPO is a simple and popular approach for efficiently fine-tuning language models to align with preferences. KTO distinguishes itself by requiring only a binary “yes” or “no” response to assess the quality of an output.
## 4 Evaluation
We use exact match as the metric. Given a model generated answer, we prompt GPT-4 to extract the final short answer and match it with the gold short answer. We will refer to this metric as GPT4-based-Exact-Match. The following figure shows the prompt template:
**SYSTEM:**
As an expert Math teacher, your role is to evaluate a student’s answer to a word problem. The problem is accompanied by a correct solution provided by the problem setter. It is important to remember that there may be various methods to solve a word problem, so the student’s steps might not always align with those in the problem setter’s solution. However, the final answer, typically a number, should be unique and match the problem setter’s answer. Your task involves analyzing the student’s solution to identify any mistakes and determine whether the answer can be modified to correct the error. If the student’s answer is unfixable, consider creating practice problems to help improve their understanding.
Use the following format:
USER: Question: Billy is volunteering his time to help people do their taxes. He can help 2 people per hour for 3 hours a day. If he takes 20% of the days between March 1st and April 19th off, and helps people on all the other days. How many people does he help? (Remember there are 31 days in March.)
Problem Setter’s Answer:
Student Answer:
Therefore, Billy helps 240 people.
**ASSISTANT:**
- **Error Analysis:** The student’s final answer of helping 240 people matches the problem setter’s solution.
- **Final Verdict:** Correct
Performance of Training Procedures on GSM8K Test Set
Table 2 captures the performance of several training procedures on the GSM8k test set containing 1319 word problems. We fine-tune Mistral-7B for up to three iterations. In the first iteration, we use supervised fine-tuning (SFT) to obtain M1. For the second iteration, we compare SFT, DPO [31], and KTO. The KTO-trained model performs better in this group. We call this M2 and use M2 to generate the dataset for iteration #3. For the third iteration, we compare DPO and KTO where M2 serves as the starting point. We also compare these against three epochs of SFT training on the Orca-Math-200K dataset. For all SFT training, we employ a constant learning rate of $1 \times 10^{-6}$. The per-device batch size is set to 3 and the number of epochs is set to 1. For DPO and KTO training jobs, we set beta to 0.3, per-device batch size to 3, gradient-accumulation-steps to 11, and number of epochs to 1. For DPO and KTO training in iteration #2, we employ a constant learning rate of $1 \times 10^{-6}$ and for iteration #3, a constant learning rate of $1 \times 10^{-7}$.
Training Procedure | Pass@1 Accuracy on GSM8K Test Set |
---|---|
SFT (M1) | 79.91 |
SFT (M1) → SFT | 81.50 |
SFT (M1) → DPO | 84.23 |
SFT (M1) → KTO (M2) | 85.06 |
SFT (M1) → SFT → SFT | 80.44 |
SFT → KTO (M2) → DPO | 84.91 |
SFT → KTO (M2) → KTO (Orca-Math) | 86.87 |
Table 2: Performance of various iterative learning experiments and baselines on the GSM8k test set. SFT stands for one epoch of training on the Orca-Math-200K dataset. SFT → SFT stands for two epochs of training on Orca-Math-200K. SFT → DPO (KTO) stands for one epoch of training on the dataset for iteration #2 with DPO (KTO) starting with M1. SFT → SFT → SFT stands for three epochs of training on Orca-Math-200K. SFT → KTO → DPO (KTO) stands for one epoch of training on the dataset for iteration #3 with DPO (KTO) starting with M2. For evaluation, we employ greedy decoding.
Model Generated Positives
We study the impact of model-generated positives by limiting the dataset to contain only teacher-generated solutions. In other words, we remove any $a^+_i$ that is model $q^+_i$ generated in the creation of the dataset for iteration #2. Table 3 shows the result of training M1 with DPO and KTO on this dataset for one epoch. We reuse the hyperparameters for iteration #2. Irrespective of the training algorithm, we see a significant performance drop.
Training Procedure | Pass@1 Accuracy on GSM8K Test Set |
---|---|
M1 → DPO | 81.96 (-2.27) |
M1 → KTO | 82.79 (-2.27) |
Table 3: Performance drop when using only teacher-generated positives.
Synthetic Negatives
The preference dataset creation involves synthetic negative creation in the situation where all four responses generated from M1 or M2 are positive. We study the impact of these synthetic negatives by ignoring the questions, $q_i$, where all sampled responses are positive (Table 4). This reduces the number of questions for iteration #2 by around 80k and for iteration #3 by around 104k.
Training Procedure | Pass@1 Accuracy on GSM8K Test Set |
---|---|
M1 → DPO | 60.73 (-23.5) |
M1 → KTO | 85.22 (+0.17) |
M1 → KTO → KTO | 85.44 (-1.43) |
Table 4: Impact of ignoring problems where all sampled responses are positive.
Table 5 presents the performance of Orca-Math on several other word problem datasets. For ease of evaluation, we selected datasets where the answer to each problem is a single number. The test sets of the benchmarks are obtained from Lila. We employ the GPT4-based exact-match metric, and model responses are generated using greedy decoding.
Test Set | Orca-Math-SFT (M1) | Orca-Math |
---|---|---|
AddSub | 88.99 | 91.74 |
ASDiv | 91.10 | 91.10 |
MultiArith | 98.28 | 98.28 |
SingleOp | 98.74 | 99.37 |
SingleEq | 97.25 | 99.08 |
Svamp | 87.63 | 91.30 |
Table 5: Performance of SFT trained model from Iteration #1 (M1) and Orca-Math on AddSub, ASDiv, MultiArith, SingleOp, SingleEq, and Svamp.
We never use the test split of GSM8K or any other datasets during training or as seeds for synthetic problem generation. Nevertheless, We take the following approach for detecting any potential text contamination.
The generation of synthetic data through generative artificial intelligence (AI) models has evolved rapidly. Numerous datasets [27, 20, 28, 23, 9, 8, 45, 6, 36] have been proposed for both specialized and generic domains, with math-related datasets [40, 43, 44, 18] being closely related to our work.
Learning from rich signals has also garnered significant attention recently. Several studies [31, 10, 22, 3, 5, 41], have demonstrated the usefulness of preference learning. In this work, we present a detailed analysis of agent-based synthetic data generation and iterative preference learning in the grade school level math domain. Specifically, we demonstrate the robustness of KTO over DPO and the effectiveness of using model-generated positives to improve model training. We believe this is a preliminary step toward iterative learning and self improvement of small language models in challenging domains.
Our study provides compelling evidence that the mathematical reasoning capabilities of Small Language Models (SLMs) can be substantially enhanced. By employing iterative learning techniques and leveraging both positive and negative signals, we have successfully surpassed the previously perceived 80% barrier on the GSM8k benchmark. Our 7B model, trained with 200K data, achieved an impressive 86.81% accuracy. Furthermore, the incorporation of agents in dataset generation has proven to be a valuable approach, enabling the creation of more diverse and interesting datasets. These findings not only highlight the potential for significant improvements in SLM performance but also underscore the importance of innovative learning strategies and dataset generation methods in advancing the creation of powerful SLMs.