00:00:00

Share Your Feedback 🏝️

GLoRe

GLoRe

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: Model | Generative Representational Instruction Tuning Next: LongRoPE

GLoRe

  • Related Project: Private
  • Category: Paper Review
  • Date: 2024-02-18

GLoRe: When, Where, and How to Improve LLM Reasoning via Global and Local Refinements

  • url: https://arxiv.org/abs/2402.10963
  • pdf: https://arxiv.org/pdf/2402.10963
  • abstract: State-of-the-art language models can exhibit impressive reasoning refinement capabilities on math, science or coding tasks. However, recent work demonstrates that even the best models struggle to identify \textit{when and where to refine} without access to external feedback. Outcome-based Reward Models ( extbf{ORMs}), trained to predict correctness of the final answer indicating when to refine, offer one convenient solution for deciding when to refine. Process Based Reward Models ( extbf{PRMs}), trained to predict correctness of intermediate steps, can then be used to indicate where to refine. But they are expensive to train, requiring extensive human annotations. In this paper, we propose Stepwise ORMs ( extbf{SORMs}) which are trained, only on synthetic data, to approximate the expected future reward of the optimal policy or V⋆. More specifically, SORMs are trained to predict the correctness of the final answer when sampling the current policy many times (rather than only once as in the case of ORMs). Our experiments show that SORMs can more accurately detect incorrect reasoning steps compared to ORMs, thus improving downstream accuracy when doing refinements. We then train \textit{global} refinement models, which take only the question and a draft solution as input and predict a corrected solution, and \textit{local} refinement models which also take as input a critique indicating the location of the first reasoning error. We generate training data for both models synthetically by reusing data used to train the SORM. We find combining global and local refinements, using the ORM as a reranker, significantly outperforms either one individually, as well as a best of three sample baseline. With this strategy we can improve the accuracy of a LLaMA-2 13B model (already fine-tuned with RL) on GSM8K from 53\% to 65\% when greedily sampled.

Contents

TL;DR


LLM에 기반한 수학적 인퍼런스 능력의 세밀화 및 개선 방법

  • 본 논문에서는 자연어 처리를 위한 대규모 언어모델(LLM)의 수학적 인퍼런스 능력을 개선하는 새로운 방법을 제시하였습니다.
  • 결과 기반 보상 모델(ORM)과 과정 기반 보상 모델(PRM)을 활용하여 문제의 중간 단계까지 정확도를 높였습니다.
  • 전역 개선 모델과 지역 개선 모델을 비교하고 두 모델이 보완적으로 작동함을 보였습니다.

1. 서론

최근 대규모 언어모델(LLM)은 수학, 과학, 코딩 문제를 해결하는 능력에서 주목할만한 진전을 보여주고 있다. 이런 모델들은 종종 정답을 산출하기 위해 “사고의 흐름”(chain of thought)을 생성하는데, 이는 문제 해결 과정에서 중간 단계를 구체적으로 설명하였습니다.

그러나 이런 모델들의 성능은 가끔 robust하지 않으며, 문제의 해결이 필요한 시점이나 방법을 정확히 인식하지 못하는 경우가 많은데, 본 논문에서는 ORM과 PRM을 활용하여 이런 문제를 개선하는 방법을 제안하고, 이를 통해 모델의 인퍼런스 능력을 향상시킬 수 있음을 보였습니다.


2. 배경 및 관련 연구

2.1. ORM과 PRM

  • ORM(Outcome-based Reward Model): 문제와 중간 단계의 정답을 바탕으로 최종 답변의 정확성을 평가하였습니다. ORM은 주어진 문제에 대해 모델이 생성한 답변이 올바른지 여부를 예측하는 데 사용되었습니다.
  • PRM(Process-based Reward Model): 각 단계의 정확성을 직접 평가하였습니다. PRM은 보다 세밀한 피드백을 제공하여 모델이 각 단계에서 발생할 수 있는 오류를 정확히 수정할 수 있도록 하였습니다.

2.2. 개선 문제의 분해

  • 언제 개선할 것인가: ORM을 사용하여 드래프트의 정확성 판단
  • 어디를 개선할 것인가: SORM을 사용하여 오류가 있는 첫 번째 단계를 식별
  • 어떻게 개선할 것인가: 전역 및 지역 개선 모델을 통한 문제 해결

그대들은 어떻게 살 것인가?


3. 방법

3.1. 수학적 모델링

  • Value Function

    \(V^*(s) = \max_\pi E[R|s, \pi]\) $V^*$는 최적 정책 하에서의 value function를 나타내며, $s$는 주어진 상태, $R$은 보상, $\pi$는 정책을 의미하였습니다.

3.2. 개선 모델 훈련

  • 전역 개선 모델(Global Refinement Model): 초기 드래프트만을 입력으로 사용하고, 전체 문제 해결 방법을 새롭게 생성하였습니다.
  • 지역 개선 모델(Local Refinement Model): 오류의 위치를 추가로 입력받아 해당 부분만을 수정하여 나머지 문제를 해결하였습니다.


4. 실험 및 결과

4.1. 데이터셋

  • 사용된 주요 데이터셋은 GSM8K와 SVAMP로, 수학적 문제 해결 능력을 평가하였습니다.

4.2. 실험 결과

  • 전역 및 지역 개선 모델은 각각 문제의 서로 다른 부분에 효과적이며, 이 두 모델의 결과를 조합해 더 높은 정확도를 달성할 수 있었다고 합니다.


