2

Does it make sense for a computational graph to have entirely non-differentiable functions?

For example, PyTorch can handle non-differentiable functions and mark outputs as non-differentiable, but I'm wondering not about functions whose derivatives are undefined at some points, but about graphs of entirely non-differentiable functions.

related question: "How is back propagation applied in case the activation function is not differentiable?"

Geremia
  • 555
  • 1
  • 5
  • 12

1 Answers1

2

Entirely non-differentiable functions make sense since one might intentionally include them in parts of a model where gradient updates are either not required or not meaningful, such as the $\text{argmax}$ operator or even a simple routing or logging function of incoming data. For entirely non-differentiable function node in a computational graph, its output is marked accordingly per your own references, and gradient-based optimization simply bypasses any branch through that node.

FunctionCtx.mark_non_differentiable(*args)
Mark outputs as non-differentiable.
This should be called at most once, in either the setup_context() or forward() methods, and all arguments should be tensor outputs.

If the function is not a deterministic mapping (i.e. it is not a mathematical function), it will be marked as non-differentiable. This will make it error out in the backward if used on tensors that require grad outside of a no_grad environment.

Therefore in practice this means that the affected outputs are treated as constants relative to the parameters upstream, and those parameters will not be updated based on errors flowing back through the non‐differentiable function node. Downstream parameters might still get updated if they have other differentiable paths influencing the loss.

cinch
  • 11,000
  • 3
  • 8
  • 17