[Relay][RFC] Automatic Differentiation
See original GitHub issueThis RFC aims to pave the road to add automatic differentiation (AD) to Relay, both a first-order and a higher-order algorithm. An AD algorithm calculates the gradient of a program, which is needed for training (back propagation). Additionally, higher-order gradients are sometimes needed for other optimization algorithms such as Newton’s method.
Because Relay support closures and control flow, and there are further plans to add features like algebraic data type (#2175) , our implementation of AD must support back propagation over these constructs. For the implementation, we plan to closely follow “Demystifying Differentiable Programming.” On a high level, our approach generates a graph at runtime and runs reverse-mode automatic differentiation on it as if it were a static graph. This algorithm easily supports closures, control flow, and ADTs and it can be extended to account for further language features.
Differentiating programs that include operators requires us to register implementations of gradients for the different operators, which we will include as attributes. More specfically, for an operator f : <x0, x1, x2...> -> y
, the gradient of the operator is f : <x0, x1, x2...> -> <y, y -> <x0, x1, x2...>>
It’s signature in C++ is runtime::TypedPackedFunc<Expr()>.
The signature for this attibute is open for discussion: we can scale other forms to the higher-order case as well.
First-order AD can be easily optimized and does not require any further language features beyond gradients for operators. AD on Higher-order program in the manner of demyst would require us to add OCaml-style references to Relay, which we plan to submit soon as a PR.
We would appreciate the community’s feedback on this outline for implementing automatic differentiation in Relay. We welcome and would be glad to respond to any comments regarding further details about the implementation of AD and necessary steps to incorporate into Relay. Thanks @slyubomirsky for helping me write this RFC.
Issue Analytics
- State:
- Created 5 years ago
- Reactions:4
- Comments:19 (18 by maintainers)
Top GitHub Comments
@masahi it is the fastest path.
It use delimited continuation so it is confusing, but it can be explained without.
Essentially, for every Double expression d, we transform it to an expression of type (Double, Ref Double). It is a tuple which hold two value: the original value, and the gradient of that value.
There is also a global Ref (() -> ()) function (it take no argument, produce an empty tuple) called backward which is resposible for reading the gradient, and writing it upstream.
take expression
x + y
wherex
andy
are subexpressions for example. we will 0: transform x and y into pair (x, xref) and (y, yref) 1: generateThere will be wrapper code which initialize a backward function which does nothing, and convert between Double and (Double, Ref Double)
it is essentially the same for tensor.
The current AD signature:
Array<Expr> (Expr orig_call, Expr out_grad)
constructs the gradient AST for a given input.Imperative AD can make use of this signature along with a high-level taping structure(possibly attached to NDArray of the framework) to get the gradient. So you don’t need another separate signature for imperative AD.