5. 결론 및 향후 연구

이 연구는 LLM의 수학적 인퍼런스 능력을 개선하기 위한 새로운 접근 방법을 제시하였습니다. 특히 ORM과 PRM을 사용하여 각 단계의 정확성을 평가하고, 문제 해결 과정에서 발생하는 오류를 효과적으로 수정할 수 있는 방법을 개발하였으며, 향후 연구에서는 이런 모델을 다양한 유형의 문제에 적용하여 그 범용성을 더욱 확장할 계획이라고 합니다.


1 Introduction

State-of-the-art large language models (LLMs) exhibit a wide range of downstream capabilities after pre- training. This includes the ability to refine their reasoning on math, science, or coding problems (OpenAI, 2023; Touvron et al., 2023; Chowdhery et al., 2022). However, under close inspection, this refinement ability is quite brittle, often unable to even identify when a solution needs refinement (Huang et al., 2023). When LLMs do produce successful refinements on hard reasoning tasks this is often due to the incorporation of external forms of feedback, e.g. feedback from humans or code, stronger models, or other tools (Zhou et al., 2023; Gou et al., 2023). In this work, we carefully examine and improve the self-refinement abilities of LLMs on reasoning tasks without any external feedback other than the ground truth answers of the training problems. Notably, this means we make no use of data or feedback from humans or stronger models. To do so we start by heuristically decomposing the refinement problem into three parts: firstly deciding when to refine, then where to refine, and finally how to refine.

Outcome Based Reward Models (ORMs) (Cobbe et al., 2021), first introduced as an estimator of final answer correctness given a question to do solution reranking, are a natural choice for addressing step one. For deciding where to refine, we carefully examine the generalization of ORMs to intermediate steps. We find the accuracy of the underlying data generating policy π directly affects the ORM’s ability to learn correctness of intermediate solutions steps. This leads to the ORM often under-estimating the solvability of a problem from an intermediate step Si. The result is high false-negative rates when used to classify steps with errors. Process Based Reward Models (PRMs) instead are trained to directly estimate the correctness of each step. Yet this requires extensive human labeling of model-generated solution steps as valid or invalid. In an effort to improve our ability to give intermediate step feedback, we introduce the Stepwise ORMs (SORMs) which explicitly predict labels at each step indicating the presence of an error. We generate SORM training data by sampling a student policy π many times at a step Si in solution S, labeling Si as valid if we successfully reach the final answer. From an RL perspective, this can be interpreted as learning (a lower bound of) the optimal value function V ∗ of the reasoning task via approximation of the optimal policy π∗ with rejection sampling. The resulting SORM gives better intermediate step-level feedback, allowing us to give information to the refinement model about both when and where to refine. The refinement model must then only decide how to refine.

We initially train global refinement models capable of refining the entire reasoning trace without any feedback beyond an initial draft solution D. The training data is generated synthetically, by pairing correct solutions with incorrect solutions as in Welleck et al. (2022). An evaluation of the global refinement model confirms its inability to correctly identify when to refine, demonstrating the need for an ORM. Reusing the SORM training data, we train a local refinement model which uses the feedback given by the SORM to identify the first incorrect reasoning step. We then compare the performance of global versus local refinements on a test set of incorrect solution drafts, finding similar refinement accuracy but on largely disjoint sets of problems. In this sense the global and local refinement models are complementary, with local refinements often able to solve problems global refinements cannot and vice versa. To obtain our best results we combine both global and local refinements, using the ORM to choose the most promising one by acting as a reranker of both plus the initial draft. Using this strategy, we can improve the accuracy of an already strong RL fine-tuned Llama-2 13B mode from 53% to 65% when greedily sampled.

In summary we make the following contributions:

  • Decompose the refinement problem into three parts, namely deciding when, where, and how to refine a solution by leveraging reward models (RMs).
  • Highlight the limitations of ORMs in judging the correctness of intermediate steps, despite their ability to judge the correctness of the final answer.
  • Introduce the step-wise ORM (SORM) to refine which is trained only on synthetic data and can more accurately evaluate intermediate steps than the ORM.
  • Propose a new method for refining LLM reasoning that decides when to refine using an ORM, where to refine using a SORM, and how to refine using both global and local refinements. We find the two types of refinement are complementary, each able to solve a large class of problems the other cannot.
  • Demonstrate performance improvements of up to 12% on GSM8K for a 13B LLaMA-2 model using our approach.

2 Background

[수정]

