00:00:00

Share Your Feedback 🏝️

Reasoning | Iterative Reasoning

Reasoning | Iterative Reasoning

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Inner Working** Next: FLAME | Factuality-Aware Alignment for Large Language Models

Reasoning | Iterative Reasoning

  • Related Project: Private
  • Category: Paper Review
  • Date: 2024-04-30

Iterative Reasoning Preference Optimization

  • url: https://arxiv.org/abs/2404.19733
  • pdf: https://arxiv.org/pdf/2404.19733
  • html https://arxiv.org/html/2404.19733v1
  • abstract: Iterative preference optimization methods have recently been shown to perform well for general instruction tuning tasks, but typically make little improvement on reasoning tasks (Yuan et al., 2024, Chen et al., 2024). In this work we develop an iterative approach that optimizes the preference between competing generated Chain-of-Thought (CoT) candidates by optimizing for winning vs. losing reasoning steps that lead to the correct answer. We train using a modified DPO loss (Rafailov et al., 2023) with an additional negative log-likelihood term, which we find to be crucial. We show reasoning improves across repeated iterations of this scheme. While only relying on examples in the training set, our approach results in increasing accuracy for Llama-2-70B-Chat from 55.6% to 81.6% on GSM8K (and 88.7% with majority voting out of 32 samples), from 12.5% to 20.8% on MATH, and from 77.8% to 86.7% on ARC-Challenge, which outperforms other Llama-2-based models not relying on additionally sourced datasets.

Contents

TL;DR


Iterative Reasoning Preference Optimization

  1. 선호 최적화 기반의 인퍼런스 향상: 반복적인 선호 최적화를 통해 인퍼런스 성능을 향상.
  2. 연쇄 인퍼런스 단계 생성: 생성된 인퍼런스 단계와 최종 답변을 활용해 선호 쌍을 생성하여 학습.
  3. 모델 성능 향상: 다양한 벤치마크에서 반복 학습을 통해 상당한 정확도 향상.

[서론]

선호 최적화는 기존의 지도 학습에 비해 pre-training된 언어 모델을 휴먼의 요구에 맞추는 데 큰 개선을 가져왔습니다. Direct Preference Optimization (DPO)과 같은 방법은 그 단순함과 효율성 덕분에 인기를 끌고 있습니다. 최근 연구들은 이런 오프라인 절차의 반복 적용이 더 나은 결과를 가져오며, 특히 더 정보가 많은 선호 관계를 구축할 수 있음을 보여주고 있습니다. 이런 반복 방법에는 Iterative DPO, Self-Rewarding LLMs, SPIN 등이 있으며, 이들은 일반적인 지침 조정 작업에서 좋은 성과를 보였지만 인퍼런스 작업에서는 중간 정도의 개선 또는 성능 저하를 보였습니다.

본 연구는 연쇄 인퍼런스(Chain-of-Thought, CoT) 인퍼런스에 중점을 두어 반복적인 선호 최적화를 적용하는 새로운 접근 방식을 제시합니다. 각 반복에서 여러 CoT 인퍼런스 단계와 최종 답변을 샘플링하여, 올바른 답변을 가진 쌍이 승리자로 선택되는 선호 쌍을 구성합니다. 그런 다음 DPO 변형을 사용하여 승리자에 대한 음의 로그 가능도(NLL) 손실 항을 포함한 학습을 진행합니다. 새로운 모델이 생성되면 새로운 쌍을 생성하고 다시 학습하여 성능이 포화 상태에 이를 때까지 반복합니다.


[방법]

[반복적 인퍼런스 선호 최적화]

접근 방식은 기본적으로 pre-training되었거나 지침 조정된 언어 모델, 훈련 입력 세트, 그리고 최종 출력의 정확성을 판단할 수 있는 능력에 의존합니다. 주어진 훈련 입력에 대해 언어 모델은 (i) 일련의 인퍼런스 단계(Chain-of-Thought, CoT)와 (ii) 문제에 대한 최종 답변을 생성해야 합니다. 최종 답변의 정확성을 판단할 수 있는 척도를 갖고 있으며, 인퍼런스 단계의 정확성은 평가하지 않습니다.

각 반복에서, 방법은 두 가지 단계로 구성됩니다. (i) 연쇄 인퍼런스 및 답변 생성, (ii) 선호 최적화.

[초기화]

초기 모델 \(M_0\)와 훈련 세트 \(D = \{(x_i, y_i)\}_i\)가 주어졌다고 가정합니다. \(x_i\)는 질문이고 \(y_i\)는 정답입니다. 모델은 각 반복에서 학습 및 업데이트되어 \(M_0, M_1, \ldots, M_T\)의 모델 시퀀스를 형성합니다.

[연쇄 인퍼런스 및 답변 생성]

현재 모델 \(M_t\)를 사용하여 각 입력에 대해 \(N\)개의 다른 응답을 생성합니다. 각 응답은 CoT 인퍼런스 \(c\)와 최종 답변 \(y\)로 구성됩니다.

\[(c_i^n, y_i^n) \sim M_t(x_i)\]

\(n \in [N]\)는 \(N\)개의 샘플을 나타냅니다. 각 응답에 대해 최종 답변의 정확성을 기반으로 한 보상 \(r_i^n\)을 계산합니다.

\[r_i^n = R(y_i^n, y_i)\]

실험에서는 \(y_i^n\)이 \(y_i\)와 일치하면 \(r_i^n = 1\), 그렇지 않으면 \(r_i^n = 0\)으로 설정합니다. 따라서 생성된 응답 세트 \(G_i\)는 보상으로 보강됩니다.

\[G_i = \{(c_i^n, y_i^n, r_i^n)\}_{n \in [N]}\]

[선호 최적화]

생성된 \(G_i\)를 기반으로 응답 쌍 데이터셋 \(D_{\text{pairs}}\)를 구성합니다. 선택된(승리한) 응답은 높은 보상을 받고, 거부된 응답은 낮은 보상을 받습니다. 일반적으로 동일한 입력에 대해 보상이 더 높은 응답을 승리자로 설정하고, 보상이 낮은 응답을 패자로 설정합니다. 이진 보상의 경우, 생성된 응답 \(G_i\)를 보상에 따라 두 집합으로 나눕니다.

\[G_i^w = \{(c_i^n, y_i^n) | r_i^n = 1\}, \quad G_i^l = \{(c_i^n, y_i^n) | r_i^n = 0\}\]

다음으로, 승리 응답과 패자 응답 쌍을 선택하여 선호 쌍 데이터셋을 구성합니다.

\[D_{\text{pairs}} = \{(c_i^{w_k}, y_i^{w_k}), (c_i^{l_k}, y_i^{l_k}) | \forall i \in D \text{ and } k \in [K]\}\]

선호 쌍이 주어지면, 다음 모델 \(M_{\theta}\)를 학습합니다. \(\theta\)는 모델 \(M_t\)에서 초기화된 파라미터입니다. 학습 손실 함수는 DPO 손실과 승리 응답의 NLL 손실을 결합한 것입니다.

\[L_{\text{DPO+NLL}} = - \log \sigma \left( \frac{M_{\theta}(c_i^{w_k} | x_i)}{M_t(c_i^{w_k} | x_i)} \right) + \alpha \frac{L_{\text{NLL}}(c_i^{w_k}, y_i^{w_k})}{|c_i^{w_k}| + |y_i^{w_k}|}\]

\(\sigma\)는 시그모이드 함수입니다. NLL 항은 응답 길이로 정규화되며, \(\alpha\)는 두 손실 항의 균형을 맞추는 하이퍼파라미터입니다. 최종적으로, 새로운 모델 \(M_{t+1} = M_{\theta}\)를 얻고, 이 모델을 사용해 다음 반복을 위한 데이터를 구축합니다.

[반복적 학습]

이 절차는 모델 시퀀스 \(M_1, \ldots, M_T\)를 학습하며, 각 연속 모델 \(M_{t+1}\)는 t번째 모델에 의해 생성된 선호 데이터 \(D_{\text{pairs}}\)로 학습됩니다.


[실험]

  • GSM8K
    • 데이터셋: 초등학교 수학 문제
    • 결과: 정확도 55.6% → 88.7% (4회 반복)
  • ARC-Challenge
    • 데이터셋: 과학 선택 문제
    • 결과: 정확도 77.8% → 86.7% (3회 반복)
  • MATH
    • 데이터셋: 수학 문제
    • 결과: 정확도 12.5% → 20.8% (3회 반복)


[결론]

반복적 인퍼런스 선호 최적화는 다양한 인퍼런스 작업에서 언어 모델의 성능을 향상시켰습니다. 각 반복을 통해 성능이 점진적으로 향상되었으며, 이는 CoT 생성과 최종 답변의 정확성을 기반으로 한 선호 쌍의 효과를 입증합니다.


Iterative Reasoning Preference Optimization for Enhanced Chain-of-Thought Reasoning in Large Language Models

This work explores the challenge of limited reasoning ability in large language models (LLMs) despite vast training datasets. We propose Iterative Reasoning Preference Optimization (Iterative RPO), a novel approach that leverages preference learning and chain-of-thought (CoT) generation to significantly improve LLM performance on reasoning tasks.

Background and Motivation:

While supervised fine-tuning (SFT) offers some improvement over zero-shot methods, it falls short in achieving robust reasoning capabilities. Existing offline preference optimization techniques, such as DPO (Difference in Predicted Outcomes), demonstrate promise but struggle to achieve high reasoning accuracy.

Iterative RPO Methodology:

Iterative RPO addresses these limitations through an iterative training process:

  1. CoT and Answer Generation: The current model generates multiple reasoning steps (CoT) and corresponding final answers for training prompts. A reward function evaluates the final answer’s correctness.
  2. Preference Pair Construction: Winning responses (with higher rewards) and losing responses (with lower rewards) are selected from the generated data to construct preference pairs.
  3. Iterative Training with Combined Loss: A new model is trained using a combined loss function:
    • DPO loss: Learns from the constructed preference pairs.
    • Negative Log-Likelihood (NLL) loss: Guides the model towards assigning higher probabilities to winning responses.
  4. Iteration: Steps 1-3 are repeated with the newly trained model, iteratively refining the reasoning ability.

Results and Significance:

Iterative RPO demonstrates significant improvements over baseline methods on various reasoning benchmarks (GSM8K, ARC-Challenge, MATH). This highlights the efficacy of learning from the generated CoT data across training iterations. The inclusion of the NLL loss term is crucial, preventing the model from favoring incorrect answer sequences. Notably, Iterative RPO exhibits robustness to noisy data, performing well even on multiple-choice tasks susceptible to random guessing (ARC-Challenge).

Conclusion:

Iterative RPO presents a compelling and practical approach for enhancing the reasoning capabilities of LLMs across diverse reasoning tasks. This method leverages the power of preference learning and CoT generation to achieve superior performance compared to existing techniques. Future research directions include exploring the application of Iterative RPO to more complex reasoning domains and investigating the impact of different preference learning algorithms within the framework.

Previous: Inner Working** Next: FLAME | Factuality-Aware Alignment for Large Language Models

post contain ""

    No matching posts found containing ""