Proper accuracy calculation on multi-gpu with DistributedDataParallel with unevenly divisible data size
See original GitHub issueHow do we calculate proper DiceMetric accuracy (or any other metric) on multi-gpu DistributedDataParallel when dataset is not evenly divisible by number of gpus?
For example on the last epoch, on 8 gpu machine, we have only 5 images, so 3 gpus will be idle . So we need to records results only from 5gpus, and ignore other 3gpus. Monai has a subclass of DistributedSampler , which seems to account for non-even division using parameter even_divisible but if I set it to False, the DistributedDataParallel freezes since some GPUs now receive no data.
Based on the tutorial code from here
it creates a dataset from 13 images, and if running on 8 gpus it will freeze.
It will freeze because of any syncs between gpus in this case the line current_accuracy = dice_metric.aggregate().item()
is the problem (added to demonstrate the issue), but any torch.distributed.barrier()
would also freeze it. DDB expects the same code path for all processes and with DistributedSampler.even_divisible=False, it’ll create different path in the last epoch (and it will only work if no sync is called, which is not convenient, as we need to calculate intermediate results or some other reduction calls).
Here is the example code:
import argparse
import os
from glob import glob
import nibabel as nib
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import monai
from monai.data import DataLoader, Dataset, create_test_image_3d, DistributedSampler, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, ScaleIntensityd, EnsureTyped, EnsureType
def evaluate(args):
if args.lo
```cal_rank == 0 and not os.path.exists(args.dir):
# create 16 random image, mask paris for evaluation
print(f"generating synthetic data to {args.dir} (this may take a while)")
os.makedirs(args.dir)
# set random seed to generate same random data for every node
np.random.seed(seed=0)
for i in range(16):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))
# initialize the distributed evaluation process, every GPU runs in a process
dist.init_process_group(backend="nccl", init_method="env://")
images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
# define transforms for image and segmentation
val_transforms = Compose(
[
LoadImaged(keys=["img", "seg"]),
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
ScaleIntensityd(keys="img"),
EnsureTyped(keys=["img", "seg"]),
]
)
# create a evaluation data loader
val_files = val_files[:13]
val_ds = Dataset(data=val_files, transform=val_transforms)
# create a evaluation data sampler
val_sampler = DistributedSampler(dataset=val_ds, even_divisible=False, shuffle=False)
# sliding window inference need to input 1 image in every iteration
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
# create UNet, DiceLoss and Adam optimizer
# print(args)
device = torch.device(f"cuda:{args.local_rank}")
torch.cuda.set_device(device)
model = monai.networks.nets.UNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
# wrap the model with DistributedDataParallel module
model = DistributedDataParallel(model, device_ids=[device])
# config mapping to expected GPU device
# map_location = {"cuda:0": f"cuda:{args.local_rank}"}
# load model parameters to GPU device
# model.load_state_dict(torch.load("final_model.pth", map_location=map_location))
model.eval()
with torch.no_grad():
i=0
for val_data in val_loader:
val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
# define sliding window size and batch size for windows inference
roi_size = (96, 96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
d= dice_metric(y_pred=val_outputs, y=val_labels)
current_accuracy = dice_metric.aggregate().item() # this line requires sync across gpus, and will freeze the system
print(args.local_rank, 'iter', i, 'dice', d, current_accuracy, "\n")
i = i + 1
print(args.local_rank, 'aggregate', len(dice_metric.get_buffer()))
metric = dice_metric.aggregate().item()
dice_metric.reset()
if dist.get_rank() == 0:
print("evaluation metric:", metric)
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dir", default="./testdata", type=str, help="directory to create random data")
# must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDP
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
evaluate(args=args)
# usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py):
# python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12345 unet_evaluation_ddp.py -d ./tmp
if __name__ == "__main__":
main()
Issue Analytics
- State:
- Created a year ago
- Comments:14 (9 by maintainers)
Top GitHub Comments
yes, that’s what I mean too. many original test datasets are uneven length, and we need to calculate the correct accuracy on multi-gpu. it seems at the moment we are able to do it , but we can’t get intermediate results to print, and the validation loop gives no metric feedback to the user… so perhaps we can come up with a better way. I’m not sure what this way is, but it should be a common problem in different domains (not only in medical)
Agreed @ericspod - validation does need to go that extra distance to ensure precision (unless we mandate all test sets to be of size 5040 😃 )