Handling empty datasets in distributed metric computation
See original GitHub issue🐛 Bug description
Metric computation does not work properly in distributed settings when some processes do not handle any batch in the dataset. It becomes a problem when small validation or test datasets are distributed to processes in an imbalanced manner.
How to Reproduce
Create a Python script named main.py
with the following content.
import torch
import ignite.distributed as idist
from torch.utils.data import IterableDataset, DataLoader
from ignite.metrics import Loss
from ignite.engine.engine import Engine
from ignite.engine.events import Events
class SampleDataset(IterableDataset):
def __iter__(self):
if idist.get_rank() == 0:
yield torch.zeros((2, 3)), torch.ones((2, 3))
def report_metrics(engine):
print(engine.state.metrics)
def test(local_rank):
data_loader = DataLoader(SampleDataset(), batch_size=None)
engine = Engine(lambda _engine, batch: batch)
Loss(torch.nn.BCELoss(reduction="mean")).attach(engine, "loss")
engine.add_event_handler(Events.COMPLETED, report_metrics)
engine.run(data_loader)
with idist.Parallel(backend="gloo") as parallel:
parallel.run(test)
Run the following command inside a CPU Docker container with PyTorch and Ignite installed.
python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py
Problem 1
The command terminated with an error. Part of the output is shown below.
terminate called after throwing an instance of 'gloo::EnforceNotMet'
what(): [enforce fail at /opt/conda/conda-bld/pytorch_1595629403081/work/third_party/gloo/gloo/transport/tcp/pair.cc:490] op.preamble.length <= op.nbytes. 8 vs 4
It seems there is type inconsistency (int
vs float
) inside idist.all_reduce()
when calling compute()
, because not all processes have called update()
at least once. A simple fix could be changing this line to self._sum = 0.0
.
However this issue could affect other metrics as well. We probably need unit tests for such scenario for all metrics.
Problem 2
In the above script, if we change Loss(...)
to the precision or recall metric (e.g. Precision()
), we get the following error message.
Engine run is terminating due to exception: Precision must have at least one example before it can be computed..
The issue is the verification should actually be moved after idist.all_reduce()
. Although some processes may have seen empty dataset, the metric is still valid collectively.
Problem 3
After fixing Problem 2, there is still an issue with multi-label precision or recall. For example, changing Loss(...)
to Precision(is_multilabel=True, average=True)
and running the script will give the following error:
Engine run is terminating due to exception: 'float' object has no attribute 'mean'.
The issue is with this line. Because not all processes have called update()
at least once, there is again type inconsistency, where in some processes self._true_positives
is of type float
while in other processes it is a scalar tensor.
Environment
- PyTorch Version: 1.6.0
- Ignite Version: 0.4.1
- OS: Linux
- How you installed Ignite (
conda
,pip
, source):pip
- Python version: 3.7.7
- Any other relevant information: N/A
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (1 by maintainers)
Top GitHub Comments
@linhr thanks for the report ! Let me reproduire and investigate the issue.
cc @n2cholas as we are working on metrics right now, we maybe have to take this into account
Since v1.7.0 pytorch seem to support uneven inputs accross participating processes: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.join