The excerpt you provided outlines advanced concepts and methodologies used in machine learning, specifically within the realm of language models and their applications in reinforcement learning and reward modeling. It dives into complex systems designed for enhancing the accuracy and utility of language models by leveraging sophisticated training methods. Here’s a breakdown and clarification of the key elements mentioned:

  1. Reasoning Task (τ): Defined as a distribution of natural language question/answer pairs. The formulation suggests a structured approach to reasoning, where answers are decomposed into atomic steps, culminating in a final answer. This structure is particularly useful for detailed analysis of the reasoning process and for models to learn not just the final answer but the logical steps leading to it.
  2. Reward Modeling: Discusses how reward models approximate the rewards in a reinforcement learning environment. Here, the rewards are linked to actions taken by a language model in generating responses. The excerpt mentions key references to influential works by Christiano et al., 2017 and Ouyang et al., 2022, indicating a focus on sparse rewards and contrastive preference models.
  3. Outcome-based Reward Model (ORM): Introduced by Cobbe et al., 2021, this model estimates the probability of a generated answer being correct based on the question and intermediate steps in the answer. ORM’s function is crucial for improving the accuracy of generated answers by assessing them at various stages of the generation process.
  4. Process-based Reward Models (PRMs): These models, discussed in newer studies by Lightman et al., 2023 and Uesato et al., 2022, extend the concept of ORM by evaluating the correctness of each step within a solution trace. This method allows for more granular feedback and adjustment, potentially leading to more precise learning by the model.
  5. Refinement: This involves generating a refined answer by conditioning on the question and a draft solution, which could be improved either through a global model (considering only the question and draft solution) or a local model (also considering the location of errors).
  6. Notation and Concepts: - Base Model: A pre-trained language model fine-tuned for specific tasks. - Student Model: A model that has been fine-tuned to generate answers for specific questions, described as a policy π with parameters θ. - DTASK, Dtrain: Denotations for dataset and training split specific to the task τ. - AGR, ALR: Acronyms for global and local refinements of a draft solution. - Value Functions (V π, V ∗): Represent the value function of a policy π and the optimal value function, respectively.

This structure and these methodologies are integral to developing more effective and sophisticated language models capable of engaging in complex reasoning tasks and generating high-quality responses. The focus on intermediate validation and step-wise correctness is indicative of the move towards more transparent and understandable AI decision-making processes.

[원본]

Reasoning: We define a reasoning task \(\tau\) as a distribution of (natural language) question/answer pairs \((Q, A) \sim \tau\). The answer could be either a single final answer, typically a numerical value in case of math problems for ease of evaluation, or include a CoT style solution trace justifying a numerical final answer. We often further write the answer \(A\) as consisting of atomic steps \(A = (S_1, ..., S_L)\) with the final answer being given on step \(L\). The notion of a start of a new “step” is problem dependent but in our case always corresponds to a newline token.

Reward Modeling: Given a reinforcement learning (RL) environment, a reward model can be trained to approximate the reward coming from an action \(a\) in state \(s\) (Christiano et al., 2017). In the language setting, reward models are trained to approximate the reward given to a response generated by a LLM (Ouyang et al., 2022). The reward is generally sparse and given at the end of a generation as in the case of RLHF (Christiano et al., 2017; Ziegler et al., 2019) where a contrastive preference model is learned for RL and rejection sampling.

Outcome-based Reward Model (ORM): First proposed as a final answer verifier used to rerank GSM8K solutions (Cobbe et al., 2021). Formally, we say the ORM estimates \(p(\text{is_correct}(A)\\|Q, A)\) where \(Q\) is a question and \(A\) is a model generated answer. Training data for the ORM is generated by sampling an underlying student model \(\pi\) many times on questions from a reasoning task \(\tau\). The ORM is then trained to predict \(p(\text{is_correct}(A)\\|Q, P_i)\) where \(P_i\) is a prefix of intermediate steps \((S_1, ..., S_i)\) and \(A\) is any hypothetical continuation of \(P_i\) sampled from \(\pi\). At intermediate steps, we may interpret the ORM as estimating the probability of \(P_i\) leading to the correct final answer. We may sometimes write ORM\(\pi\) to emphasize the ORM’s dependence on its data generating student model \(\pi\).

Process-based Reward Models (PRMs): More recently proposed to directly supervise the correctness of each step in a solution \(A = (S_1, ..., S_L)\) (Lightman et al., 2023; Uesato et al., 2022). Formally, we write a PRM predicts \(p(\text{is_correct}(S_i)\\|P_i, Q)\) where \(S_i\) is the last step of \(P_i\).

Refinement: We define a refinement of a draft solution \(AD\) and question \(Q\) as a new solution \(AR\) generated by conditioning on both \(Q\) and \(AD\). We consider both global refinement models, which take as input only \(Q, AD\) and predict \(p(AR\\|Q, AD)\), and local refinement models, which take as input an extra parameter \(E\) indicating the location of an error in \(AD\), to predict \(p(AR\\|Q, AD, E)\).

Notation: For the rest of the paper, we refer to the pre-trained LLM fine-tuned for downstream tasks as the base model. We fine-tune the base model, either on supervised data or using RL, to produce a student model that generates answers \(A\) given a question \(Q\). Sometimes we may also write the student model as a policy \(\pi\) implicitly depending on learnable parameters \(\theta\). \(DTASK\) will be used to denote a dataset for TASK \(\tau\) with train split \(D_{train}\) being implicit. We will use \(Q\) to denote a question and \(A_1, ..., A_k\) TASK to denote solution traces. Sometimes we will write \(A = (S_1, ..., S_L)\) which decomposes the solution trace \(A\) into intermediate steps \(S_i\). \(P_i = (S_1, ..., S_i)\) will be used to denote the prefix of steps up to \(S_i\). Additionally, we will sometimes use \(AGR\) and \(ALR\) to represent global and local refinements of \(AD\). \(V_\pi\) denotes the value function of policy \(\pi\). \(V^*\) denotes the optimal value function with dependence on the background task implicit.

