paper: Feng, Xidong, et al. "Alphazero-like tree-search can guide large language model decoding and training." arXiv preprint arXiv:2309.17179 (2023).
link: https://arxiv.org/abs/2309.17179
Alphazero-like Tree-Search can Guide Large Language Model Decoding and Training
Recent works like Tree-of-Thought (ToT) and Reasoning via Planning (RAP) aim to augment the reasoning capabilities of LLMs by using tree-search algorithms to guide multi-step reasoning. These methods rely on prompting a pre-trained model to serve as a valu
arxiv.org
[Abstract]
- 이전의 ToT나 RAP는 tree search algorithm을 활용하여 multi-step reasoning 능력을 올리려고 시도함
- 그러나 이 방법들은 pretrained 모델의 prompting에 의존하고 low depth만 가능했음
- AlphaZero-like tree search learning framework (TS-LLM)은 learned value function을 이용하고 high depth까지 가능함
- 또한 TS-LLM 방식은 inference와 training 모두에 적용 가능함
1. Introduction
- 기존의 방법들은 LLM 프롬프팅을 통해 value를 얻었기에 general applicability가 부족하고 well-desinged prompt에 의존함
- ToT,RAP는 BFS/DFS와 MCTS를 사용했지만 maximum depth가 10이나 7에 지나지 않음
- TS-LLM
- learned value function을 이용하여 평가하기에 더 reliable함
- depth 64까지 tree search 가능함
- Iterative process
- tree search로 policy improve
- policy improvement through policy distillation
- value function improvement through the ground-truth training labels on the tree search trajectories
2. Related work
- Finetuning LLMs with Augmentation
- TS-LLM은 tree search로 augmented sample들을 만들고 LLM과 Value function을 학습시킴
3. Enhancing LLMs with Tree Search
3.1 Problem Formulation
- language generation process as a multi-step MDP
- generation 문제를 high cumulative reward 최적화 문제로 바꿈
- sentence-level action nodes
- depth는 작지만 width가 너무 큼
- Sampled MuZero처럼 w 크기의 노드만을 sample하기로 함
- token-level action nodes
- depth가 너무 깊음
3.2 Guiding LLM Inference Decoding with Tree Search
- Learning an LLM-based Value function
- value function / reward model (ORM)
- value function과 reward model은 공유되고 decoder-only transformer with an MLP 구조로 input token들의 각 위치에서 scalar 값을 output함
- sentence-level일 때는 마지막 token의 scalar 값을 value로 봄. final reward는 prompt와 generated sentence 전체를 넣었을 때 last token으로 얻음
- value network과 reward 모두 TD-lambda 방법으로 학습함
2. Tree Search Algorithms
- MCTS with Value Function Approximation (MCTS-alpha)
- select, expand, evaluate, backup
- leaf node의 값을 value function으로 evaluate하고 backpropagate함
- action을 취하고 난 뒤에는 이전 state로 못 돌아가기에 initial state에서 시작하지 않음
- MCTS-Rollout
- MCTS처럼 initial state에서 항상 시작하되, value function으로 evaluate
- 중간 단계에서도 value function으로 backup 가능함
3. Multiple search and Search aggregation
- multiple tree search에서 나온 N개의 complete answer들을 합쳐서 하나의 정답을 만듦
- aggregation methods with ORM
- Majority-Vote
- ORM-Max: answer f with maximum final reward
- ORM-Vote: answer f with the sum of reward
3.3 Enhancing LLM Training with Tree Search
- Policy Improvement
- 기존의 policy, value, reward와 training set을 가지고 tree search를 하여 augmented dataset D를 얻음
- 그리고 D에서 filtered positive exampled D+를 얻음
- Policy Distillation
- D+를 가지고 supervised training(cross-entropy loss with trajectories' tokens)을 하여 LLM policy를 improve함
- Policy Evaluation
- D를 가지고 value model과 reward model을 다시 학습함
- 위 세 가지 과정을 반복적으로 실행함
4. Experiments
4.1. Task
- mathematical reasoning task: GSM8K
- mathematical planning task: Game24
- logical reasoning task: PrOntoQA
- RLHF alignment task: synthetic RLHF data
- chess endgame
4.2 Baselines
- GPT3.5
- LLaMA
4.3 Model & Training
- reasoning task : rollout policy로 LLaMA2-7B 모델 사용
- RLHF,chess: rollout policy로 GPT-2-small 모델 사용
- value,reward 훈련을 위해 기본 training set으로 SFT한 policy 모델을 rollout하여 training set 만듦
- policy model과 value model의 base model은 같음
5. Results
- Learned value function이 GPT-3.5로 prompt하여 value를 얻는 것보다 나은 성능을 보임
- ORM 훈련에는 diverse dataset이, value function 훈련에는 much data가 중요함
- width를 높이니 성능이 늘어남
Note
- 계산량이나 성능 비교가 명확해보이지 않음
- ORM으로 훈련시켰어야 하나? MATH-Sheperd와 같은 방법론으로 PRM했으면 더 잘했을 수도 있음