4

In RLHF, the reward function is a neural network. This means we can compute its gradients cheaply and accurately through backpropagation. Now, we want to find a policy that maximizes reward (see https://arxiv.org/abs/2203.02155). Then, why do we need PPO to find that policy? Since the reward function is differentiable and cheap to evaluate, wouldn't some gradient descent algorithm such as SGD or Adam be enough?

Rexcirus
  • 1,309
  • 9
  • 22
DeltaIV
  • 204
  • 2
  • 9

1 Answers1

4

Recall that RLHF essentially has three stages:

  • Step 0: Pretrain a text-generating LLM.
  • Step 1: Collect human ratings on text, and from that, learn a reward model $r_{\theta}$ which can provide ratings on any unseen text. This model may use parts of the LLM. (This is where you backpropagate through the reward model!)
  • Step 2: Finetune the parameters $\Phi$ of your text-generating model, $LLM_{\Phi}$ with respect to a frozen reward function $r_{\theta}$ to maximize the rewards.

Q1: In Step 2, why don't we apply gradients through the (frozen) reward function?

It seems there is some confusion around what it means to maximize the rewards. In RL, you think of the the reward function $r_\theta$ as something dictated by the environment --- you can't change it, nor does it make sense to change the reward itself however differentiable it may be. What does make sense is to take actions with our agent's policy, and see how the environment rewards them, and change the agent's policy to "maximize the rewards".

Analogously, the idea in RLHF is to use a (fixed) reward function $r_{\theta}(y)$ which can provide useful ratings of text $y$ generated from an $LLM_{\Phi}$. This can be used to improve the parameters $\phi$ of $LLM_{\Phi}$ into assigning more likelihood to text that are of higher quality. Concretely Step 2 involves:

  • (a) making your agent take various independent action(s) according to your current policy i.e., make $LLM_{\Phi}$ generate various text completions, $y_1, y_2, ...$ given a prompt $p$
  • (b) then looking at the reward(s), $r_\theta(y_i)$
  • (c) then adjusting $\Phi$ (NOT $\theta$) to increase/decrease the probability of that (series of) action(s) according to the reward.

Hope you see why it does not make sense to differentiate through the reward model itself. You want to use the useful signal from the reward model to improve your generator.

Q2: Why do we need a neural network for the reward function in the first place?

Ideally, a human should provide feedback on all the generated text. That would be infeasible. So we resort to training a reward model that learns from human feedback, and generalizes its ratings to unseen text. Indeed, at this stage of RLHF, we differentiate through $r_{\theta}$. But after that $\theta$ is frozen.

You may still ask:

Q3: Given that we have $r_{\theta}(x)$ where $x$ is the output of $LLM_{\Phi}$, why not pass gradients through $r_{\theta}(x)$ to update $\Phi$?

This is difficult to do because $x$ is a discrete variable. Instead what we do is to change the distribution over $x$ so as to maximize $r_{\theta}(x)$. This does not require "backpropagating through" $r_{\theta}(x)$. Instead, this only requires updating the likelihood of $x$ under $LLM_{\Phi}$, proportional to $r_{\theta}(x)$.


Note: The rest is my older answer, which may not have addressed your original question.

You may still ask, within Step (c) above:

Q4. Why PPO? Why not simple SGD to update the likelihood?

(but I realize this may not be your question?)

PPO is in essence doing a simple SGD, but with some fancy penalties for the sake of stability. The "simple SGD" part of PPO takes care of updating your "policy" so as to maximize your "reward". This update on $\Phi$ (NOT on $\theta$), improves the likelihood of $LLM_{\Phi}$ on text that $r_{\theta}$ assigns higher rating to.

Meanwhile, the "fancy penalty" makes sure that in doing your simple update, your policy remains in sufficient proximity to your "original policy". Hence the name, Proximal Policy Optimization. Without this penalty, the agent may diverge out of instability and start producing garbage as explained in the next question. See https://huggingface.co/blog/deep-rl-ppo for a nice technical discussion of the algorithm.

Q5. What exactly are these instabilities?

Recall that what you have is a reward model trained on a specific kind of data (e.g., meaningful text from a pre-trained LLM). Thus, it is possible that your reward model has no clue how to score text from other distributions: as an extreme case, imagine complete gibberish text. In fact, it's even possible that the global maximum with respect to the learned reward model, lurks in some gibberish/strange space of language where the reward model wasn't trained on. If your LLMs by happenstance produces a gibberish that the reward model thinks is great, the LLM may quickly devolve into fooling your reward model by producing only said gibberish.

Q6. Can't you simulate this with a small learning rate with simple SGD?

I do not have a good answer to this as I'm not an RL expert. My guess would be either that (i) this may require limiting to very small learning rate that makes learning too slow or impossible (perhaps precluding any type of feature-learning) or (ii) it may still be possible to get caught in bad local maxima, since this corresponds to purely maximizing only the reward objective.

Now, going by the title of your question, perhaps you may still ask (or maybe you don't; I did wonder about this when I first came across RLHF for LLMs):

Q7. Isn't this just SGD with some supervision and a fancy penalty? Why is this called RL?

Remember that in supervised learning, you're given an immutable set of labeled input-output pairs, while in RL, you're interacting with your environment by producing your own inputs and collect output values in response to that. Analogously, in standard LLM training, you are given an immutable set of human-generated text for supervision on which you maximize likelihood. In RLHF however, the LLM has the freedom to generate whatever text it may please as in step (a), and gather supervision on that text, as in step (b). This is exactly why it can generate "whatever it may please" and devolve into a gibberish state (which in turn motivates a need to fetter your policy close to the initial policy).

Enigman
  • 173
  • 4