Sampling patches from multiple subject datasets
See original GitHub issue🚀 Feature Suppose I have two (or more) SubjectsDatasets. One dataset (dataset A) with scans from hospital A and another dataset (dataset B) with scans from hospital B. Now I want that my model sees as much patches from dataset A as from dataset B. However, dataset A contains 500 subjects and dataset B contains only 10 subjects.
In tensorflow there is a method sample_from_datasets. It would look like this for this example:
dataset = tf.data.experimental.sample_from_datasets(
[dataset_A, dataset_B], weights=[1,50], seed=None
)
This new dataset is then the dataset that could be passed to the torchio Queue.
In PyTorch this looks a bit different, but the result is the same:
sets = [dataset_A, dataset_B]
dataset = ConcatDataset(sets)
dist = np.concatenate([[(len(dataset) - len(s))/len(dataset)]*len(s) for s in sets])
sampler = WeightedRandomSampler(weights=dist, num_samples=min([len(s) for s in sets] * len(sets))
dataloader = DataLoader(dataset, sampler=sampler)
What would be the best way to integrate this feature in TorchIO?
- being able to also pass a sampler to the Queue?
- having a SampledSubjectsDataset(datasets: List[SubjectsDataset], weights=List[float]) that can be passed to the queue instead of a SubjectsDataset?
- how many patients to sample in one epoch when you have unbalanced datasets? Then length of the smallest datasets times the number of datasets?
Motivation
Training with balanced datasets is known to be important to get good results
Alternatives
An alternative is deepcopying dataset B 50 times (before images get loaded) to get it balanced. However, not a very nice solution.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:6 (6 by maintainers)
Top GitHub Comments
@dmus beat me. The thing is that preprocessed images are not stored in the dataset, because they are basically lists of paths with methods to load, transform and return the images.
I don’t think there is a need in this case. In the __get_item method of the SubjectsDataset is a need, there a deepcopy is made to avoid doing preprocessing on an already preprocessed subject