본문 바로가기

RL

[Paper Review] Beyond Human Data: Scaling Self-Training for Problem-Solving with Language Models (ReST^EM)

paper: Singh, Avi, et al. "Beyond human data: Scaling self-training for problem-solving with language models." arXiv preprint arXiv:2312.06585 (2023).

link: https://arxiv.org/abs/2312.06585

[Beyond Human Data: Scaling Self-Training for Problem-Solving with Language Models

Fine-tuning language models~(LMs) on human-generated data remains a prevalent practice. However, the performance of such models is often limited by the quantity and diversity of high-quality human data. In this paper, we explore whether we can go beyond hu

arxiv.org](https://arxiv.org/abs/2312.06585)


[Abstract]

  • LLM을 human data에 fine-tuning시키는 것은 quantity나 diversity에 제약이 있음
  • ReST^EM이라는 EM 알고리즘에 기반한 self-training 방법을 제시함
  • human data에만 fine-tuning한 모델보다 더 나은 성능을 보임

1. Introduction

  • prior work인 ReST 방법과 비슷한 접근
  • ReST^EM
    • Generate (E-step): LM이 multiple output들을 만들어내고 binary reward를 이용하여 filter해서 training dataset 구축
    • Improve (M-step): original LM은 E-step에서 만들어낸 dataset으로 supervised fine-tuned

2. Preliminaries

  • SFT

  • RL

3. Expectation-Maximization for Reinforced Self-Training

Expectation-Maximization (EM) for RL

  • Variational Inference : ELBO를 Maximize

  • E-step: q를 q*로 만들어서 q^(t+1)로 사용
  • M-step: maximize하는 theta를 만들어서 업데이트. weighted negative log-likelihood loss를 minimize하는 문제와 같음
  • 위 과정을 반복하면 monotonic improvement가 보장됨 \[L(p_{\theta^{t+1}},q^{t+1})\geq L(p_{\theta^{t}},q^{t+1})\geq L(p_{\theta^{t}},q^{t})\]

EM with non-negative rewards

  • 일반적인 RL objective ($L_{RL}$)과 다른 점이 존재함
  • 일반적인 RL과 달리 EM-based RL은 이전 iteration의 policy로 샘플링한 데이터를 이용해서 업데이트함.
  • 즉, data collection과 policy optimization을 decoupling함 => 덕분에 LLM과 같은 large policy network로 쉽게 scaling 가능함

$ReST^{EM}$

  • Decouple data collection (E-step) and policy optimization (M-step) in a typical RL pipeline
  • Generate (E-step)
    • 현재 policy인 $p_{\theta}$를 가지고 데이터셋 $D_{i}$를 만듦
    • 보상 함수 $r(x,y)$를 가지고 데이터셋에 점수를 매김
  •  Improve (M-step)
    • 새로운 데이터셋 $D_{i}$를 가지고 policy $p_{\theta}$를 fine-tune함. 이때 fine-tune 함수는 이전 iteration 함수가 아니라 base model임
    • objective function은 reward-weighted negative log-likelihood loss임
    • 이렇게 업데이트된 policy를 가지고 다시 E-step을 가서 better quality sample을 만듦

 

4. Experiments

  • Training Datasets
    • MATH : mathematical problem solving
    • APPS : code generation
  • Models
    • PaLM 2 variants: PaLM 2-S, PaLM 2-S*, PaLM 2-L
  • Evaluation
    • test splits of MATH and APPS 
    • GSM8K
    • Hungarian HS finals
    • HumalEval
    • Big-Bench Hard

 

5. Results

  • MATH and APPS
    • Human-written solution보다 좋은 성능들을 보임
    • Model이 Scaling할수록 성능이 좋아짐
  • Impact on pass@k and majority-voting performance
    • $ReST^{EM}$으로 훈련한 모델이 두 metric에서 모두 나은 성능을 보임
  • Ablation studies
    • 여러 번 iteration하는 게 좋음
    • model-generated data는 하나의 질문에 여러 정답을 생성할 수 있기에 더 좋은 성능을 보임

Note

  • EM algorithm을 통해서 할 수 있는 방법들이 많아 보인다.
  • MM algorithm을 적용해서 할 수는 없을까? (가칭 $ReST^{MM}$)