LLM Reasoning: State-of-the-art (SOTA) large language models (LLMs) (OpenAI, 2023; Touvron et al., 2023; Bai et al., 2022; Chowdhery et al., 2022) demonstrate increasingly impressive abilities on hard reasoning tasks as studied by a wide range of math, science, and code benchmarks (Cobbe et al., 2021; Hendrycks et al., 2021b; Sawada et al., 2023; Liang et al., 2022; Srivastava et al., 2022; Rein et al., 2023; Mialon et al., 2023; Chollet, 2019; Hendrycks et al., 2021a; Austin et al., 2021; Mishra et al., 2022; Patel et al., 2021; Gao et al., 2021). Chain of thought (CoT) (Wei et al., 2022) and related techniques (Chen et al., 2022; Yao et al., 2023a; Besta et al., 2023) have emerged as dominant methods significantly boosting LLM performance on these types of tasks. CoT methods allow LLMs to defer giving their final answer by first generating a “chain of thought” involving intermediate computations needed to correctly solve the problem.

LLM Refinement: Intimately related to reasoning ability is a model’s ability to refine previous answers. This work studies the ability of large language models to self-refine their CoT solutions to math reasoning tasks. Several works (Yao et al., 2022; Madaan et al., 2023; Zhou et al., 2023) demonstrate SOTA LLM self-refining and self-critiquing abilities on a range of tasks via prompting and/or tool usage. However, recent work (Huang et al., 2023) argues even for the strongest models such techniques struggle on hard, open-ended reasoning tasks where the model itself must decide when to stop refinement.

Other papers use hand-crafted data augmentation (Paul et al., 2023) or gather human data (Wang et al., 2023b; Chen, 2023; Lee et al., 2023; Saunders et al., 2022; Schick et al., 2022) while still others use techniques from reinforcement learning to generate critiques (Akyurek et al., 2023; Yao et al., 2023b) for larger models. Most related to us is (Welleck et al., 2022) which trains global refinement models in an implicit reinforcement learning like manner by pairing low-value rollouts with high-value rollouts.

Process-based reward modeling (PRMs) (Uesato et al., 2022; Lightman et al., 2023) gives a denser, step-by-step reward for the “correctness” of a particular step without explicitly modeling the step’s impact on the correctness of the final answer. Both ORMs and PRMs are most often used as rerankers over large numbers of candidate solutions, with PRMs generally outperforming ORMs (Lightman et al., 2023). However, PRMs areexpensive to train, requiring extensive human annotation of each step. Uesato et al. (2022) directly compares the performance of a 70B ORM vs PRM on GSM8K, finding both performing similarly when used as a reward for RL and for reranking. They qualitatively note the ORM appears to somewhat generalize to intermediate steps in a manner similar to a PRM but do not quantitatively ablate this observation over multiple models or tasks. Li et al. (2022) attempt to train synthetic stepwise verifiers similar to a PRM which are then used for Monte Carlo Tree Search. Concurrent work (Wang et al., 2023a) proposes training a synthetic process based reward model in a manner similar to our SORM. They then use the RM downstream for RL fine-tuning and rejection sampling.

In contrast to the above works we conduct a careful comparison of ORM/SORM verification abilities at the step level. We then propose to utilize the ORM/SORM for refinement. We accomplish this by generating fully synthetic stepwise labels which allow us to train both the SORM and refinement models.

4 Method

Decomposition of the Refinement Problem: We start by decomposing the refinement problem into three stages:

  1. Learning when a draft \(D\) is correct and when it needs refinement.
  2. Learning where to begin refinement by identifying the first incorrect step.
  3. Learning how to correct the initial draft.

We can naturally address step one by using the Outcome-based Reward Model (ORM) which is trained to predict the probability of a draft being correct. This alleviates some of the difficulty, now only requiring the refiner to identify where and when to refine. Additionally, when doing local refinement, we propose using the Stepwise ORM (SORM) to localize the position of the first error. This simplifies the task even more, as now the local refiner must only decide how to fix the error and continue from there.

Localizing Errors with Reward Models: To identify errors at the step level, we can leverage the ORM by taking its intermediate prediction \(\text{ORM}_\pi(Q, P_i)\) at a step \(S_i\) where \(P_i = (S_1, ..., S_i)\) is the prefix of all steps up to \(S_i\). Recall the ORM is trained to predict the likelihood that a solution with prefix \(P_i\) results in a correct final answer. Importantly, the likelihood inferred from this training data is heavily dependent on the data-generating policy \(\pi\). For this reason, we sometimes include the subscript \(\text{ORM}_\pi\), omitting it when not needed.

