0
policy.eval(); critic.eval() # BN eval mode for rollout
                with torch.no_grad():
                    mean, std = policy(actor_critic_input)
                    dist = TransformedDistribution(Normal(mean, std), [TanhTransform()])
                    action_tensor = dist.sample() # (1, action_dim)
                    log_prob = dist.log_prob(action_tensor).sum(dim=-1)

I have nan when I am printing log_prob which is log probability of sampled action from the distribution dist.

1 Answers1

2

The log_prob calculation for a TransformedDistribution has two parts:

The log probability of the untransformed value under the base distribution (the Normal distribution).

The log absolute determinant of the Jacobian of the transform evaluated at the untransformed value. For TanhTransform, this is often computed using the transformed value (y, which is your action_tensor) as log(1 - y^2).

The problem was in the second part: log(1 - y^2).

The TanhTransform maps values from (-inf, inf) to (-1, 1).

When sampled from dist = TransformedDistribution(Normal(mean, std), [TanhTransform()]), PyTorch first samples an untransformed value x from the Normal distribution, and then computes y = tanh(x) to get the final action_tensor.

Due to numerical precision or may be explicit clamping within the sampling process in the operations that take place afterwards, the values in action_tensor (y) can sometimes be exactly 1.0 or -1.0, or very, very slightly outside the range (-1.0, 1.0).

If y is exactly 1.0 or -1.0, then y^2 is exactly 1.0ยด. 1 - y^2is exactly0. Taking the log of zero (log(0))results in-Inf`.

If y is slightly outside (-1.0, 1.0) (e.g., 1.0000001), then y^2 is slightly greater than 1.0. 1 - y^2 is a negative number. Taking the log of a negative number (log(-epsilon)) results in NaN.

Even if action_tensor contains a value that is exactly 1.0 or -1.0 (resulting in -Inf for the log determinant term), and the base distribution's log_prob is finite, adding -Inf often results in -Inf. However, if the base distribution's log_prob is also Inf (which can happen if the untransformed value x = atanh(y) is very large, possibly due to y being close to 1 or -1, and the std is small), you can get Inf + (-Inf), which is NaN.

So for me what worked is clamping the acion_tensor by (-1.0 + epsilon, 1.0 - epsilon) and epsilon = 0.01 by adding another line after sampling the action: action_tensor = torch.clamp(action_tensor, -1.0 + epsilon, 1.0 - epsilon) #clamps the action to avoid NaN