NLP
[Paper Review] Training Verifiers to Solve Math Word Problems (GSM8K)
maotter
2025. 1. 8. 11:30
paper: Cobbe, Karl, et al. "Training verifiers to solve math word problems." arXiv preprint arXiv:2110.14168 (2021).
link: https://arxiv.org/abs/2110.14168
Training Verifiers to Solve Math Word Problems
State-of-the-art language models can match human performance on many tasks, but they still struggle to robustly perform multi-step mathematical reasoning. To diagnose the failures of current models and support research, we introduce GSM8K, a dataset of 8.5
arxiv.org
[Abstract]
- LLM이 multi-step mathematical reasoning 문제를 푸는 데에 어려움을 겪고 있음
- GSM8K라고 하는 높은 품질의 grade school math problem으로 이루어진 8.5K 데이터셋을 만듦
- Verifier을 훈련시켜서 모델의 성능을 높임
1. Introduction
- LLM이 scaling law에 따라 높은 성능을 보이고 있지만 multi-step mathemathical reasoning에서는 어려움을 겪고 있음
- 주요 원인으로는 autoregressive model의 한계로, 이전에 발생한 Error를 수정하는 능력이 없기 때문임
- 본 연구에서는 Verifier를 훈련하여 model이 생성한 답변들의 정확도를 평가하도록 함
- candidate solution들 중에서 verifier가 가장 높게 ranking한 답변을 최종적으로 선택함
- GSM8K 데이터셋
- linguistic diversity는 높으나, 비교적 간단한 grade school math concept만을 가지고 있음
- elementary concept만을 가지고 있기에 모델이 충분히 풀 수 있는 문제들임
- Finetuning보다는 Verifier를 활용하는 것이 더 효율적이라는 것을 실험적으로 증명함
- Dropout은 strong regularizer로써 성능 향상에 효과적임을 실험적으로 증명함
2. Dataset
GSM8K
- 8.5K high quality grade school math problems by "human problem writers"
- 7.5K train & 1K test
- 2 to 8 steps to solve
- 기초적인 연산자 (+ - X /)만을 이용한 기초적인 연산들을 수행해야함
3. Related Work
3.1. Related Datasets
- ASDiv
- 2.3K math word problems 데이터셋으로 high diversity와 high quality
- GSM8K는 ASDiv와 같은 design principle을 취했지만, natural language solution과 더 많은 step을 요구한다는 점에서 다름
- MATH
- GSM8K보다 더 크고 훨씬 challenging한 데이터셋
- CommonsenseQA
- GSM8K는 CommonsenseQA와 같이 basic background knowledge를 요구함
- LogiQA
- GSM8K는 LogiQA와 같이 reading comprehension과 더불어 Logical reasoning을 요구함
3.2. Related Methods
- 본 연구는 Shen et al. (2021a, "Generate & rank")와 가장 비슷하나 세 가지 측면에서 다른 점을 보임
- pure mathematical expression이 아니라 natural language solution에 집중하기에 더 이해하기 쉽고 verbal 분석 능력을 높일 수 있음
- verifier가 baseline method들보다 additional data로부터 더 효과적으로 scale한다는 것을 실험적으로 보임
- generator와 verifier를 분리하여 generator의 overfitting을 막음
4. Methods
- Finetuning과 Verification 두 가지 방식을 이용함
- 두 방식 모두 GPT-3 family를 사용 (6B, 175B)
- arithmetic 실수를 줄이기 위해서 calculation annotation을 이용함
- <<48/2>>라는 식으로 <<.>> 토큰들 사이에 계산해야할 부분 들어감.
- 특별히 처리하지 않고 똑같이 토큰으로 다룸
- test 시에는 eval() 함수 이용해서 계산을 시도하지만, 실패하면 원래 방법대로 autoregressive하게 계산하도록 처리함
4.1. Finetuning
- 6B 모델로 성능을 확인했을 때, Test@1은 에포크가 늘어날수록 성능이 조금씩 증가함
- 그러나 Test@100은 에포크가 늘어나자 Overfitting 문제를 보이며 성능이 떨어짐
- 그래서 최종적으로 verifier를 훈련시키기 위한 샘플 generator로 2 epoch 훈련시킨 모델을 선택함
- 또한 최종 output 전에 natural language solution을 생성하지 않으면 성능이 크게 떨어지는 것을 확인함
4.2. Verification
- Verifier는 generator가 생성한 candidate solution(final answer)를 보고 correct한지 probability를 출력함
- Verifier 훈련 과정
- generator를 training set에 2 epoch 동안 Finetune
- 각 Training sample에 대해 generator로 100 completion을 샘플해서 correct인지 incorrect인지 각각 label
- 1 epoch로 2의 Label된 데이터셋에 대해서 Verifier를 훈련. 훈련 objective는 똑같이 language modeling objective
- Test time
- 각 test problem에 대해서 100 completion을 샘플하고 verifier로 score 및 rank를 매겨서 가장 높은 순위의 결과를 선택
- Comparison with Finetuing method
- 작은 데이터셋에서는 성능 향상을 보이지 않음
- 그러나 큰 데이터셋에서는 Verifier가 커다란 성능 향상을 보임
4.3. Verification Ablations
- solution-level
- solution이 완전히 생성되고 verifier가 scalar prediction
- token-level
- solution의 each token마다 scalar prediction
- 최종적으로는 token-level의 성능이 더 좋았고 solution-level은 금방 overfitting됨
- verifier를 훈련할 때 verification objective만 사용하는 것보다는 language modeling objective도 함께 사용하는 것이 더 좋음
- large generator에 small verifier를 사용하는 게 가장 나은 결과를 보임
5. Additional Experiments
5.1. Test Time Compute
- 400 completion들 중에서 고르는 게 가장 좋은 성능을 보임. 적당한 계산 비용과 성능을 보이는 100 completion을 최종 선택
- single top solution이 아니라 majority vote를 해봄. completion variants 모두 조금씩 성능이 다 올라감
5.2. Regularization
- Finetuning 방법과 Verification 방법 모두 dropout (residual dropout) 시에 큰 효과를 보임
- verifier의 경우에는 overfitting 문제가 있었던 solution-level에서 더 큰 성능 향상을 보임
6. Conclusion
- Verification이 Finetuing보다 더 큰 성능 향상을 가져다줌
- token-level verifier가 sentence-level verifier보다 overfitting에 덜 민감함
- residual dropout이 모든 방법들에 성능 향상을 가져다줌
Note
- GSM8K는 생각보다 쉬운 문제
- Verifier의 구조와 학습 방법을 다르게 할 수 있을 것 같음