Author: Jian Hu
First published on: 2024/8/4
In late 2022, OpenAI's paper [1] on InstructGPT sparked widespread interest in Reinforcement Learning from Human Feedback (RLHF), now often referred to as post-training. The core concept involves training a reward model (RM) using a pairwise preference dataset combined with ranking loss. This RM is then employed alongside the Proximal Policy Optimization (PPO) algorithm to fine-tune a GPT model. The objective is to align the model for enhanced safety, mathematical reasoning, and other capabilities. While the RM is trained using labeled data as a proxy and incorporates a KL penalty to constrain policy distance, the overall process does not significantly deviate from traditional PPO methodologies.
The objective function of RLHF
<aside> đź’ˇ
A major issue is that PPO requires loading four large models simultaneously, complicating system architecture. This dual demand for model inference acceleration (PPO sample generation) and model training acceleration poses significant hurdles in AI infrastructure optimization, particularly with colossal models like Llama3.1 (405 billion parameters).
</aside>
From 2023 to 2024, numerous new algorithms resembling RLHF emerged, including Direct Preference Optimization (DPO) [2] and its variants such as REINFORCE Leave One-Out (RLOO)[4] , Group Relative Policy Optimization (GRPO)[5] and REINFORCE. The primary goal of these algorithms is to simplify the RLHF process while reducing training costs and enhancing efficiency.
Among these algorithms, DPO stands out as a pivotal innovation. Its fundamental premise is that since the RM in RLHF is trained on labeled data, there is no need to separate RM training from RLHF training into two stages. Instead, it proposes merging the losses from both training steps into a single loss function. This results in a loss that resembles performing supervised fine-tuning (SFT) on positive samples while applying reverse SFT on negative samples (assuming no KL penalty constraint).
DPO Loss Function
The DPO loss is straightforward; it requires only SFT and Reference models without the complexities of implementing PPO. If we interpret DPO through the lens of traditional reinforcement learning, it resembles an offline REINFORCE algorithm where positive samples receive a reward of +1 and negative samples -1. The KL penalty can also be incorporated into this reward or treated as an additional KL divergence loss.
The objective function of DPO
However, DPO inherits several drawbacks from conventional offline RL algorithms, such as the lack of importance sampling for gradient correction and issues related to out-of-distribution (OOD) samples between training data and models. These factors can skew training trajectories. Additionally, during training, DPO may encounter simultaneous declines in chosen logits and rejected logits; Llama3.1 mitigates this by applying negative log-likelihood (NLL) loss to chosen samples.
A straightforward improvement involves training an RM similarly to RLHF—this RM can be either a pair RM or an original behavior transfer model. The language model samples N responses for each prompt, ranking them based on scores provided by the RM to identify the best-chosen and worst-rejected samples. In simpler terms, this means pairing samples with the highest and lowest scores for DPO training. Since these samples are generated by the language model itself, it alleviates some OOD issues prevalent in offline RL.
Typically, this iterative sample generation and training process occurs over 3 to 10 iterations, with each iteration generating 10K to 20K samples—significantly fewer than the 50 to 100 policy iterations required by PPO. Thus, Iterative DPO occupies a middle ground between online and offline RL algorithms.
The primary advantage of Iterative DPO lies in its engineering convenience and balance between convergence effects. As its training occurs in distinct phases—sample inference and model training—there's often no need to load all models onto GPUs simultaneously, sidestepping many infrastructure challenges associated with RLHF. Inference acceleration can be easily achieved through frameworks like vLLM or TensorRT-LLM for offline deployment.