본문 바로가기

RL

[Paper Review] AlphaZero-Like Tree-Search can GuideLarge Language Model Decoding and Training (TS-LLM)

paper: Feng, Xidong, et al. "Alphazero-like tree-search can guide large language model decoding and training." arXiv preprint arXiv:2309.17179 (2023).

link: https://arxiv.org/abs/2309.17179

 

Alphazero-like Tree-Search can Guide Large Language Model Decoding and Training

Recent works like Tree-of-Thought (ToT) and Reasoning via Planning (RAP) aim to augment the reasoning capabilities of LLMs by using tree-search algorithms to guide multi-step reasoning. These methods rely on prompting a pre-trained model to serve as a valu

arxiv.org


[Abstract]

- 이전의 ToT나 RAP는 tree search algorithm을 활용하여 multi-step reasoning 능력을 올리려고 시도함

- 그러나 이 방법들은 pretrained 모델의 prompting에 의존하고 low depth만 가능했음

- AlphaZero-like tree search learning framework (TS-LLM)은 learned value function을 이용하고 high depth까지 가능함

- 또한 TS-LLM 방식은 inference와 training 모두에 적용 가능함

1. Introduction

  • 기존의 방법들은 LLM 프롬프팅을 통해 value를 얻었기에 general applicability가 부족하고 well-desinged prompt에 의존함
  • ToT,RAP는 BFS/DFS와 MCTS를 사용했지만 maximum depth가 10이나 7에 지나지 않음
  • TS-LLM
    • learned value function을 이용하여 평가하기에 더 reliable함
    • depth 64까지 tree search 가능함
    • Iterative process
      • tree search로 policy improve
      • policy improvement through policy distillation
      • value function improvement through the ground-truth training labels on the tree search trajectories

2. Related work

  • Finetuning LLMs with Augmentation
    • TS-LLM은 tree search로 augmented sample들을 만들고 LLM과 Value function을 학습시킴

3. Enhancing LLMs with Tree Search

3.1 Problem Formulation

  • language generation process as a multi-step MDP
    • generation 문제를 high cumulative reward 최적화 문제로 바꿈
  • sentence-level action nodes
    • depth는 작지만 width가 너무 큼
    • Sampled MuZero처럼 w 크기의 노드만을 sample하기로 함
  • token-level action nodes
    • depth가 너무 깊음

3.2 Guiding LLM Inference Decoding with Tree Search

  1. Learning an LLM-based Value function
    • value function / reward model (ORM)
    • value function과  reward model은 공유되고 decoder-only transformer with an MLP 구조로 input token들의 각 위치에서 scalar 값을 output함
    • sentence-level일 때는 마지막 token의 scalar 값을 value로 봄. final reward는 prompt와 generated sentence 전체를 넣었을 때 last token으로 얻음
    • value network과 reward 모두 TD-lambda 방법으로 학습함

 

2. Tree Search Algorithms

  • MCTS with Value Function Approximation (MCTS-alpha)
    • select, expand, evaluate, backup
    • leaf node의 값을 value function으로 evaluate하고 backpropagate함
    • action을 취하고 난 뒤에는 이전 state로 못 돌아가기에 initial state에서 시작하지 않음
  • MCTS-Rollout
    • MCTS처럼 initial state에서 항상 시작하되, value function으로 evaluate
    • 중간 단계에서도 value function으로 backup 가능함

3. Multiple search and Search aggregation

  • multiple tree search에서 나온 N개의 complete answer들을 합쳐서 하나의 정답을 만듦
  • aggregation methods with ORM
    • Majority-Vote
    • ORM-Max: answer f with maximum final reward
    • ORM-Vote: answer f with the sum of reward

3.3 Enhancing LLM Training with Tree Search

  • Policy Improvement
    • 기존의 policy, value, reward와 training set을 가지고 tree search를 하여 augmented dataset D를 얻음
    • 그리고 D에서 filtered positive exampled D+를 얻음
  • Policy Distillation
    • D+를 가지고 supervised training(cross-entropy loss with trajectories' tokens)을 하여 LLM policy를 improve함
  • Policy Evaluation
    • D를 가지고 value model과 reward model을 다시 학습함
  • 위 세 가지 과정을 반복적으로 실행함

4. Experiments

4.1. Task

  • mathematical reasoning task: GSM8K
  • mathematical planning task: Game24
  • logical reasoning task: PrOntoQA
  • RLHF alignment task: synthetic RLHF data
  • chess endgame

4.2 Baselines

  • GPT3.5
  • LLaMA

4.3 Model & Training

  • reasoning task : rollout policy로 LLaMA2-7B 모델 사용
  • RLHF,chess: rollout policy로 GPT-2-small 모델 사용
  • value,reward 훈련을 위해 기본 training set으로 SFT한 policy 모델을 rollout하여 training set 만듦
  • policy model과 value model의 base model은 같음

5. Results

  • Learned value function이 GPT-3.5로 prompt하여 value를 얻는 것보다 나은 성능을 보임
  • ORM 훈련에는 diverse dataset이, value function 훈련에는 much data가 중요함
  • width를 높이니 성능이 늘어남

Note

  • 계산량이나 성능 비교가 명확해보이지 않음
  • ORM으로 훈련시켰어야 하나? MATH-Sheperd와 같은 방법론으로 PRM했으면 더 잘했을 수도 있음