To best understand the behavior of the ORM’s prediction at an intermediate step \(S_i\), we can interpret it as the value function of \(\pi\). Recall the value function \(V_\pi(S)\) of a policy \(\pi\) is computed as \(V_\pi(S) = E_\tau \sim \pi(S) R(\tau)\), i.e., the mean return of the policy \(\pi\) from the state \(S\). In the context of reasoning problems, the states we consider are of the form \(S = (Q, S_1, ..., S_i)\) with question \(Q\) and intermediate steps \(S_j\). In our setting, by default there is only a sparse reward of +1 given at the terminal state for a correct final answer. We can write \(\text{ORM}_\pi(Q, P_i) \approx p(\text{is_correct}(A)\\|Q, P_i, \pi)\) where \(P_i = (S_1, ..., S_i)\) is the prefix of all prior steps and \(\text{is_correct}(A)\) is the event that a full solution \(A\) sampled from \(\pi\) with prefix \(P_i\) has the correct final answer. We can then write \(E_{A \sim \pi}(Q,P_i)R(A) = E_{A \sim \pi}(Q,P_i)1_{\text{is_correct}(A)} = p(\text{is_correct}(A)\\|Q, P_i, \pi)\). Therefore, an approximation to the value function of a policy \(\pi\) is predicting exactly the same thing as the outcome-based reward model at an intermediate step \(S\). So we may treat the ORM as approximating a value function for the student model \(\pi\) used to generate its training data. Ideally, we might want to use the ORM to identify where a mistake was made by finding the first step \(S_i\) such that \(\text{ORM}(Q, P_i) \leq 0.5\) i.e. \(P_i\) is likely to result in the wrong answer. However, because the ORM is acting as a value function for \(\pi\), it tends to hallucinate error steps simply because it expects the data-generating student \(\pi\) to fail. For example, if \(\pi\) almost always fails problems involving division, the ORM will assign a low probability of success to a division problem even before the student takes its first step. In these cases, we say the ORM is overly pessimistic. This is not ideal when using the ORM to identify the location of mistakes.

Learning a Step-Wise ORM (SORM): Another natural candidate which could be used to identify mistakes at each step is a Process-Based Reward Model (PRM) (Lightman et al., 2023). A PRM estimates the probability of correctness of a step \(S_i, p(S_i \text{ correct} | Q, S_1, S_2, ..., S_i)\) independently of its impact on the final answer. However, this

would be expensive, requiring collecting human-annotated samples. Instead, we propose to approximate the optimal value function \(V^*\) of the reasoning task. \(V^*\) corresponds to the value function of the optimal policy which is able to successfully solve the reasoning task from any logically valid intermediate state \(S_j\). Such an optimal value function would have \(V^*(Q, S_1, ..., S_i) = 1\) for a solution prefix with no mistakes, and \(V^*(Q, S_1, ..., S_i) = 0\) if the prefix already contains a mistake which will result in an incorrect final answer. We call models we train to directly approximate \(V^*\) stepwise ORMs or SORMs.

As discussed in Uesato et al. (2022), the ORM possesses some knowledge of intermediate solution correctness, allowing it to approximate a PRM. However, we find in practice this property is dependent on the size of the base model and the difficulty of the task τ , with ORMs trained on data from larger students and easier tasks giving better approximations to a PRM. When interpreting the ORM as a value function V π of the data generating student, this makes sense. A larger, more capable student will better approximate the optimal policy π∗, resulting in a better approximation of the ORM to V ∗.

4.1 Training pipeline

Recall, we assume no access to data from humans or better models for fine-tuning. Thus we must generate all training data synthetically for both global and local refinement. Additionally we must generate data for both the ORM and SORM. We divide our proposed training pipeline in three steps. See Figure 1 for a diagram outlining each step.

Step 1: Fine-tuning a student model

To produce base checkpoints from which we can generate ORM/SORM training data and initial refinement drafts AD we fine-tune models using Expert Iteration (EI) (Silver et al., 2017). This is done by sampling the student model K = 96 times per question and filtering out rollouts with incorrect final answers. De-duplication is then performed on the remaining samples to construct a new fine-tuning dataset R1. We then combine this with any available SFT data producing D1 which we use to again fine-tune the pre-trained model. This process is repeated until the maj@1 score of each subsequent fine-tune converges. Note, the fine-tuning dataset used at step i is Di = Ri ∪ Di−1: the union of rollouts generated at the ith step with previously generated training data (D0 = ∅ or SF T ). In the case of GSM8K we first fine-tune each pre-trained model on the given supervised fine-tuning (SFT) data. For SVAMP, which has no CoT SFT data, we 1-shot prompted the pretrained model to generate solutions used to construct an initial EI dataset. We call the resulting model the student model or student policy π. For more details of this training process and resulting models see Section B in the appendix.

Step 2: Training the ORM/SORM

We generate ORM training data by sampling the RL fine-tuned student policy π K times per prompt. As usual, we then label each intermediate step Si as correct if the final answer is correct and incorrect otherwise. To generate training data for our SORM we sample an approximation of the optimal policy π∗ at each step Si in a model generated solution and check correctness of the final answer. We aim to approximate π∗ via rejection sampling of our student policy π∗. Concretely, to produce a training label for a step Si in model generated rollout S, we sample the student policy π for K rollouts starting from the prefix Pi = (S1, …, Si). This produces verifying traces T1, …, TK with correct final answers indicated by l1, …, lK. We then label Si as positive if maxj lj = 1 i.e. we can find the correct final answer starting from Si. In practice we sample K = 8 rollouts per step, each generating at most 300 tokens. Otherwise we label Si as negative. We then train the SORM in exactly the same manner as the ORM, predicting the appropriate label after each step in a solution. See Section G for a comparison of the labels assigned by this process to ground truth human labels.

SORM data post-processing To improve our approximation to the optimal policy via rejection sampling we apply several post-processing steps: 1) If a step Si has a positive label li we set lj = 1 for j ≤ i. I.e. all steps before a positive steps are labeled as positive. This accounts for particularly hard problems where the student is able to find the solution with K samples from the step Si but not any prior step Sj, j < i. 2) We enforce a consistency constraint on the verifying rollouts, requiring each intermediate result Ri computed on step Si of the solution to be used later on. This helps prevent false positives by requiring a verification to make full use of the previous steps it’s verifying. In practice we implement this by checking for each Ri as a string in the suffix after Pi. 3) We balance the number of positive and negative labels at each prefix length in the training dataset. This is crucial, as otherwise there is an imbalance of positive labels towards the start of solutions and negative labels towards the end. This imbalance is easy for SORMs to exploit, leading to models which almost always predict a positive label in the first few steps a negative label towards the end. As an additional baseline we consider the Balanced ORM which simply balances the number of positives and negatives per question in the ORM training dataset. This is done in an attempt to mitigate the overly pessimisstic behavior of the ORM described earlier.

Figure 2 Example of local and global refinements on a math word problem. Left: The local refinement does poorly with a student which struggles dividing by a fraction. Although all prior steps leading up to the fractional division are valid, the local refinement model is forced to either attempt the difficult operation again or choose the wrong operation entirely. In contrast, the global refinement model may attempt to solve the problem with an entirely new approach. Right: In this draft, the model is very close to the final answer, only making a simple mistake at the end. The local refinement is able to correct this simple mistake. In contrast, the global refinement must start from scratch.

Our SORM approximation is motivated by observations from concurrent work which shows our student π does not need to engage in too much exploration, i.e. sampling, to solve most problems sufficiently in distribution of pretraining data. This suggests rejection sampling to be capable of providing a decent approximation to the optimal policy. Additionally, the deterministic dynamics of the reasoning environment allows us to only sample once from the optimal policy π∗ to compute V ∗ at a prefix Pi. This further reduces our sampling requirements, while also allowing us to conclude that if rejection sampling can solve the problem from a prefix Pi, then π∗ will also solve the problem from Pi. Note of course rejection sampling will be weaker than π∗, resulting in the SORM being an under-approximation of V ∗.

Step 3: Training refinement models To train a local refinement model we need a dataset of the form (Q, AD, AR, E) where Q is a question, AD is an initial draft, E labels the location of the first error in AD indicating where to refine, and AR is a refinement with the correct final answer. In pratice, E is communicated to the local refinement as a “[BAD]” token prefixing the incorrect step Si in the draft. Then, at test time, we need a model predicting p(E|Q, AD) to localize errors in the draft. Conveniently, we explicitly train the SORM to predict the correctness of each step in AD. Thus, to produce E we infer the SORM on all steps and return the index of the first step with predicted correctness below a threshold T . Further, we can construct a refinement training dataset with error annotations using the SORM dataset. Given an incorrect model rollout A = (S1, S2, …, SL) we can locate step Si as containing the first error by identifying li = 0 as the first zero label in the trace. We then pair A with a correct verifying trace T from the previous (correct) step Si−1. This creates a training pair (A, T ) where we label the first error in A as E = i. See Figure 2 for an example.

We construct a dataset for global refinement similarly using the ORM training dataset. This is done by pairing incorrect rollouts Aincorrect with correct rollouts Acorrect for the same question Q. This constructs a training tuple (Q, Aincorrect, Acorrect). To maintain a format similar to local refinement, we put a [BAD] token at the very start of the incorrect rollout. We combine both refinement datasets to train a model capable of both global and local refinement.

4.2 Evaluation

We construct a test set for both the ORM/SORM and refinement models by sampling the student model greedily on test questions Q from the task τ . For each benchmark this gives us a test set with prompts of the form (Q, AD) where Q is the problem and AD is an initial draft. For both benchmarks we refer to this as the (Q, D) test set. To generate intermediate step labels we use the same process as used to generate SORM training data. We evalaute the ORM and SORM on this test set by comparing their predictions to these ground truth labels.

To evaluate the global refinement performance we greedily infer the refiner on each (Q, AD) sample and compare the resulting refinement AGR to the ground truth. To evaluate the local refinement model we first annotate each (Q, AD) pair with the location of its first error using the ORM or SORM. This forms a (Q, AD, E) triplet which we use to greedily sample the local refiner. For our best results, we propose to sample both a global refinement AGR and a local refinement ALR for a draft AD and choose the best solution using the ORM reranker. This strategy stems from our observation that global and local refinements each solve complementary, partially non-overlapping subsets of problems the student initially fails on. Thus combining both refinements with the draft significantly expands the set of problems we can solve. Additionally, using the ORM to rerank refinements allows for a cleaner comparison against a best-of-three baseline from the draft-generating student π. See Figure 3 for a diagram of the evaluation pipeline.

We also highlight more exploratory work in the appendix. In the main body we consider only process-based local refinement, which relies on locating reasoning errors in a solution trace. One drawback of this approach is its agnosticism to the abilities of the student model doing refinement. Alternatively, we consider value-based refinement which relies on feedback identifying the step in a solution from which the model has the best chance of succeeding. A comparison to process-based refinement is done in appendix Section J. Additionally, in appendix Section C, we compare refinement training using expert iteration to other RL algorithms with various reward schemes.

5 Results

We evaluate our refinement pipeline on the GSM8K (Cobbe et al., 2021) and SVAMP (Patel et al., 2021) math word problem benchmarks. We fine-tune Llama-2 7B and 13B to produce all downstream models including the ORM, SORM, and refinement models. Note, the evaluation of each model size is self-contained, not utilizing any data or feedback from models of a different size. maj@1 model scores via greedy sampling will be used to evaluate model performance. Hyperparamters for each phase of training are supplied in Section A of the appendix.

5.1 Evaluting the ORM and SORM

