2

This old question has no definitive answer yet, that's why I am asking it here again. I also asked this same question here.

If I'm doing policy gradient in Keras, using a loss of the form:

rewards*cross_entropy(action_pdf, selected_action_one_hot)

How do I manage negative rewards?

I've had success with this form in cases where the reward is always positive, but it does not train with negative rewards. The failure mode is for it to drive itself to very confident predictions all the time, which results in very large negative losses due to induced deviation for exploration. I can get it to train by clipping rewards at zero, but this throws a lot of valuable information on the table (only carrots, no sticks).

nbro
  • 42,615
  • 12
  • 119
  • 217
Mastiff
  • 121
  • 3

1 Answers1

1

You don't need to manage negative rewards separately, if you implemented the algorithm correctly it will work regardless if the rewards are negative or not. You seem to be using rewards for the loss but you should be using the return which is the sum of the rewards for some state action pair from that point until the end of trajectory.

You also seem to be missing $-$ sign from the loss. The objective function for the vanilla policy gradient algorithm (REINFORCE) which we want to maximize is \begin{equation} J = \sum_a \pi(a|s) q_{\pi}(s, a) \end{equation} It can be shown that the gradient sample for this policy gradient method is \begin{equation} \nabla J = G_t \nabla \log (\pi(A_t|S_t)) \end{equation} so in TensorFlow you should define your loss as \begin{equation} J = - G_t \pi(A_t|S_t) \end{equation} We need the $-$ because in TensorFlow you use minimizers, but we want to maximize this function so minimizing this loss is same as maximizing the objective function. In conclusion, the code similar to what you wrote, should be
-return * cross_entropy(action_pdf, selected_action_one_hot)

EDIT

As pointed out in the comment we don't actually need $-$ because it is already included in cross_entropy function.

Brale
  • 2,416
  • 1
  • 7
  • 15