Authors: Jian Hu and Weixun Wang

First published on: 2024/10/18

Online Policy Improve

Reinforcement Learning from Human Feedback (RLHF) [1]

The RLHF process can be broken down into several key stages:

  1. Supervised Fine-Tuning (SFT): Initially, a base language model is fine-tuned using supervised learning techniques on a dataset of prompts and human-written responses.
  2. Creating the Reward Model: After fine-tuning, a reward model is constructed based on human rankings of model outputs. This model predicts which responses are preferred by humans.
  3. Reinforcement Learning (RL) Fine-Tuning: The final stage involves using reinforcement learning algorithms, such as Proximal Policy Optimization (PPO), to further refine the language model based on feedback from the reward model.

image.png

Let $x_t$ denote the input at time $t$ , which includes both the query and the tokens generated up to that point. The RL optimization goal can be mathematically expressed as follows:

$\text{Optimization Goal} = \mathbb{E}{(x,y)}[R(x,y) - \beta D_{KL}(\pi || \pi_0)]$

Where:

To mitigate issues such as "hallucinations" or nonsensical outputs resulting from over-optimization of the reward model's learned preferences, KL divergence penalties are incorporated into the optimization objective. This balance helps maintain diversity in outputs while ensuring alignment with human feedback.

Pair-wise Reward Model

  1. Training Objective: The reward model is trained to minimize a loss function that encourages it to predict rewards that align closely with human evaluations. This is often done using a cross-entropy loss function, defined as: $\mathcal{L}(\theta) = -\frac{1}{\binom{K}{2}} E_{(x,y_w,y_l)} \left[ \log(\sigma(r_\theta(x,y_w) - r_\theta(x,y_l))) \right]$

    Here, $r_\theta(x,y)$ denotes the predicted reward for the response $y$ given prompt $x$ , and $\sigma$ is the sigmoid function.

  2. Reward Normalization: After training, the outputs of the reward model are normalized so that the mean score of reference responses is zero, ensuring consistency in scoring.

Proximal Policy Optimization (PPO) [2]