question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

very slow compile: CNN, per sample gradient

See original GitHub issue

I encountered a “Very slow compile” when computing per-sample gradients with a small CNN. The warning occurs when the number of samples exceeds about 300 (which is not large either):

2019-10-30 16:12:13.561535: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************

The compilation time scales with the number of samples which makes it very hard to scale up the code for practical purposes.

Here’s the code with N_points which controls the number of samples:

import sys,os

from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, grad, random, vmap

import numpy as np
from functools import partial

seed=1
np.random.seed(seed)
np.random.RandomState(seed)
rng = random.PRNGKey(seed)


from jax.experimental.stax import GeneralConv, relu
from jax.nn.initializers import glorot_normal

import time

####################################################################
os.environ['XLA_FLAGS']='--xla_dump_to=/tmp/foo'

L=4 # linear image dimension
dtype=jnp.float64 
N_symm=2*2*2 # 

dimension_numbers=('NCHW', 'OIHW', 'NCHW') # default
out_chan=1
filter_shape=(2,2)
strides=(1,1)

input_shape=np.array((1,1,L,L),dtype=np.int) # NCHW input format

lhs_spec, rhs_spec, out_spec = dimension_numbers
W_init=glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))


W_init = partial(W_init, dtype=dtype)
init_params, apply_layer = GeneralConv(dimension_numbers, out_chan, filter_shape, strides=strides, padding='VALID', W_init=W_init)
       

# initialize parameters
_,params = init_params(rng,input_shape)



@jit
def evaluate(params, batch):
    # reshaping required inside evaluate func because of per-sample gradients
    batch=batch.reshape(-1,1,L,L)

    # apply layer
    a = apply_layer(params, batch)

    # apply logcosh nonlinearity
    z=jnp.log(jnp.cosh(a)) 
    
    return jnp.sum(z)


@jit
def compute_grad_log_psi(params,batch,):
	return vmap(partial( jit(grad(evaluate)),   params))(batch, )
	#return vmap(partial( grad(evaluate), params))(batch, )
	
###########################

# define data
N_points=300 
batch=np.ones((N_points,N_symm,L,L),dtype=dtype)

	
for _ in range (10):

    ti = time.time()
    d_psi = compute_grad_log_psi(params,batch)
    tf = time.time()
    
    print("gradients took {0:.4f} secs.".format(tf-ti))

which produced the following output:

/Users/mgbukov/jax/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
2019-10-30 16:21:03.691593: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
	gradients took 163.3030 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0003 secs.
gradients took 0.0003 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.
gradients took 0.0002 secs.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
levskayacommented, Oct 31, 2019

I just reproduced this: I also see the slow compilation on CPU, however it’s very fast to compile and run on the GPU backend (a few seconds). I suspect this is a XLA:CPU specific compilation issue.

The pre-optimized XLA HLO is identical on CPU and GPU, but: On the GPU backend things are fine:

   2.6K Oct 30 23:00 ../gpu/module_0024.after_optimizations-buffer-assignment.txt
   6.2K Oct 30 23:00 ../gpu/module_0024.after_optimizations.txt
     11K Oct 30 23:00 ../gpu/module_0024.before_optimizations.txt

on CPU the optimization passes blow up the code size:

   143K Oct 30 23:05 module_0024.after_optimizations-buffer-assignment.txt
   1.2M Oct 30 23:05 module_0024.after_optimizations.txt
     11K Oct 30 23:05 module_0024.before_optimizations.txt

We’ll probably need to loop in the XLA devs to look further into this issue.

For future reference the pre-optimized problematic XLA HLO is attached: module_0024.before_optimizations.txt

1reaction
levskayacommented, Oct 31, 2019

Hmm - it seems this might be happening because this code creates a grouped convolution and Eigen doesn’t support grouped convolutions so XLA has to lower it into 300 separate convs. Is there a reason you’re running on CPU and not GPU where such convs are supported and reasonably fast?

I’m not sure there’s an easy workaround here on CPU for initial compile times, as the current cpu implementation has already been optimized around a lack of underlying primitives.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Opacus: User-Friendly Differential Privacy Library in PyTorch
Opacus computes batched per-sample gradients, providing ... does indeed yield correct per-sample gradients, it can be very slow in.
Read more >
How to Control the Stability of Training Neural Networks ...
Stochastic gradient descent requires that the model make a prediction and have the weights updated for each training example. This has the ...
Read more >
Applying Gradient Descent in Convolutional Neural Networks
This paper mainly discusses the CNN and the related BP and GD algorithms, including the basic structure and function of CNN, details of...
Read more >
Various Optimization Algorithms For Training Neural Network
Gradient Descent · May trap at local minima. · Weights are changed after calculating gradient on the whole dataset. So, if the dataset...
Read more >
(PDF) Efficient Per-Example Gradient Computations in ...
three main approaches to per-example gradient computation are described as ... Since there is no parallelization, this method is very slow.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found