Manual optimization doesn't work with multiple TPUs with `pytorch-lightning: 1.7.1`
See original GitHub issueπ Bug
As title, it only works with core. When multiple cores, it yields Assertion Errors under this line https://github.com/Lightning-AI/lightning/blob/acd4805f1a284e513272d150de6f98f27a0489b3/src/pytorch_lightning/loops/optimization/manual_loop.py#L110
(torch-12) ]0;dinhanhx@t1v-n-8b0bf8c6-w-0: ~/storage/projects/boring[01;32mdinhanhx@t1v-n-8b0bf8c6-w-0[00m:[01;34m~/storage/projects/boring[00m$ conda activate torch-12[6Ppython3 boring.py
2022-08-17 15:38:55.853913: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:38:55.853977: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
2022-08-17 15:39:25.752772: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:25.752857: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:50.188700: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:50.188770: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:51.019169: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:51.019232: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:51.679213: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:51.679274: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:52.989587: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:52.989650: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:53.426351: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:53.426413: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:54.191559: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:54.191623: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:54.916327: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:54.916389: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1894: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
rank_zero_warn(
[?25l
[2K[37mEpoch 0 [0m [38;5;237mββββββββββββββββββββββββββββββββββββββββ[0m [37m0/2[0m [38;5;245m0:00:00 β’ -:--:--[0m [38;5;249m0.00it/s[0m
[2K[37mEpoch 0 [0m [38;5;237mββββββββββββββββββββββββββββββββββββββββ[0m [37m0/2[0m [38;5;245m0:00:00 β’ -:--:--[0m [38;5;249m0.00it/s[0m [37m [0mpt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
Exception in device=TPU:7:
Exception in device=TPU:4:
Exception in device=TPU:2:
Exception in device=TPU:3:
Exception in device=TPU:1:
Exception in device=TPU:5:
Exception in device=TPU:6:
Traceback (most recent call last):
Traceback (most recent call last):
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
fn(gindex, *args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
fn(gindex, *args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
fn(rank, *_args)
Traceback (most recent call last):
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
fn(rank, *_args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
results = function(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
results = function(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
results = self._run_stage()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
results = self._run_stage()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
return self._run_train()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
return self._run_train()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
self.fit_loop.run()
Traceback (most recent call last):
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
fn(gindex, *args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
self.fit_loop.run()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
fn(rank, *_args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
Traceback (most recent call last):
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
Traceback (most recent call last):
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
results = function(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
fn(gindex, *args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
results = self._run_stage()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
fn(rank, *_args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
return self._run_train()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
outputs = self.manual_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
results = function(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
self.fit_loop.run()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
outputs = self.manual_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
fn(gindex, *args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
fn(gindex, *args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
results = self._run_stage()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
fn(rank, *_args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
fn(rank, *_args)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
return self._run_train()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
results = function(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
AssertionError
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
results = function(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
self.fit_loop.run()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
AssertionError
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
results = self._run_stage()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
results = self._run_stage()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
return self._run_train()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
outputs = self.manual_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
return self._run_train()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
self.fit_loop.run()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
self.fit_loop.run()
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
Traceback (most recent call last):
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
outputs = self.manual_loop.run(kwargs)
AssertionError
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
batch_output = self.batch_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
outputs = self.manual_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
outputs = self.manual_loop.run(kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, **kwargs)
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
AssertionError
File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import LinearLR
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import RichProgressBar
import torch_xla.core.xla_model as xm
from torch.utils.data.distributed import DistributedSampler
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.automatic_optimization = False
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self(batch).sum()
self.log("train_loss", loss, sync_dist=True, sync_dist_group=True, rank_zero_only=True)
self.manual_backward(loss)
opt.step()
sch = self.lr_schedulers()
sch.step()
# return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
scheduler = LinearLR(opt, start_factor=0.5, total_iters=4)
return [opt], [scheduler]
def run():
ds = RandomDataset(32, 64)
sampler = DistributedSampler(
ds, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
)
train_data = DataLoader(ds, batch_size=2)
val_data = DataLoader(ds, batch_size=2)
test_data = DataLoader(ds, batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
logger=CSVLogger("csvlogs"),
accelerator='tpu', devices=8,
callbacks=[RichProgressBar()],
strategy="tpu_spawn_debug"
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run()
export TPU_LOG_DIR="disabled"
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
python boring.py
Expected behavior
It runs.
Environment
* CUDA:
- GPU: None
- available: False
- version: 10.2
* Lightning:
- pytorch-lightning: 1.7.1
- torch: 1.11.0
- torch-xla: 1.11
- torchinfo: 1.7.0
- torchmetrics: 0.9.3
- torchvision: 0.12.0
* Packages:
- absl-py: 1.2.0
- aiohttp: 3.8.1
- aiosignal: 1.2.0
- argon2-cffi: 21.3.0
- argon2-cffi-bindings: 21.2.0
- astroid: 2.11.7
- asttokens: 2.0.5
- async-timeout: 4.0.2
- attrs: 22.1.0
- backcall: 0.2.0
- beautifulsoup4: 4.11.1
- bleach: 5.0.1
- cachetools: 4.2.4
- certifi: 2022.6.15
- cffi: 1.15.1
- charset-normalizer: 2.1.0
- cloud-tpu-client: 0.10
- cloud-tpu-profiler: 2.4.0
- commonmark: 0.9.1
- debugpy: 1.6.2
- decorator: 5.1.1
- defusedxml: 0.7.1
- dill: 0.3.5.1
- einops: 0.4.1
- entrypoints: 0.4
- executing: 0.9.1
- fastjsonschema: 2.16.1
- filelock: 3.7.1
- flake8: 5.0.4
- frozenlist: 1.3.1
- fsspec: 2022.7.1
- google-api-core: 1.32.0
- google-api-python-client: 1.8.0
- google-auth: 1.35.0
- google-auth-httplib2: 0.1.0
- google-auth-oauthlib: 0.4.6
- googleapis-common-protos: 1.56.4
- grpcio: 1.47.0
- httplib2: 0.20.4
- huggingface-hub: 0.8.1
- idna: 3.3
- importlib-metadata: 4.12.0
- importlib-resources: 5.9.0
- install: 1.3.5
- ipykernel: 6.15.1
- ipython: 8.4.0
- ipython-genutils: 0.2.0
- ipywidgets: 7.7.1
- isort: 5.10.1
- jedi: 0.18.1
- jinja2: 3.1.2
- jsonschema: 4.9.1
- jupyter-client: 7.3.4
- jupyter-core: 4.11.1
- jupyterlab-pygments: 0.2.2
- jupyterlab-widgets: 1.1.1
- lazy-object-proxy: 1.7.1
- libtpu-nightly: 0.1.dev20220303
- markdown: 3.4.1
- markupsafe: 2.1.1
- matplotlib-inline: 0.1.3
- mccabe: 0.7.0
- mistune: 0.8.4
- multidict: 6.0.2
- nbclient: 0.6.6
- nbconvert: 6.5.0
- nbformat: 5.4.0
- nest-asyncio: 1.5.5
- notebook: 6.4.12
- numpy: 1.23.1
- oauth2client: 4.1.3
- oauthlib: 3.2.0
- packaging: 21.3
- pandocfilters: 1.5.0
- parso: 0.8.3
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.2.0
- pip: 22.1.2
- pkgutil-resolve-name: 1.3.10
- platformdirs: 2.5.2
- prometheus-client: 0.14.1
- prompt-toolkit: 3.0.30
- protobuf: 3.19.4
- psutil: 5.9.1
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycodestyle: 2.9.1
- pycparser: 2.21
- pydeprecate: 0.3.2
- pyflakes: 2.5.0
- pygments: 2.12.0
- pylint: 2.14.5
- pyparsing: 3.0.9
- pyrsistent: 0.18.1
- python-dateutil: 2.8.2
- pytorch-lightning: 1.7.1
- pytz: 2022.1
- pyyaml: 6.0
- pyzmq: 23.2.0
- regex: 2022.7.25
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- rich: 12.5.1
- rsa: 4.9
- send2trash: 1.8.0
- sentencepiece: 0.1.97
- setuptools: 61.2.0
- six: 1.16.0
- soupsieve: 2.3.2.post1
- stack-data: 0.3.0
- tablign: 0.3.4
- tensorboard: 2.10.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.1
- terminado: 0.15.0
- tinycss2: 1.1.1
- tokenizers: 0.12.1
- tomli: 2.0.1
- tomlkit: 0.11.1
- torch: 1.11.0
- torch-xla: 1.11
- torchinfo: 1.7.0
- torchmetrics: 0.9.3
- torchvision: 0.12.0
- tornado: 6.2
- tqdm: 4.64.0
- traitlets: 5.3.0
- transformers: 4.21.1
- typing-extensions: 4.3.0
- uritemplate: 3.0.1
- urllib3: 1.26.11
- wcwidth: 0.2.5
- webencodings: 0.5.1
- werkzeug: 2.2.2
- wheel: 0.37.1
- widgetsnbextension: 3.6.1
- yarl: 1.8.1
- zipp: 3.8.1
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.13
- version: #46-Ubuntu SMP Mon Apr 19 19:17:04 UTC 2021
Additional context
I runs on Google TPU VM 3.8
Issue Analytics
- State:
- Created a year ago
- Comments:11 (6 by maintainers)
Top Results From Across the Web
Optimization β PyTorch Lightning 1.8.5.post0 documentation
However, it won't work on TPU, AMP, etc... optimizer ... Here is an example training a simple GAN with multiple optimizers using manual...
Read more >Rami/multi-label-class-github-issues-text-classification
I wonder if the issue is related to Tensorboard versioning, however pip install pytorch-lightning doesn't seem to install any new tensorboard versionΒ ...
Read more >PyTorch Lightning
TPUs or GPUs, without code changes. Want to train on multiple GPUs? TPUs? Determine your hardware on the go. Change one trainer param...
Read more >Scaling deep learning workloads with PyTorch / XLA and ...
This problem exists when we store training data on a local disk and ... TPU profiler tool to further optimize PyTorch / XLA...
Read more >pytorch-lightning
PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. ... Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs,Β ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

@awaelchli I just tried master 2022.08.22. Yes it works.
Correction: I no longer see these failures from a few weeks ago. All tests are passing. @dinhanhx maybe it is worth trying master?