0

I work with a few different automatic differentiation frameworks, including pytorch, Jax, and Flux in Julia. Periodically I run some code and I get errors about mutations or operations occurring "in-place." These errors generally cause the program to fail. My question is, what does this error mean--in the sense of the AD algorithm and not the specific programming framework?

So I wanted to understand from a very low level, what the problem is and how to get around it. I believe most AD frameworks focus on reverse mode AD because it is more efficient for a large number of parameters.

I am not sure how current AD algorithms are implemented. Older AD systems were based on Wengert lists, while more recent implementations use source-to-source transformations. So just to elaborate on my original equation, does in-place mutation create problems for both source-to-source AD packages and Wengert list/tape based AD packages?

Here is my guess at the issue, though I cannot confirm it. It seems like the whole mutation issue has something to do with computing either the forward pass or the backward pass of the AD algorithm. So to compute the gradient, the current value of each variable needs to be known. If we change the value of a variable in the chain in-place, then we compute the wrong gradient at that variable, in that time-step. So that makes sense. But I am not sure if my understanding is correct here. So any clarification is appreciated.

After that, how does one write code that avoids this kind of problem? I suppose this question can depend on the AD framework the user is using, but does any operation need to create a new variable instead of reusing existing variables? How smart are these frameworks in handling these types of mutations.

krishnab
  • 207
  • 2
  • 8

1 Answers1

0

Here is my sense of the answer. In reverse model automatic differentiation (or backpropagation), the automatic differentiation library has to keep track of the values at each stage in the forward pass. These values are stored in memory.

When values are changed in-place. then the values stored by the forward pass (as mentioned above), get changed in-place. This change then generates errors as the backward pass computation is working with altered data instead of the original values from the forward pass. Hence automatic differentiation libraries will do things to unroll loops or such to prevent memory locations from being overwritten during the execution of the code.

If anyone wants to add some additional detail to this response, please feel free.

krishnab
  • 207
  • 2
  • 8