RuntimeError: Expected to mark a variable ready only once when using Swin Transformer
See original GitHub issueTrying to use SWAV to pretrain Swin-Transformers but I am getting this error
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.3) Incorrect unused parameter detection. The return value of the `forward` function is inspected by the distributed data parallel wrapper to figure out if any of the module's parameters went unused. For unused parameters, DDP would not expect gradients from then. However, if an unused parameter becomes part of the autograd graph at a later point in time (e.g., in a reentrant backward when using `checkpoint`), the gradient will show up unexpectedly. If all parameters in the model participate in the backward pass, you can disable unused parameter detection by passing the keyword argument `find_unused_parameters=False` to `torch.nn.parallel.DistributedDataParallel`.
Entire code for reproduction
import warnings
warnings.filterwarnings("ignore")
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from torch import nn
import torchvision
import pytorch_lightning as pl
from lightly.data import LightlyDataset
from lightly.data import SwaVCollateFunction
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead
from lightly.models.modules import SwaVPrototypes
from swin_transformer_v2 import SwinTransformerV2
from swin_transformer_v2 import swin_transformer_v2_t, swin_transformer_v2_s
import torch.utils.model_zoo as model_zoo
import torch
model_urls = {
'tiny': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
# 'tiny': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
'small': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
}
def swin_transformer_version2(pretrained=False, channels=3, **kwargs):
version = str(kwargs.pop('version'))
window_size = kwargs.pop('window_size')
resolution = kwargs.pop('resolution')
if version == 'tiny':
model = swin_transformer_v2_t(
in_channels=3,
window_size=8,
input_resolution=(256, 256),
sequential_self_attention=False,
use_checkpoint=pretrained
)
else:
model = swin_transformer_v2_s(
in_channels=3,
window_size=8,
input_resolution=(256, 256),
sequential_self_attention=False,
use_checkpoint=pretrained
)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls[version], model_dir='.'), strict=False)
if resolution > 256:
print(window_size, resolution, version, pretrained)
model.update_resolution(new_window_size=window_size, new_input_resolution=(resolution, resolution))
return model
import lightly.data as data
class SwaV(pl.LightningModule):
def __init__(self):
super().__init__()
self.swin_model = swin_transformer_version2(pretrained=True, version='tiny', window_size=8, resolution=256)
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
self.projection_head = SwaVProjectionHead(768, 512, 128)
self.prototypes = SwaVPrototypes(128, n_prototypes=512)
# enable sinkhorn_gather_distributed to gather features from all gpus
# while running the sinkhorn algorithm in the loss calculation
self.criterion = SwaVLoss(sinkhorn_gather_distributed=True)
def forward(self, x):
x = self.swin_model(x)
x = self.avg_pool(x[-1])
x = x.flatten(start_dim=1)
x = self.projection_head(x)
x = nn.functional.normalize(x, dim=1, p=2)
p = self.prototypes(x)
return p
def training_step(self, batch, batch_idx):
self.prototypes.normalize()
crops, _, _ = batch
multi_crop_features = [self.forward(x.to(self.device)) for x in crops]
high_resolution = multi_crop_features[:2]
low_resolution = multi_crop_features[2:]
loss = self.criterion(high_resolution, low_resolution)
log_dict = {"train_loss": loss}
self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True)
return {"loss": loss, "log": log_dict}
# return loss
def training_epoch_end(self, outputs):
train_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log('train_loss', train_loss, sync_dist=True)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=0.001)
return optim
def main(args):
seed_everything(args.seed)
checkpoint_callback = ModelCheckpoint(
monitor="train_loss",
dirpath=f"logs/{args.name}",
filename=args.project + "-{epoch:02d}-{train_loss:.4f}",
save_top_k=1,
mode="min",
every_n_epochs=5,
)
dataset = data.LightlyDataset(input_dir='data2/train')
collate_fn = SwaVCollateFunction(
crop_sizes=[256, 256], crop_counts=[2, 6], crop_min_scales=[0.14, 0.05], crop_max_scales=[1.0, 0.14], vf_prob=0.5, rr_prob=0.5,
gaussian_blur=0.25, cj_strength=0.5, kernel_size=5, normalize={'mean': [0, 0, 0], 'std': [1, 1, 1]}
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch,
collate_fn=collate_fn,
shuffle=True,
drop_last=True,
num_workers=16,
)
gpus = torch.cuda.device_count()
# trainer.fit(model=model, train_dataloaders=dataloader)
wandb_logger = WandbLogger(project=args.project, name=args.name)
model = SwaV()
trainer = Trainer(
gpus=gpus,
max_epochs=args.epochs,
precision=args.precision,
callbacks=[checkpoint_callback],
logger=wandb_logger,
log_every_n_steps=2,
# strategy='ddp',
sync_batchnorm=True,
strategy="ddp_find_unused_parameters_false",
)
trainer.fit(model=model, train_dataloaders=dataloader)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Implementation of SwAV")
parser.add_argument("--project", default="swav_test", type=str, help="wandb project")
parser.add_argument("--name", default="exp", type=str, help="wandb name")
parser.add_argument("--input", default="data/train", type=str, help="input directory")
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--device", default=0, type=int)
parser.add_argument("--precision", default=16, type=int)
parser.add_argument("--batch", default=16, type=int)
args = parser.parse_args()
main(args)
cc @justusschock @kaushikb11 @awaelchli @akihironitta @rohitgr7
Issue Analytics
- State:
- Created a year ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
Expected to mark a variable ready only once - distributed
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a...
Read more >PyTorch DDP: Finding the cause of "Expected to mark a ...
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a...
Read more >ddp_find_unused_parameters_f...
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a...
Read more >How To Fit a Bigger Model and Train It Faster - Hugging Face
Run all steps above and then just one of the experiments below. ... Even when we set the batch size to 1 and...
Read more >Groovy Language Documentation
Only closures with zero or one parameters are allowed. Interoperability with Java. When a method (whether implemented in Java or Groovy) expects a...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hey, @sarmientoj24 - Thank you for raising the issue! Just to cross-check, and debug further, could you please try disabling
use_checkpoint
in theswin_transformer_v2_t
/swin_transformer_v2_s
function calls? (use_checkpoint=False
)I know it’s not the solution, but just cross-checking if it’s the issue.
Sharing full stack trace for anyone who is trying to debug this:
data2/train
folder, add around 4-10 images (to not go OOM if you don’t have enough GPU memory).16
to something smaller, if you don’t have minimum 16 images in the folder.lightly
andSwin-Transformer-V2
:pip install lightly
,pip install git+https://github.com/ChristophReich1996/Swin-Transformer-V2
)Full stack trace
Thanks, this works for me! Turning off the checkpoint can solve this problem.