본문 바로가기

RL

[Paper Review] ReFT: Reasoning with Reinforced Fine-Tuning

paper: ReFT: Reasoning with Reinforced Fine-Tuning (Trung et al., ACL 2024)

 

ReFT: Reasoning with Reinforced Fine-Tuning

Luong Trung, Xinbo Zhang, Zhanming Jie, Peng Sun, Xiaoran Jin, Hang Li. Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.

aclanthology.org


[Abstract]

  • CoT data에 SFT만 하면 Generalization 능력이 떨어짐
  • ReFT: SFT로 warmup한 뒤, online RL로 further fine-tune
  • extra or augmented training question 없이 improvement

1. Introduction

  • 하나의 질문에 대해서 multiple valid CoT annotation, 경로가 가능함에도 현재의 CoT SFT는 하나의 경로만을 학습함
  • ReFT
    • SFT로 warm-up stage를 가져서 어느 정도의 성능을 확보
    • Online RL (PPO)를 이용하여 further refine

2. Related Work

  • Math Problem Solving
    • natural language CoT보다 program-based CoT가 더 정확한 reasoning step을 보임
  • Reinforcement Learning
    • 이전 연구들처럼 alignment에 RL을 사용하는 것이 아니라 전통적인 SFT 방식보다 더 나은 성능을 이끌어내기 위해 본 연구는 RL을 적용

3. Method

3.1. Reinforced Fine-Tuning (ReFT)

  • Warm-up
    • $(question, CoT)$로 이루어진 데이터셋에 대해 몇 epoch 정도 fine-tune됨
    • 모델이 적절한 답변 생성을 위한 기본적인 problem-solving 기술을 갖추도록 함
    • 일반적인 SFT Loss로 학습 : $L_{SFT}({\theta}) = -\mathop{\mathbb{E}}_{e\sim D}\left [ \sum_{t=1}^{L}log({\pi}_{\theta}(a_t|s_t)) \right ]$
  1. RL
    • $(question, answer)$로 이루어진 데이터셋을 가지고 스스로 학습
    • policy 모델이 반복적으로 response들을 sampling하고 response가 맞는지 평가하고, online 방식으로 update
    • PPO 사용
    • Reward 
      • non-terminal state에게는 모두 0
      • terminal state에서 CoT로부터 뽑아낸 answer와 ground-truth answer를 비교하여 reward
      • partial reward를 활용하여 정답이 틀렸지만 null이 아닌 numeric type이면 0.1 부여
      • Total reward: $r_{total}(s_t,a_t,s_{t+1}_ = r(s_t,a_t, s_{t+1})-{\beta}KL({\pi}_{\theta}(\cdot |s_t), {\pi}^{(0)}_{\theta}(\cdot|s_t))$
    • Advantage로는 GAE 사용 
    • Total Loss: $L_{RL}({\theta},{\phi})=L_{policy} + {\alpha}L_{value}$

3. Experiments

  • Benchmark
    • GSM8K / SVAMP / MathQA
    • CoT annotation을 얻기 위해서 GPT-3.5-turbo를 이용한 few-shot prompt
  • Model
    • Galactica-6.7B
    • CodeLLAMA-7B
  • Baselines : Expert Iteration 방법을 사용
    • Offline Self-Training: SFT 모델로 CoT 샘플하고 ground truth에 맞는 answer만 포함해서 다시 SFT 
    • Online Self-Training: SFT 모델로 warmup하고 CoT 샘플하고 ground truth에 맞는 answer 포함해서 계속 update
  • Reward Model Reranking
    • SFT checkpoint에서 best model로 initialize한 LM
    • correct 또는 incorrect를 predict하도록 학습

 

4. Notes

  • Long-horizon에 사용할 수 있는 방법이 없을까
  • Entropy를 추가해서 해도 되지 않을까