question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItΒ collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Manual optimization doesn't work with multiple TPUs with `pytorch-lightning: 1.7.1`

See original GitHub issue

πŸ› Bug

As title, it only works with core. When multiple cores, it yields Assertion Errors under this line https://github.com/Lightning-AI/lightning/blob/acd4805f1a284e513272d150de6f98f27a0489b3/src/pytorch_lightning/loops/optimization/manual_loop.py#L110

(torch-12) ]0;dinhanhx@t1v-n-8b0bf8c6-w-0: ~/storage/projects/boringdinhanhx@t1v-n-8b0bf8c6-w-0:~/storage/projects/boring$ conda activate torch-12python3 boring.py
2022-08-17 15:38:55.853913: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:38:55.853977: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
2022-08-17 15:39:25.752772: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:25.752857: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:50.188700: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:50.188770: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:51.019169: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:51.019232: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:51.679213: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:51.679274: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:52.989587: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:52.989650: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:53.426351: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:53.426413: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:54.191559: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:54.191623: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:54.916327: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:54.916389: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1894: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
[?25l
Epoch 0    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 β€’ -:--:-- 0.00it/s  
Epoch 0    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 β€’ -:--:-- 0.00it/s  pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
Exception in device=TPU:7: 
Exception in device=TPU:4: 
Exception in device=TPU:2: 
Exception in device=TPU:3: 
Exception in device=TPU:1: 
Exception in device=TPU:5: 
Exception in device=TPU:6: 
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import LinearLR

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import RichProgressBar

import torch_xla.core.xla_model as xm
from torch.utils.data.distributed import DistributedSampler


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()
        loss = self(batch).sum()
        self.log("train_loss", loss, sync_dist=True, sync_dist_group=True, rank_zero_only=True)
        self.manual_backward(loss)
        opt.step()
        sch = self.lr_schedulers()
        sch.step()
        # return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        scheduler = LinearLR(opt, start_factor=0.5, total_iters=4)
        return [opt], [scheduler]


def run():
    ds = RandomDataset(32, 64)
    sampler = DistributedSampler(
            ds, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
        )
    
    train_data = DataLoader(ds, batch_size=2)
    val_data = DataLoader(ds, batch_size=2)
    test_data = DataLoader(ds, batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        logger=CSVLogger("csvlogs"), 
        accelerator='tpu', devices=8, 
        callbacks=[RichProgressBar()], 
        strategy="tpu_spawn_debug"
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()
export TPU_LOG_DIR="disabled"
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
python boring.py

Expected behavior

It runs.

Environment

* CUDA:
        - GPU:               None
        - available:         False
        - version:           10.2
* Lightning:
        - pytorch-lightning: 1.7.1
        - torch:             1.11.0
        - torch-xla:         1.11
        - torchinfo:         1.7.0
        - torchmetrics:      0.9.3
        - torchvision:       0.12.0
* Packages:
        - absl-py:           1.2.0
        - aiohttp:           3.8.1
        - aiosignal:         1.2.0
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - astroid:           2.11.7
        - asttokens:         2.0.5
        - async-timeout:     4.0.2
        - attrs:             22.1.0
        - backcall:          0.2.0
        - beautifulsoup4:    4.11.1
        - bleach:            5.0.1
        - cachetools:        4.2.4
        - certifi:           2022.6.15
        - cffi:              1.15.1
        - charset-normalizer: 2.1.0
        - cloud-tpu-client:  0.10
        - cloud-tpu-profiler: 2.4.0
        - commonmark:        0.9.1
        - debugpy:           1.6.2
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.5.1
        - einops:            0.4.1
        - entrypoints:       0.4
        - executing:         0.9.1
        - fastjsonschema:    2.16.1
        - filelock:          3.7.1
        - flake8:            5.0.4
        - frozenlist:        1.3.1
        - fsspec:            2022.7.1
        - google-api-core:   1.32.0
        - google-api-python-client: 1.8.0
        - google-auth:       1.35.0
        - google-auth-httplib2: 0.1.0
        - google-auth-oauthlib: 0.4.6
        - googleapis-common-protos: 1.56.4
        - grpcio:            1.47.0
        - httplib2:          0.20.4
        - huggingface-hub:   0.8.1
        - idna:              3.3
        - importlib-metadata: 4.12.0
        - importlib-resources: 5.9.0
        - install:           1.3.5
        - ipykernel:         6.15.1
        - ipython:           8.4.0
        - ipython-genutils:  0.2.0
        - ipywidgets:        7.7.1
        - isort:             5.10.1
        - jedi:              0.18.1
        - jinja2:            3.1.2
        - jsonschema:        4.9.1
        - jupyter-client:    7.3.4
        - jupyter-core:      4.11.1
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-widgets: 1.1.1
        - lazy-object-proxy: 1.7.1
        - libtpu-nightly:    0.1.dev20220303
        - markdown:          3.4.1
        - markupsafe:        2.1.1
        - matplotlib-inline: 0.1.3
        - mccabe:            0.7.0
        - mistune:           0.8.4
        - multidict:         6.0.2
        - nbclient:          0.6.6
        - nbconvert:         6.5.0
        - nbformat:          5.4.0
        - nest-asyncio:      1.5.5
        - notebook:          6.4.12
        - numpy:             1.23.1
        - oauth2client:      4.1.3
        - oauthlib:          3.2.0
        - packaging:         21.3
        - pandocfilters:     1.5.0
        - parso:             0.8.3
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            9.2.0
        - pip:               22.1.2
        - pkgutil-resolve-name: 1.3.10
        - platformdirs:      2.5.2
        - prometheus-client: 0.14.1
        - prompt-toolkit:    3.0.30
        - protobuf:          3.19.4
        - psutil:            5.9.1
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycodestyle:       2.9.1
        - pycparser:         2.21
        - pydeprecate:       0.3.2
        - pyflakes:          2.5.0
        - pygments:          2.12.0
        - pylint:            2.14.5
        - pyparsing:         3.0.9
        - pyrsistent:        0.18.1
        - python-dateutil:   2.8.2
        - pytorch-lightning: 1.7.1
        - pytz:              2022.1
        - pyyaml:            6.0
        - pyzmq:             23.2.0
        - regex:             2022.7.25
        - requests:          2.28.1
        - requests-oauthlib: 1.3.1
        - rich:              12.5.1
        - rsa:               4.9
        - send2trash:        1.8.0
        - sentencepiece:     0.1.97
        - setuptools:        61.2.0
        - six:               1.16.0
        - soupsieve:         2.3.2.post1
        - stack-data:        0.3.0
        - tablign:           0.3.4
        - tensorboard:       2.10.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - terminado:         0.15.0
        - tinycss2:          1.1.1
        - tokenizers:        0.12.1
        - tomli:             2.0.1
        - tomlkit:           0.11.1
        - torch:             1.11.0
        - torch-xla:         1.11
        - torchinfo:         1.7.0
        - torchmetrics:      0.9.3
        - torchvision:       0.12.0
        - tornado:           6.2
        - tqdm:              4.64.0
        - traitlets:         5.3.0
        - transformers:      4.21.1
        - typing-extensions: 4.3.0
        - uritemplate:       3.0.1
        - urllib3:           1.26.11
        - wcwidth:           0.2.5
        - webencodings:      0.5.1
        - werkzeug:          2.2.2
        - wheel:             0.37.1
        - widgetsnbextension: 3.6.1
        - yarl:              1.8.1
        - zipp:              3.8.1
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.13
        - version:           #46-Ubuntu SMP Mon Apr 19 19:17:04 UTC 2021

Additional context

I runs on Google TPU VM 3.8

cc @kaushikb11 @rohitgr7

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:11 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
dinhanhxcommented, Aug 22, 2022

@awaelchli I just tried master 2022.08.22. Yes it works.

1reaction
awaelchlicommented, Aug 22, 2022

Correction: I no longer see these failures from a few weeks ago. All tests are passing. @dinhanhx maybe it is worth trying master?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Optimization β€” PyTorch Lightning 1.8.5.post0 documentation
However, it won't work on TPU, AMP, etc... optimizer ... Here is an example training a simple GAN with multiple optimizers using manual...
Read more >
Rami/multi-label-class-github-issues-text-classification
I wonder if the issue is related to Tensorboard versioning, however pip install pytorch-lightning doesn't seem to install any new tensorboard versionΒ ...
Read more >
PyTorch Lightning
TPUs or GPUs, without code changes. Want to train on multiple GPUs? TPUs? Determine your hardware on the go. Change one trainer param...
Read more >
Scaling deep learning workloads with PyTorch / XLA and ...
This problem exists when we store training data on a local disk and ... TPU profiler tool to further optimize PyTorch / XLA...
Read more >
pytorch-lightning
PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. ... Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs,Β ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found