question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

bfloat16 pmap cuda11.1 results in error

See original GitHub issue

I received the following error message when using bfloat16 with pmap with jaxlib 0.1.69 + cuda111 (no issue without pmap):

RuntimeError: Unimplemented: Requested AllReduce not implemented on GPU; replica_count: 2; partition_count: 1, group_mode: kCrossReplica, operand_count: 2; NCCL support: 1; first operand array element-type: BF16

The bug(?) can be reproduced as follows (code copied from Parallel Evaluation in JAX included in reference documentation):

import jax
import numpy as np
import jax.numpy as jnp
jax.devices()

from typing import NamedTuple, Tuple
import functools

class Params(NamedTuple):
    weight: jnp.ndarray
    bias: jnp.ndarray


def init(rng) -> Params:
    """Returns the initial model params."""
    weights_key, bias_key = jax.random.split(rng)
    weight = jax.random.normal(weights_key, (), dtype=D_TYPE)
    bias = jax.random.normal(bias_key, (), dtype=D_TYPE)
    return Params(weight, bias)


def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
    """Computes the least squares error of the model's predictions on x against y."""
    pred = params.weight * xs + params.bias
    return jnp.mean((pred - ys) ** 2)

LEARNING_RATE = 0.005

# So far, the code is identical to the single-device case. Here's what's new:


# Remember that the `axis_name` is just an arbitrary string label used
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
    """Performs one SGD update step on params using the given data."""

    # Compute the gradients on the given minibatch (individually on each device).
    loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)

    # Combine the gradient across all devices (by taking their mean).
    grads = jax.lax.pmean(grads, axis_name='num_devices')

    # Also combine the loss. Unnecessary for the update, but useful for logging.
    loss = jax.lax.pmean(loss, axis_name='num_devices')

    # Each device performs its own update, but since we start with the same params
    # and synchronise gradients, the params stay in sync.
    new_params = jax.tree_multimap(
        lambda param, g: param - g * LEARNING_RATE, params, grads)

    return new_params, loss

# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise

# Initialise parameters and replicate across devices.
D_TYPE = jnp.bfloat16
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

def split(arr):
    """Splits the first axis of `arr` evenly across the number of devices."""
    return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)

type(x_split)

def type_after_update(name, obj):
    print(f"after first `update()`, `{name}` is a", type(obj))

# Actual training loop.
for i in range(1000):

    # This is where the params and data gets communicated to devices:
    replicated_params, loss = update(replicated_params, x_split, y_split)

    # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
    # indicating that they're on the devices.
    # `x_split`, of course, remains a NumPy array on the host.
    if i == 0:
        type_after_update('replicated_params.weight', replicated_params.weight)
        type_after_update('loss', loss)
        type_after_update('x_split', x_split)

    if i % 100 == 0:
        # Note that loss is actually an array of shape [num_devices], with identical
        # entries, because each device returns its copy of the loss.
        # So, we take the first element to print it.
        print(f"Step {i:3d}, loss: {loss[0]:.3f}")


# Plot results.

# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))

setting D_TYPE to jnp.float32 eliminates error

Thanks for your help

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
tomhennigancommented, Aug 10, 2021

bf16 collectives are supported in XLA on all GPUs since tensorflow/tensorflow@43c7026.

On Ampere cards bf16 GeMM is supported since tensorflow/tensorflow@9d1970a and Conv since tensorflow/tensorflow@4731fe5.

I’d suggest installing the latest jaxlib (0.1.70) to pick up these XLA changes and trying again.

0reactions
jakevdpcommented, Aug 17, 2021

Looks like this is already addressed in the latest jaxlib, so I’m going to close the issue. Thanks for the report!

Read more comments on GitHub >

github_iconTop Results From Across the Web

bfloat16 pmap cuda11.1 results in error (v2) #8044 - GitHub
i have a second gpu and can now reproduce the behavior/error in #7567 locally (don't have permissions to reopen that): Traceback (most ...
Read more >
Change log - JAX documentation
The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression...
Read more >
API Reference :: NVIDIA Deep Learning cuDNN Documentation
This error is usually returned when a call to cudnnCreate() fails or when cudnnCreate() has not been called prior to calling another cuDNN...
Read more >
TRAINING NEURAL NETWORKS WITH TENSOR CORES
o Contrast with fp16/bfloat16 types that provide: storage, various math ... mAP. 24.91. 24.85. SAMPLING OF NETWORKS. Architecture. Network. Top-1 Accuracy.
Read more >
PyTorch 1.9.0 Now Available - Exxact Corporation
Added BFloat16 support for torch. ... RPC backend if a device map is provided (#57288); torch.distributed.optim : ... Improve CUDA-11.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found