question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to make ``scatter`` (Just CUDA) results repeatable.

See original GitHub issue

❓ Questions & Help

I try to repeat my work, but find the scatter in torch_scatter (cuda) is un-stable, though with defined random seed.

Due to the scatter is in class MessagePassing, I thought It is worth paying attention to.

Or I made mistake or neglected someting?

The following are my test results.

I’d appreciate it, if anyone could help or idea.

file.md

import os
import random

import numpy as np
import torch
from torch import nn
from torch.backends import cudnn
from torch.nn import Linear, Softplus
from torch_scatter import scatter


class ReadOutLayer(nn.Module):
    """Merge node layer."""

    def __init__(self, num_filters, out_size=1, readout="add", temp_to_cpu=True):
        super(ReadOutLayer, self).__init__()
        self.readout = readout
        self.lin1 = Linear(num_filters, num_filters * 5)
        self.s1 = Softplus()
        self.lin2 = Linear(num_filters * 5, num_filters)
        self.s2 = Softplus()
        self.lin3 = Linear(num_filters, out_size)
        self.temp_to_cpu = temp_to_cpu

    def forward(self, h, batch):
        h = self.lin1(h)
        h = self.s1(h)
        h = self.jump(h, batch)
        h = self.lin2(h)
        h = self.s2(h)
        h = self.lin3(h)
        return h

    def jump(self, h, batch):
        if self.temp_to_cpu:
            # torch.geometric scatter is unstable especially for small data in cuda device.?
            old_device = h.device
            device = torch.device("cpu")
            h = h.to(device=device)
            batch = batch.to(device=device)
            h = scatter(h, batch, dim=0, reduce=self.readout)
            h = h.to(device=old_device)
        else:
            h = scatter(h, batch, dim=0, reduce=self.readout)

        return h


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # torch.backends.cudnn.enabled = False
    cudnn.deterministic = True
    cudnn.benchmark = False


set_seed(1)

##### get data.x data.y, data.batch
x = torch.rand((1000, 100), requires_grad=True, )
y = torch.rand((100, 1), requires_grad=True, )
batch_mark = torch.randint(low=0, high=1000, size=(100,))
batch_mark = torch.sort(batch_mark).values
batch = torch.zeros((1000,))
for n, i in enumerate(batch_mark):
    batch[i:] = n

batch = batch.to(torch.int64)


# model definition

def scatter_check(x, y, batch, test):

    if test == "just cpu":
        temp_to_cpu = False
        device = torch.device("cpu")

    elif test == "cuda with cpu scatter":
        temp_to_cpu = True
        device = torch.device("cuda:0")

    elif test == "cuda":
        temp_to_cpu = False
        device = torch.device("cuda:0")
    else:
        raise  NotImplementedError

    model = ReadOutLayer(100, temp_to_cpu=temp_to_cpu)
    x = x.to(device)
    y = y.to(device)
    batch = batch.to(device)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
    loss_method = torch.nn.MSELoss()

    loss_ir = 0

    for i in range(300):
        p_y = model(x, batch)

        lossi = loss_method(p_y, y)
        print(i, lossi.item())
        loss_ir = lossi.item()

        lossi.backward()
        optimizer.step()
        optimizer.zero_grad()
    return loss_ir

set_seed(1)
a= scatter_check(x, y, batch, test="just cpu")
set_seed(1)
b= scatter_check(x, y, batch, test="just cpu")
# cpu is ok
assert a==b

set_seed(1)
a = scatter_check(x, y, batch, test="cuda with cpu scatter")
set_seed(1)
b = scatter_check(x, y, batch, test="cuda with cpu scatter")
# use cuda but with jump to cpu to run ``scatter``  is ok.
assert a==b

set_seed(1)
a = scatter_check(x, y, batch, test="cuda")
set_seed(1)
b = scatter_check(x, y, batch, test="cuda")
# cuda is fail.
assert a!=b

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:4
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

5reactions
rusty1scommented, Jul 2, 2021

Scatter is a non-deterministic operation by design since it makes use of atomic operations in which the order of aggregation is non-deterministic, leading to minor numerical differences. As an alternative, you can make use of the segment_csr operation of torch_scatter, see here.

For message passing layers, deterministic aggregation is only guaranteed when using SparseTensor.

In the end, I wouldn’t worry too much about it. In a deep learning scenario, such numerical instabilities should be only noticeable on really small datasets. Although it is correct that exact reproducible is no longer guaranteed when using non-deterministic operations, we can only enforce reproducibility for a single permutation (which does not exist in the context of graphs).

1reaction
rusty1scommented, Jul 19, 2022

Yes, this is due to how floating-point precision works. In case the ordering of operations is not deterministic internally, you may get slightly different outputs, e.g., (1 + 2) + 3 may be different from 1 + (2 + 3).

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to make scatter (Just CUDA) results repeatable. #2788
I try to repeat my work, but find the scatter in torch_scatter (cuda) is un-stable, though with defined random seed. Due to the...
Read more >
Best way to do scatter write without memory conflict?
Now launch a new kernel, and each block here “owns” a subgrid. It goes through the lists you just made and applies the...
Read more >
Retain Duplicates with Set Intersection in CUDA
I want all of the vals where the corresponding key is contained in comp. Is there any way to achieve this using thrust,...
Read more >
3-D scatter plot - MATLAB scatter3 - MathWorks
Create a 3-D scatter plot and use view to change the angle of the axes in ... Initialize the random-number generator to make...
Read more >
FAQ — Optuna 3.0.4 documentation
How to save machine learning models trained in objective functions? How can I obtain reproducible optimization results? How are exceptions from trials handled?...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found