very slow compile: CNN, per sample gradient
See original GitHub issueI encountered a “Very slow compile” when computing per-sample gradients with a small CNN. The warning occurs when the number of samples exceeds about 300 (which is not large either):
2019-10-30 16:12:13.561535: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
The compilation time scales with the number of samples which makes it very hard to scale up the code for practical purposes.
Here’s the code with N_points
which controls the number of samples:
import sys,os
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, grad, random, vmap
import numpy as np
from functools import partial
seed=1
np.random.seed(seed)
np.random.RandomState(seed)
rng = random.PRNGKey(seed)
from jax.experimental.stax import GeneralConv, relu
from jax.nn.initializers import glorot_normal
import time
####################################################################
os.environ['XLA_FLAGS']='--xla_dump_to=/tmp/foo'
L=4 # linear image dimension
dtype=jnp.float64
N_symm=2*2*2 #
dimension_numbers=('NCHW', 'OIHW', 'NCHW') # default
out_chan=1
filter_shape=(2,2)
strides=(1,1)
input_shape=np.array((1,1,L,L),dtype=np.int) # NCHW input format
lhs_spec, rhs_spec, out_spec = dimension_numbers
W_init=glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
W_init = partial(W_init, dtype=dtype)
init_params, apply_layer = GeneralConv(dimension_numbers, out_chan, filter_shape, strides=strides, padding='VALID', W_init=W_init)
# initialize parameters
_,params = init_params(rng,input_shape)
@jit
def evaluate(params, batch):
# reshaping required inside evaluate func because of per-sample gradients
batch=batch.reshape(-1,1,L,L)
# apply layer
a = apply_layer(params, batch)
# apply logcosh nonlinearity
z=jnp.log(jnp.cosh(a))
return jnp.sum(z)
@jit
def compute_grad_log_psi(params,batch,):
return vmap(partial( jit(grad(evaluate)), params))(batch, )
#return vmap(partial( grad(evaluate), params))(batch, )
###########################
# define data
N_points=300
batch=np.ones((N_points,N_symm,L,L),dtype=dtype)
for _ in range (10):
ti = time.time()
d_psi = compute_grad_log_psi(params,batch)
tf = time.time()
print("gradients took {0:.4f} secs.".format(tf-ti))
which produced the following output:
/Users/mgbukov/jax/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
2019-10-30 16:21:03.691593: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
gradients took 163.3030 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0003 secs.
gradients took 0.0003 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
Issue Analytics
- State:
- Created 4 years ago
- Comments:8 (3 by maintainers)
Top Results From Across the Web
Opacus: User-Friendly Differential Privacy Library in PyTorch
Opacus computes batched per-sample gradients, providing ... does indeed yield correct per-sample gradients, it can be very slow in.
Read more >How to Control the Stability of Training Neural Networks ...
Stochastic gradient descent requires that the model make a prediction and have the weights updated for each training example. This has the ...
Read more >Applying Gradient Descent in Convolutional Neural Networks
This paper mainly discusses the CNN and the related BP and GD algorithms, including the basic structure and function of CNN, details of...
Read more >Various Optimization Algorithms For Training Neural Network
Gradient Descent · May trap at local minima. · Weights are changed after calculating gradient on the whole dataset. So, if the dataset...
Read more >(PDF) Efficient Per-Example Gradient Computations in ...
three main approaches to per-example gradient computation are described as ... Since there is no parallelization, this method is very slow.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I just reproduced this: I also see the slow compilation on CPU, however it’s very fast to compile and run on the GPU backend (a few seconds). I suspect this is a XLA:CPU specific compilation issue.
The pre-optimized XLA HLO is identical on CPU and GPU, but: On the GPU backend things are fine:
on CPU the optimization passes blow up the code size:
We’ll probably need to loop in the XLA devs to look further into this issue.
For future reference the pre-optimized problematic XLA HLO is attached: module_0024.before_optimizations.txt
Hmm - it seems this might be happening because this code creates a grouped convolution and Eigen doesn’t support grouped convolutions so XLA has to lower it into 300 separate convs. Is there a reason you’re running on CPU and not GPU where such convs are supported and reasonably fast?
I’m not sure there’s an easy workaround here on CPU for initial compile times, as the current cpu implementation has already been optimized around a lack of underlying primitives.