SORMs are better than ORMs at evaluating intermediate answers: On GSM8K the SORM improves over the intermediate step accuracy of the ORM by up to 8% from 73% to 81% (See Table 1). This confirms the ORM does a reasonable job estimating intermediate step correctness but can still be improved, particularly for smaller models on a hard tasks like GSM8K. We’ll see this difference in label accuracy also translates into a difference in refinement final accuracy, where it is critical for the ORM/SORM to reliably identify locations of mistakes. In comparison, the balanced ORM underperforms, having comparable intermediate accuracy to the ORM. This is despite qualitiatively appearing to fix the ORM’s over-pessimism, as the balanced ORM assigns roughly 50% chance of success to all questions. We also examine the types of errors models make, finding the SORMs to have a balanced numbers of false positives and negatives when using a 0.5 as the classification threshold.

Figure 3 Evaluation Pipeline for global and local refinement models. We first sample a draft AD from the student model then sample global and local refinements. The ORM is then used to determine which response to select as the final answer among these three candidate solutions.

Table 1 Step-level accuracy of 7B/13B ORM and SORM on test set labels. Note: the test sets are well balanced with positive labels representing 45%-55% of samples. The SORM has better step level accuracy than ORM on the harder GSM8K benchmark but comparable step level accuracy on SVAMP.

Table 2 Final answer accuracy of 7B/13B ORM and SORM on test set labels. Note: the test sets are well balanced with positive labels representing 45%-55% of samples. The ORM has better accuracy than the SORM at predicting final answer correctness.

ORMs better approximate V ∗ on easier tasks: On SVAMP the ORM has better step accuracy than on GSM8K (see Table 1), particularly the 13B model. As a result the SORM offers less improvement. Most questions in GSM8K are relatively more difficult, requiring at least 4 steps to solve. In contrast, most questions in SVAMP require at most three key steps. This small number of steps likely makes it easier for the ORM to generalize. Additionally, the EI models trained on SVAMP reach on average 15% higher accuracy than the same sized model on GSM8K. This makes the base student model a closer approximation to π∗ on SVAMP, making the ORM a closer approximation to V ∗.

The importance of a strong data generating student π is further highlighted by the difference in accuracies between 7B and 13B models on SVAMP. The 7B student EI model gets an accuracy of 58%, whereas the 13B model gets an accuracy of 70%. Correspondingly, the 13B ORM model performs much better at on intermediate steps than the 7B model. Yet in contrast the 13B ORM on GSM8K performs slightly worse at intermediate steps than 7B. This is perhaps partially explained by the performance of the 13B EI student on GSM8K which only improves 5% over the 7B student.

ORMs are better than SORMs at evaluating final answers: Despite the SORM being generally better at predicting intermediate steps, it is slightly worse at predicting final answer correctness compared to the ORM. This is true for both benchmarks, with the 13B SORM on GSM8K lagging by 5% (See Table 2). However, part of this difference is likely due to statistical biases the ORM is able to exploit, improving final answer accuracy at the cost of over-pessimism. For example, if the problem involves division, the ORM knows the student is likely to fail and immediately predicts a low probability of success. In contrast the SORM is forced to be more optimistic, attempting to carefully examine the correctness of each intermediate step.

Unfortunately, the inaccuracy of the SORM as a final answer predictor also makes it slightly worse as a final answer reranker. For this reason we always use the ORM whenever reranking candidate drafts and refinements. A more detailed comparison of reranking accuracies on GSM8K is done in Figure 4. Note, this comparison is done using ORMs and SORMs derived from a student model trained using only supervised fine-tuning on GSM8K. Rerank accuracies are computed by sampling the student K times and scoring each rollout with the ranker. The rollout with the highest score is then chosen as the final answer.

Figure 4 also plots rerank accuracies for SORM models trained on data without additional postproccessing. The best performing SORM uses only consistent verifying rollouts and per-step balanced labels, justifying these as good postpro- cessing choices.

5.2 Evaluating global and lo- cal refinements

Now, with a better understanding of our SORMs’ capabilities, we can apply them for refinement. Recall that to de- cide when to accept a refinement AR we use the ORM as a reranker on the draft AD and refinement AR. When performing local refinement we can ad- ditionally use both the ORM and SORM to identify the location of the first mis- take in AD. For the ORM we do this by labeling the first step Si such that ORM (Si) ≤ T = 0.5 where T is a threshold hyperparameter. We identify the first error analogously with the SORM. We report results on both GSM8K and SVAMP (Q, D) test sets in Figure 5. Note, we being evaluation without using the ORM as a reranker. This is done to confirm others’ observations that refiners struggle knowing when to refine on their own.

Figure 4 Plot of ORM, balanced ORM, and SORM rerank accuracies with the same SFT student (maj@1 = 0.36). Note: SORM by itself does not use balanced step labels or consistent verifiers as additional pre-processing steps as described in Section 4. When we add in both steps, reranking performance significantly improves to nearly match the ORM’s performance.

Both global and local refinement models struggle with knowing when to refine: On both benchmarks global and local refinements show little improvement to overall model accuracy. GSM8K 7B global refinements even decreases overall accuracy, with the other models improving by at most 1%. The local refinements improve overall accuracy more, likely due to the presence of the “[BAD]” token indicating the location (and therefore presence) of the first mistake. This underscores the importance of an ORM for choosing when to refine an incorrect draft. We also note that bigger models produce better refinements.

