bfloat16 pmap cuda11.1 results in error
See original GitHub issueI received the following error message when using bfloat16 with pmap with jaxlib 0.1.69 + cuda111 (no issue without pmap):
RuntimeError: Unimplemented: Requested AllReduce not implemented on GPU; replica_count: 2; partition_count: 1, group_mode: kCrossReplica, operand_count: 2; NCCL support: 1; first operand array element-type: BF16
The bug(?) can be reproduced as follows (code copied from Parallel Evaluation in JAX included in reference documentation):
import jax
import numpy as np
import jax.numpy as jnp
jax.devices()
from typing import NamedTuple, Tuple
import functools
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, (), dtype=D_TYPE)
bias = jax.random.normal(bias_key, (), dtype=D_TYPE)
return Params(weight, bias)
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * xs + params.bias
return jnp.mean((pred - ys) ** 2)
LEARNING_RATE = 0.005
# So far, the code is identical to the single-device case. Here's what's new:
# Remember that the `axis_name` is just an arbitrary string label used
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
"""Performs one SGD update step on params using the given data."""
# Compute the gradients on the given minibatch (individually on each device).
loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
# Combine the gradient across all devices (by taking their mean).
grads = jax.lax.pmean(grads, axis_name='num_devices')
# Also combine the loss. Unnecessary for the update, but useful for logging.
loss = jax.lax.pmean(loss, axis_name='num_devices')
# Each device performs its own update, but since we start with the same params
# and synchronise gradients, the params stay in sync.
new_params = jax.tree_multimap(
lambda param, g: param - g * LEARNING_RATE, params, grads)
return new_params, loss
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise
# Initialise parameters and replicate across devices.
D_TYPE = jnp.bfloat16
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
def split(arr):
"""Splits the first axis of `arr` evenly across the number of devices."""
return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])
# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)
type(x_split)
def type_after_update(name, obj):
print(f"after first `update()`, `{name}` is a", type(obj))
# Actual training loop.
for i in range(1000):
# This is where the params and data gets communicated to devices:
replicated_params, loss = update(replicated_params, x_split, y_split)
# The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
# indicating that they're on the devices.
# `x_split`, of course, remains a NumPy array on the host.
if i == 0:
type_after_update('replicated_params.weight', replicated_params.weight)
type_after_update('loss', loss)
type_after_update('x_split', x_split)
if i % 100 == 0:
# Note that loss is actually an array of shape [num_devices], with identical
# entries, because each device returns its copy of the loss.
# So, we take the first element to print it.
print(f"Step {i:3d}, loss: {loss[0]:.3f}")
# Plot results.
# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))
setting D_TYPE to jnp.float32 eliminates error
Thanks for your help
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
bfloat16 pmap cuda11.1 results in error (v2) #8044 - GitHub
i have a second gpu and can now reproduce the behavior/error in #7567 locally (don't have permissions to reopen that): Traceback (most ...
Read more >Change log - JAX documentation
The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression...
Read more >API Reference :: NVIDIA Deep Learning cuDNN Documentation
This error is usually returned when a call to cudnnCreate() fails or when cudnnCreate() has not been called prior to calling another cuDNN...
Read more >TRAINING NEURAL NETWORKS WITH TENSOR CORES
o Contrast with fp16/bfloat16 types that provide: storage, various math ... mAP. 24.91. 24.85. SAMPLING OF NETWORKS. Architecture. Network. Top-1 Accuracy.
Read more >PyTorch 1.9.0 Now Available - Exxact Corporation
Added BFloat16 support for torch. ... RPC backend if a device map is provided (#57288); torch.distributed.optim : ... Improve CUDA-11.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
bf16 collectives are supported in XLA on all GPUs since tensorflow/tensorflow@43c7026.
On Ampere cards bf16 GeMM is supported since tensorflow/tensorflow@9d1970a and Conv since tensorflow/tensorflow@4731fe5.
I’d suggest installing the latest jaxlib (
0.1.70
) to pick up these XLA changes and trying again.Looks like this is already addressed in the latest jaxlib, so I’m going to close the issue. Thanks for the report!