5

What is grad_fn for a non-differentiable function like slicing (grad_fn=<SliceBackward0>), view (grad_fn=<ViewBackward0>), etc.? Is grad_fn simply the function's inverse operation?

Where in the source code can I see the implementation of SliceBackward0, ViewBackward0, etc.? I assume it's in their backward() static methods somewhere.

related: "Does it make sense for a computational graph to have entirely non-differentiable functions?"

Geremia
  • 555
  • 1
  • 5
  • 12

1 Answers1

4

Non-differentiable node is not inverted during backward(), but is simply treated as constants. The backpropagated gradient stops at that node so any parameter whose only connection to the loss is via the said non-differentiable node cannot receive gradient updates. So the grad_fn objects you listed are not inverse of the forward operations (e.g., differentiable slicing and view operations), instead, it's an object of a subclass of PyTorch's internal Function class that holds the logic for how to compute the gradients in the backward pass.

The functions such as your <SliceBackward0> are implemented here and is generated at the location like torch/csrc/autograd/generated/Functions.cpp or a related file in a local source build since PyTorch's autograd engine is implemented primarily in C++ with code generation for many fuctions.

cinch
  • 11,000
  • 3
  • 8
  • 17