question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Variable epoch_length for different epochs

See original GitHub issue

❓ Questions/Help/Support

Hi,

I’m working with a model that increases in complexity during training. To avoid memory issues, I reduce the batch size accordingly at each epoch. This means that, for a fixed length of the dataset, the number of iterations per epoch increases each epoch.

Something like this:


batch_size_per_epoch = [16, 8, 4, 2]
dataset = ImageDataset(...)

loaders = (DataLoader(dataset, batch_size=bs, shuffle=True) for bs in batch_size_per_epoch)

# Then I can run the engine either with
for i, loader in enumerate(loaders):
    engine.run(loader, max_epochs=i+1)
# or by calling engine.set_data in a properly defined event handler.

The problem is that engine.state.epoch_length is set once for the first loader and the subsequent loaders run as many iterations as the first one. Setting engine.state.epoch_length by hand is not only ugly, but also messes up the saving/loading of the engine (epoch and iterations are inferred assuming a fixed epoch length).

Is there any way to use variable epoch lengths or variable batch sizes with ignite? I’ve been thinking of building a new engine for each epoch, but keeping the state from previous engines, saving, loading and reusing the loggers/metrics/handlers is rather messy. Is there an alternative?

Best

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
vfdev-5commented, Feb 12, 2021

@pmneila well, I agree there is no proper way to do that. Here is a hacky approach to achieve what you’d like

import torch
from ignite.engine import Engine, Events

n_samples = 100
batch_size_per_epoch = [32, 16, 8, 4, 2]

loaders = [torch.rand(n_samples // bs, bs, 3, 32, 32) for bs in batch_size_per_epoch]


trainer = Engine(lambda e, b: print(f"{e.state.epoch} - {e.state.iteration} : {b.shape}"))
trainer.state.loader_index = 0


@trainer.on(Events.EPOCH_COMPLETED)
def set_next_loader():
    trainer.state.loader_index += 1
    print(f"Set next loader: {trainer.state.loader_index}")
    new_loader = loaders[trainer.state.loader_index]
    trainer.set_data(new_loader)
    trainer.state.epoch_length = len(new_loader)


trainer.run(loaders[trainer.state.loader_index], max_epochs=2)

print("\nBefore load_state_dict...")
print(f"trainer.state.epoch={trainer.state.epoch}")  # Prints: 2
print(f"trainer.state.iteration={trainer.state.iteration}")  # Prints: 9
sd = trainer.state_dict()
last_epoch = trainer.state.epoch

trainer.load_state_dict(sd)
# Set explicitly the epoch
trainer.state.epoch = last_epoch

print("\nAfter load_state_dict...")
print(f"trainer.state.epoch={trainer.state.epoch}")  # Prints: 0 (should be 2)
print(f"trainer.state.iteration={trainer.state.iteration}")  # Prints: 9

# we have to restart the data when starts the engine
# such that it avoids calling `_setup_engine()`
@trainer.on(Events.STARTED)
def reset_things():
    new_loader = loaders[trainer.state.loader_index]
    trainer.set_data(new_loader)
    trainer.state.epoch_length = len(new_loader)


trainer.run(loaders[trainer.state.loader_index], max_epochs=4)
Output
1 - 1 : torch.Size([32, 3, 32, 32])
1 - 2 : torch.Size([32, 3, 32, 32])
1 - 3 : torch.Size([32, 3, 32, 32])
Set next loader: 1
2 - 4 : torch.Size([16, 3, 32, 32])
2 - 5 : torch.Size([16, 3, 32, 32])
2 - 6 : torch.Size([16, 3, 32, 32])
2 - 7 : torch.Size([16, 3, 32, 32])
2 - 8 : torch.Size([16, 3, 32, 32])
2 - 9 : torch.Size([16, 3, 32, 32])
Set next loader: 2

Before load_state_dict...
trainer.state.epoch=2
trainer.state.iteration=9

After load_state_dict...
trainer.state.epoch=2
trainer.state.iteration=9
3 - 10 : torch.Size([8, 3, 32, 32])
3 - 11 : torch.Size([8, 3, 32, 32])
3 - 12 : torch.Size([8, 3, 32, 32])
3 - 13 : torch.Size([8, 3, 32, 32])
3 - 14 : torch.Size([8, 3, 32, 32])
3 - 15 : torch.Size([8, 3, 32, 32])
3 - 16 : torch.Size([8, 3, 32, 32])
3 - 17 : torch.Size([8, 3, 32, 32])
3 - 18 : torch.Size([8, 3, 32, 32])
3 - 19 : torch.Size([8, 3, 32, 32])
3 - 20 : torch.Size([8, 3, 32, 32])
3 - 21 : torch.Size([8, 3, 32, 32])
Set next loader: 3
4 - 22 : torch.Size([4, 3, 32, 32])
4 - 23 : torch.Size([4, 3, 32, 32])
4 - 24 : torch.Size([4, 3, 32, 32])
4 - 25 : torch.Size([4, 3, 32, 32])
4 - 26 : torch.Size([4, 3, 32, 32])
4 - 27 : torch.Size([4, 3, 32, 32])
4 - 28 : torch.Size([4, 3, 32, 32])
4 - 29 : torch.Size([4, 3, 32, 32])
4 - 30 : torch.Size([4, 3, 32, 32])
4 - 31 : torch.Size([4, 3, 32, 32])
4 - 32 : torch.Size([4, 3, 32, 32])
4 - 33 : torch.Size([4, 3, 32, 32])
4 - 34 : torch.Size([4, 3, 32, 32])
4 - 35 : torch.Size([4, 3, 32, 32])
4 - 36 : torch.Size([4, 3, 32, 32])
4 - 37 : torch.Size([4, 3, 32, 32])
4 - 38 : torch.Size([4, 3, 32, 32])
4 - 39 : torch.Size([4, 3, 32, 32])
4 - 40 : torch.Size([4, 3, 32, 32])
4 - 41 : torch.Size([4, 3, 32, 32])
4 - 42 : torch.Size([4, 3, 32, 32])
4 - 43 : torch.Size([4, 3, 32, 32])
4 - 44 : torch.Size([4, 3, 32, 32])
4 - 45 : torch.Size([4, 3, 32, 32])
4 - 46 : torch.Size([4, 3, 32, 32])
Set next loader: 4

Anyway, I agree that separating epoch, iteration and epoch_length could be a interesting feature to have.

2reactions
sdesroziscommented, Feb 12, 2021

@pnmeila thanks for this discussion.

I think having such way to handle dynamic batch size and epoch length could lead to implementation as in the following paper

https://arxiv.org/abs/1711.00489

Read more comments on GitHub >

github_iconTop Results From Across the Web

[Eeglablist] Extracting epochs with variable time-length
Overall having different epoch lengths is unusual and requires you use a non-traditional strategy. See below for some suggestions about that ...
Read more >
Epochs with variable durations... again - MNE Forum - Discourse
I am working with MEG data of sentence listening. I need to epoch the data on the basis of sentence duration - and...
Read more >
Compute t/f representation with variable-length epochs? #5612
I am interested in EEG activity during the problem solving phase, which differs in length (1) from epoch to epoch and (2) from...
Read more >
The Epochs data structure: discontinuous data - MNE-Python
This tutorial covers the basics of creating and working with epoched data. It introduces the Epochs data structure in detail, including how to...
Read more >
The consequences of using different epoch lengths on ... - NCBI
We conclude that a 60-second epoch seems preferable when the aim is to classify accelerometer-based sedentary behaviour, while shorter epochs are needed to ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found