Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Debugging Multiple GPU Model

See original GitHub issue

I am trying to reproduce a multiple GPU implementation of my keras model using some of the code from your blog post. I have slightly modified it to take a list of GPUs (in case I want to specify which GPUs I am using). I am using the Tensorflow backend of Keras, and they are both up to date. I have four NVIDIA Titan X GPUs. Below is a small example using MNIST.

from keras.layers import concatenate
from keras.layers.core import Lambda
from keras.models import Model

import tensorflow as tf

def make_parallel(model, gpu_list):
    def get_slice(data, idx, parts):
        shape = tf.shape(data)
        size = tf.concat([ shape[:1] // parts, shape[1:] ], axis=0)
        stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ], axis=0)
        start = stride * idx
        return tf.slice(data, start, size)

    outputs_all = []
    for i in range(len(model.outputs)):

    #Place a copy of the model on each GPU, each getting a slice of the batch
    gpu_count = len(gpu_list)
    for i in range(gpu_count):
        with tf.device('/gpu:%d' % gpu_list[i]):
            with tf.name_scope('tower_%d' % gpu_list[i]) as scope:

                inputs = []
                #Slice each input into a piece for processing on this GPU
                for x in model.inputs:
                    input_shape = tuple(x.get_shape().as_list())[1:]
                    slice_n = Lambda(get_slice, output_shape=input_shape, arguments={'idx':i,'parts':gpu_count})(x)

                outputs = model(inputs)
                if not isinstance(outputs, list):
                    outputs = [outputs]
                #Save all the outputs for merging back together later
                for l in range(len(outputs)):

    # merge outputs on CPU
    with tf.device('/cpu:0'):
        merged = []
        for outputs in outputs_all:
            merged.append(concatenate(outputs, axis=0))
        return Model(inputs=model.inputs, outputs=merged)

if __name__ == "__main__":
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.datasets import mnist
    from keras.utils import to_categorical
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.reshape(60000, -1)
    x_test = x_test.reshape(10000, -1)
    model = Sequential()
    model.add(Dense(64, input_shape=(784,), activation='relu'))
    model.add(Dense(10, activation='softmax'))
    parallel_model = make_parallel(model , [0,1,2,3])
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    parallel_model.compile(optimizer='nadam', loss='categorical_crossentropy',
                           metrics=['accuracy']), y_train, batch_size=128,
                       validation_data=(x_test, y_test))

This code works when I select two or four GPUs; but when I select three GPUs, I get the following error:

Using TensorFlow backend.
Train on 60000 samples, validate on 10000 samples
Epoch 1/1
Traceback (most recent call last):

  File "<ipython-input-1-524a8053f5a2>", line 1, in <module>
    runfile('/home/rmk6217/Documents/kemker/machine_learning/', wdir='/home/rmk6217/Documents/kemker/machine_learning')

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/spyder/utils/site/", line 866, in runfile
    execfile(filename, namespace)

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/spyder/utils/site/", line 102, in execfile
    exec(compile(, filename, 'exec'), namespace)

  File "/home/rmk6217/Documents/kemker/machine_learning/", line 71, in <module>
    validation_data=(x_test, y_test))

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/Keras-2.0.2-py3.5.egg/keras/engine/", line 1485, in fit

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/Keras-2.0.2-py3.5.egg/keras/engine/", line 1140, in _fit_loop
    outs = f(ins_batch)

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/Keras-2.0.2-py3.5.egg/keras/backend/", line 2102, in __call__

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/", line 767, in run

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/", line 965, in _run
    feed_dict_string, options, run_metadata)

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/", line 1015, in _do_run
    target_list, options, run_metadata)

  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/", line 1035, in _do_call
    raise type(e)(node_def, op, message)

InvalidArgumentError: Incompatible shapes: [128] vs. [126]
	 [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](ArgMax, ArgMax_1)]]

Caused by op 'Equal', defined at:
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/spyder/utils/ipython/", line 227, in <module>
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/spyder/utils/ipython/", line 223, in main
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 474, in start
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/zmq/eventloop/", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tornado/", line 831, in start
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tornado/", line 604, in _run_callback
    ret = callback()
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tornado/", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 258, in enter_eventloop
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 93, in loop_qt5
    return loop_qt4(kernel)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 87, in loop_qt4
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/IPython/lib/", line 144, in start_event_loop_qt4
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 39, in process_stream_events
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 291, in do_one_iteration
    stream.flush(zmq.POLLIN, 1)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/zmq/eventloop/", line 352, in flush
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/zmq/eventloop/", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/zmq/eventloop/", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tornado/", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 276, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 228, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 390, in execute_request
    user_expressions, allow_stdin)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/ipykernel/", line 501, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/IPython/core/", line 2717, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/IPython/core/", line 2827, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/IPython/core/", line 2881, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-524a8053f5a2>", line 1, in <module>
    runfile('/home/rmk6217/Documents/kemker/machine_learning/', wdir='/home/rmk6217/Documents/kemker/machine_learning')
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/spyder/utils/site/", line 866, in runfile
    execfile(filename, namespace)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/spyder/utils/site/", line 102, in execfile
    exec(compile(, filename, 'exec'), namespace)
  File "/home/rmk6217/Documents/kemker/machine_learning/", line 68, in <module>
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/Keras-2.0.2-py3.5.egg/keras/engine/", line 952, in compile
    append_metric(i, 'acc', masked_fn(y_true, y_pred, mask=masks[i]))
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/Keras-2.0.2-py3.5.egg/keras/engine/", line 479, in masked
    score_array = fn(y_true, y_pred)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/Keras-2.0.2-py3.5.egg/keras/", line 25, in categorical_accuracy
    K.argmax(y_pred, axis=-1)),
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/Keras-2.0.2-py3.5.egg/keras/backend/", line 1347, in equal
    return tf.equal(x, y)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/", line 721, in equal
    result = _op_def_lib.apply_op("Equal", x=x, y=y, name=name)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/", line 763, in apply_op
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/", line 2327, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/rmk6217/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/", line 1226, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Incompatible shapes: [128] vs. [126]
	 [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](ArgMax, ArgMax_1)]]

I have dug through the debugger for a while now, but I can seem to track the issue. I can’t help feeling that I am doing something stupid, so I was hoping another set of eyes might see things I didn’t Any assistance would be appreciated. Thanks!

Issue Analytics

  • State:closed
  • Created 6 years ago
  • Comments:7

github_iconTop GitHub Comments

rmkemkercommented, Apr 5, 2017

Awesome! There was an error in your code (forgot an import):

From: from keras.layers import Lambda, merge To: from keras.layers import Lambda, merge, concatenate

Any everything worked! I was able to edit your code to take in a list - so I can chose which GPUs I want. Super easy, thanks!

icybladecommented, Apr 5, 2017

If one mini-batch cannot be evenly split into each GPU, error will occur. You can try my solution here:

Read more comments on GitHub >

github_iconTop Results From Across the Web

Debugging - Hugging Face
Multi -GPU Network Issues Debug​​ If both processes can talk to each and allocate GPU memory each will print an OK status. For...
Read more >
How to debug with multi-gpu training · Issue #992 - GitHub
Hi, I am trying to debug multi-gpu training with Pycharm. But the multi-gpu training directly called the module torch.distributed.launch.
Read more >
PyTorch 101, Part 4: Memory Management and Using Multiple ...
This article covers PyTorch's advanced GPU management features, how to optimise memory usage and best practises for debugging memory errors.
Read more >
Arm Forge User Guide Version 21.0.1
Debug multiple GPU processes. CUDA allows debugging of multiple CUDA processes on the same node. However, each process will still attempt to reserve...
Read more >
How to scale training on multiple GPUs - Towards Data Science
In this blog post, I will go over how to scale up training with PyTorch. We've had some models in TensorFlow (<2.0) and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Post

No results found

github_iconTop Related Hashnode Post

No results found