Feature request: smarter primitive for container_types
See original GitHub issueI’ve been noticing that grad()
can be quite memory intensive, due to the tape.
As a workaround, I’ve been making some of my subroutines @primitive
, and assigning their derivatives thru defgrad
. This way, the gradient for the subroutine gets computed “off” the tape.
However, there is a problem if the primitive subroutine takes a container argument, e.g. if I call subroutine(a_dict)
, where subroutine
is primitive
and a_dict
is a dict
containing Node
entries. Then not isinstance(a_dict, Node)
, and so subroutine
acts as if it is not primitive
at all: it will not call self.gradmaker()
, and continue to pass Node
objects into self.fun()
, thus continuing to add to the tape.
My solution is to convert a_dict
into a DictNode
, so that subroutine
truly acts as a primitive. More specifically:
make_dict = primitive(lambda keys, *vals: dict(zip(keys, vals)))
make_dict.gradmaker = lambda argnum, ans, args, kwargs: lambda g: g[args[0][argnum-1]]
a_dict = make_dict(a_dict.keys(), *a_dict.values())
ret = subroutine(a_dict)
Then a_dict
is a DictNode
and none of its entries are Node
, so subroutine(a_dict)
won’t add its internal computations to the tape. (Note make_dict
doesn’t work if a_dict
recursively contains more dicts).
It would be nice if primitive.__call__()
did this automatically (i.e., not just check if the arguments are Node
, but if they are container types with Node
inside). Then, primitives with container_type arguments would be really primitive.
Issue Analytics
- State:
- Created 7 years ago
- Comments:5 (5 by maintainers)
We read you loud and clear; we’ve just gotten busy with thesis defenses and job hunts.
I think “moving computations off the tape” means defining your own primitives and, in so doing, being able to discard some values that aren’t needed on the reverse pass. Defining your own primitives is in the documentation (to the extent that anything here is documented), but there are some more things we could write down and some API adjustments we could make. In particular, sometimes your primitive does need to cache some results for the reverse pass gradfun to use, and the API should support that. I hacked something like that into the API as
primitive_with_aux
and used it for a project where I wrote some primitives and gradients in cython, but that implementation isn’t a good one because it breaks higher-order differentiation. Now that we understand this issue better we’re hoping that an updated core design will handle these cases.Sounds reasonable – doing all the extra checking could certainly slow down a lot of common cases. I think requiring power users to call
make_tuple
ormake_dict
to move computations off the tape is pretty reasonable.Maybe this functionality for moving certain computations off the tape, could become an official part of the API/documentation/functionality, with some unit tests? In practice it can be hard to tell whether a computation is happening on or off the tape, so my worry is that future changes to autograd may affect the code behavior in a way that is hard to diagnose.
Anyways, thanks for the response, I was worried my initial rambling post would be impossible to understand 😃