paper: Singh, Avi, et al. "Beyond human data: Scaling self-training for problem-solving with language models." arXiv preprint arXiv:2312.06585 (2023).
link: https://arxiv.org/abs/2312.06585
[Beyond Human Data: Scaling Self-Training for Problem-Solving with Language Models
Fine-tuning language models~(LMs) on human-generated data remains a prevalent practice. However, the performance of such models is often limited by the quantity and diversity of high-quality human data. In this paper, we explore whether we can go beyond hu
arxiv.org](https://arxiv.org/abs/2312.06585)
[Abstract]
- LLM을 human data에 fine-tuning시키는 것은 quantity나 diversity에 제약이 있음
- ReST^EM이라는 EM 알고리즘에 기반한 self-training 방법을 제시함
- human data에만 fine-tuning한 모델보다 더 나은 성능을 보임
1. Introduction
- prior work인 ReST 방법과 비슷한 접근
- ReST^EM
- Generate (E-step): LM이 multiple output들을 만들어내고 binary reward를 이용하여 filter해서 training dataset 구축
- Improve (M-step): original LM은 E-step에서 만들어낸 dataset으로 supervised fine-tuned
2. Preliminaries
- SFT
- RL
3. Expectation-Maximization for Reinforced Self-Training
Expectation-Maximization (EM) for RL
- Variational Inference : ELBO를 Maximize
- E-step: q를 q*로 만들어서 q^(t+1)로 사용
- M-step: maximize하는 theta를 만들어서 업데이트. weighted negative log-likelihood loss를 minimize하는 문제와 같음
- 위 과정을 반복하면 monotonic improvement가 보장됨 \[L(p_{\theta^{t+1}},q^{t+1})\geq L(p_{\theta^{t}},q^{t+1})\geq L(p_{\theta^{t}},q^{t})\]
EM with non-negative rewards
- 일반적인 RL objective ($L_{RL}$)과 다른 점이 존재함
- 일반적인 RL과 달리 EM-based RL은 이전 iteration의 policy로 샘플링한 데이터를 이용해서 업데이트함.
- 즉, data collection과 policy optimization을 decoupling함 => 덕분에 LLM과 같은 large policy network로 쉽게 scaling 가능함
$ReST^{EM}$
- Decouple data collection (E-step) and policy optimization (M-step) in a typical RL pipeline
- Generate (E-step)
- 현재 policy인 $p_{\theta}$를 가지고 데이터셋 $D_{i}$를 만듦
- 보상 함수 $r(x,y)$를 가지고 데이터셋에 점수를 매김
- Improve (M-step)
- 새로운 데이터셋 $D_{i}$를 가지고 policy $p_{\theta}$를 fine-tune함. 이때 fine-tune 함수는 이전 iteration 함수가 아니라 base model임
- objective function은 reward-weighted negative log-likelihood loss임
- 이렇게 업데이트된 policy를 가지고 다시 E-step을 가서 better quality sample을 만듦
4. Experiments
- Training Datasets
- MATH : mathematical problem solving
- APPS : code generation
- Models
- PaLM 2 variants: PaLM 2-S, PaLM 2-S*, PaLM 2-L
- Evaluation
- test splits of MATH and APPS
- GSM8K
- Hungarian HS finals
- HumalEval
- Big-Bench Hard
5. Results
- MATH and APPS
- Human-written solution보다 좋은 성능들을 보임
- Model이 Scaling할수록 성능이 좋아짐
- Impact on pass@k and majority-voting performance
- $ReST^{EM}$으로 훈련한 모델이 두 metric에서 모두 나은 성능을 보임
- Ablation studies
- 여러 번 iteration하는 게 좋음
- model-generated data는 하나의 질문에 여러 정답을 생성할 수 있기에 더 좋은 성능을 보임
Note
- EM algorithm을 통해서 할 수 있는 방법들이 많아 보인다.
- MM algorithm을 적용해서 할 수는 없을까? (가칭 $ReST^{MM}$)