Decreasing autograd memory usage
See original GitHub issueI don’t mean “memory leak” in terms of unreachable memory after the Python process quits, I mean memory that is being allocated in the backwards pass, when it should be being freed. I mentioned this problem in #199 , but I think it should be opened as an issue.
For a simple function
import autograd.numpy as np
from autograd import grad
def F(x,z):
for i in range(100):
z = np.dot(x,z)
return np.sum(z)
dF = grad(F)
and a procedure to measure memory usage
from memory_profiler import memory_usage
def make_data():
np.random.seed(0)
D = 1000
x = np.random.randn(D,D)
x = np.dot(x,x.T)
z = np.random.randn(D,D)
return x,z
def m():
from time import sleep
x,z = make_data()
gx = dF(x,z)
sleep(0.1)
return gx
mem_usage = np.array(memory_usage(m,interval=0.01))
mem_usage -= mem_usage[0]
and a manual gradient of the same function
def grad_dot_A(g,A,B):
ga = np.dot(g,B.T)
ga = np.reshape(ga,np.shape(A))
return ga
def grad_dot_B(g,A,B):
gb = np.dot(A.T,g)
gb = np.reshape(gb, np.shape(B))
return gb
def dF(x, z):
z_stack = []
for i in list(range(100)):
z_stack.append(z)
z = np.dot(x, z)
retval = np.sum(z)
# Begin Backward Pass
g_retval = 1
g_x = 0
# Reverse of: retval = np.sum(z)
g_z = repeat_to_match_shape(g_retval, z)
for i in reversed(list(range(100))):
# Reverse of: z = np.dot(x, z)
z = z_stack.pop()
tmp_g0 = grad_dot_A(g_z, x, z)
tmp_g1 = grad_dot_B(g_z, x, z)
g_z = 0
g_x += tmp_g0
g_z += tmp_g1
return g_x
I get the following memory usage profile:
If I replace the dot gradient with the ones used in the manual code, I get the same memory profile, nothing improves.
If I replace the dot product with element-wise multiply, I get a different memory profile, but still not what I would expect:
I would love to help figure this out, but I’m not sure where to start. First thing is of course to document the problem.
Issue Analytics
- State:
- Created 6 years ago
- Comments:21 (20 by maintainers)
On the issue of float32(/float16) support, I believe that
array
s are generally already handled correctly by most of the vjps, its only when primitives are applied to scalar values that things get messy.This is a reflection of numpy’s type system, which respects the dtype of arrays and seems to largely ignore the dtype of scalar values (even if they are wrapped as an ndarray with shape ()).
I think this needs to be thoroughly tested but assuming the above is correct, we could say in the docs that float32 is supported but only for arrays.
Here’s a gist on a more “real-world” benchmark, taking the gradients of an RNN.
https://gist.github.com/alexbw/7b2a0682f65dd1bcb7120ca2d47a2823
Here’s the memory usage: