paper: ReFT: Reasoning with Reinforced Fine-Tuning (Trung et al., ACL 2024)
ReFT: Reasoning with Reinforced Fine-Tuning
Luong Trung, Xinbo Zhang, Zhanming Jie, Peng Sun, Xiaoran Jin, Hang Li. Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
aclanthology.org
[Abstract]
- CoT data에 SFT만 하면 Generalization 능력이 떨어짐
- ReFT: SFT로 warmup한 뒤, online RL로 further fine-tune
- extra or augmented training question 없이 improvement
1. Introduction
- 하나의 질문에 대해서 multiple valid CoT annotation, 경로가 가능함에도 현재의 CoT SFT는 하나의 경로만을 학습함
- ReFT
- SFT로 warm-up stage를 가져서 어느 정도의 성능을 확보
- Online RL (PPO)를 이용하여 further refine
2. Related Work
- Math Problem Solving
- natural language CoT보다 program-based CoT가 더 정확한 reasoning step을 보임
- Reinforcement Learning
- 이전 연구들처럼 alignment에 RL을 사용하는 것이 아니라 전통적인 SFT 방식보다 더 나은 성능을 이끌어내기 위해 본 연구는 RL을 적용
3. Method
3.1. Reinforced Fine-Tuning (ReFT)
- Warm-up
- $(question, CoT)$로 이루어진 데이터셋에 대해 몇 epoch 정도 fine-tune됨
- 모델이 적절한 답변 생성을 위한 기본적인 problem-solving 기술을 갖추도록 함
- 일반적인 SFT Loss로 학습 : $L_{SFT}({\theta}) = -\mathop{\mathbb{E}}_{e\sim D}\left [ \sum_{t=1}^{L}log({\pi}_{\theta}(a_t|s_t)) \right ]$
- RL
- $(question, answer)$로 이루어진 데이터셋을 가지고 스스로 학습
- policy 모델이 반복적으로 response들을 sampling하고 response가 맞는지 평가하고, online 방식으로 update
- PPO 사용
- Reward
- non-terminal state에게는 모두 0
- terminal state에서 CoT로부터 뽑아낸 answer와 ground-truth answer를 비교하여 reward
- partial reward를 활용하여 정답이 틀렸지만 null이 아닌 numeric type이면 0.1 부여
- Total reward: $r_{total}(s_t,a_t,s_{t+1}_ = r(s_t,a_t, s_{t+1})-{\beta}KL({\pi}_{\theta}(\cdot |s_t), {\pi}^{(0)}_{\theta}(\cdot|s_t))$
- Advantage로는 GAE 사용
- Total Loss: $L_{RL}({\theta},{\phi})=L_{policy} + {\alpha}L_{value}$
3. Experiments
- Benchmark
- GSM8K / SVAMP / MathQA
- CoT annotation을 얻기 위해서 GPT-3.5-turbo를 이용한 few-shot prompt
- Model
- Galactica-6.7B
- CodeLLAMA-7B
- Baselines : Expert Iteration 방법을 사용
- Offline Self-Training: SFT 모델로 CoT 샘플하고 ground truth에 맞는 answer만 포함해서 다시 SFT
- Online Self-Training: SFT 모델로 warmup하고 CoT 샘플하고 ground truth에 맞는 answer 포함해서 계속 update
- Reward Model Reranking
- SFT checkpoint에서 best model로 initialize한 LM
- correct 또는 incorrect를 predict하도록 학습
4. Notes
- Long-horizon에 사용할 수 있는 방법이 없을까
- Entropy를 추가해서 해도 되지 않을까