Global and local refinements fix similar percentages of incorrect drafts: To understand how well our refiners perform when refinement is needed we also report results when applying refinement to only incorrect drafts from the test set in Figure 5. In this case both global and local refinements do much better, improving overall accuracy by an average of 10% on GSM8K and 8% on SVAMP. This demonstrates the refiners have learned how to refine, they simply often do not know when.

It is initially somewhat surprising global refinements are able to fix a similar percentage of drafts as local refinements. Local refinements receive extra information from E, presumably strictly improving performance over the global refiner. In reality, the provided E is noisy as it must be predicted by an imperfect ORM/SORM. We see that even the difference in label accuracy bewteen the ORM and SORM results in a nontrivial difference in refinement accuracy.

Figure 5 Refinement accuracies on GSM8K and SVAMP. All refinement models struggle identifying correct drafts which do not need refinement. Significant improvements are seen when only refining incorrect drafts.

Table 3 Refinement accuracy on incorrect model answers. Local refinement + SORM denotes using the SORM to highlight the first incorrect reasoning step for the local refinement model. We find refining both globally and locally with the SORM can fix up to 41% of problems the model previously failed.

Additionally, global refinements have the advantage of optionally restarting a solution from scratch. A local refinement model is trained to reuse the prefix of a solution preceding a “[BAD]” token under the assumption this prefix has no errors. However, even if this prefix has valid reasoning, it may be a low-value solution path for the student. For example, a student who often fails to correctly divide may benefit from starting the problem from scratch in a way that doesn’t require any use of division. global refinements can take advantage of this, whereas local refinements may be commited to valid reasoning with a low chance of successfully completing. See Figure 2 for examples illustrating this point.

Global and local refinements solve partially disjoint, complementary sets of problems: To better understand how global and local refinements compare we examine the overlap between the problems they correctly solve. The last two rows of Table 3 show that, when combined, global and local refinements can fix 41% of incorrect GSM8K drafts from the 13B student. Alone, global refinement and local refinement with the SORM fixes only 28% of problems. Yet, when taking the best of both types of refinement for the same question, we significantly improve performance across all combinations of benchmarks and model sizes. This shows local refinement is able to solve a large set of problems global refinement cannot, and vice versa. Best performance at test time can then be achieved if we have a way of selecting which of the two refinements is appropriate.

Fortunately, we can use the ORM as a reranker for exactly the task of choosing between global and local refinements. Additionally, we can consider the initial draft as a third possible option as a way of deciding if we want to refine at all. Figure 6 shows the results of reranking the draft, global, and local refinement for each question. Since we are effectively sampling three times, we include as a baseline the best of three (Bo3) samples from the EI student. We additionally report overall accuracy if we had a perfect reranker capable of always choosing the correct solution.

Reranking the draft + refinements improves over the draft accuracy by on average 8% across models and benchmarks. When comparing with the Bo3 baseline we still see significant improvements of around 8% on GSM8K. On SVAMP, reranked Bo3 is a much more competitive baseline, itself giving a large improvement over the draft accuracy. An even bigger improvement can be seen when using an oracle reranker, with the 13B refiner improving 11% over even Bo3 on GSM8K.

Figure 6 Accuracy of reranked refinements on all drafts compared to greedy and best of 3 samples from the student (Bo3) baselines. On GSM8K, reranking refinements using the ORM improves over the Bo3 baseline by up to 9% and up to 13% with a perfect reranker.

6 Conclusion and Future Work

In this paper we study the use of reward models for both identifying when to refine and where to refine LLM reasoning. We found ORM models generalize to some extent to evaluating the accuracy of intermediate steps on easier reasoning tasks but struggle on harder tasks where the training data generating policy π is further from π∗. We then propose to approximate the optimal policy π∗ via rejection sampling and post-processing, allowing us to generate training labels for intermediate steps Si used to train SORM models. We find the SORM generalizes better on intermediate test steps than the ORM, but at the cost of final answer accuracy. We then reused the ORM/SORM training data to train a global/local refinement models. We found each type of refinement strategy helped solve a largely unique set of problems, allowing us to combine both via ORM reranking for best performance.

Future work can be classified as either: 1) improving the reliability and verbosity of local error critiques E by providing more information on how to refine or 2) augmenting the type of information local refiners use to generate correct solutions. Our study of both ORMs and SORMs reveals large room for improvement when verifying step level reasoning. Allowing verifier models to generate chains of thought appears to offer some benefit (Dhuliawala et al., 2023). Further augmenting verifying CoT with tools (Zhou et al., 2023) allows GPT-4 to effectively solve MATH (Hendrycks et al., 2021a). But it remains unclear how much GPT-4 relies on the tool to solve the problem versus actually uses the tool to augment its own understanding of why a step is wrong.

Another promising direction treats iterative refinement as a form of in-context exploration similar in spirit to ideas from algorithm distillation (Laskin et al., 2022). Here, the aim is to minimize the number of in-context model rollouts needed to figure out how to refine. This also closely relates to work aiming to augment the exploration abilities of SOTA LLMs, a direction we believe is critical to future success. The right iterative local self-refinement strategies might hopefully allow models to access complex behaviors previously inaccessible with naieve iid repeated sampling.

Previous: Model | Generative Representational Instruction Tuning Next: LongRoPE

post contain ""

    No matching posts found containing ""