Error when saving hyperparameters with torch.dtype as parameter
See original GitHub issue🐛 Bug
When a torch.dtype
object is passed into a LightningModule constructor, and then self.save_hyperparameters()
is called, the code errors out with ValueError: dictionary update sequence element #0 has length 1; 2 is required
.
To Reproduce
I have reproduced this issue in this BoringModel.
- Create a LightningModule that takes in a
torch.dtype
in it’s constructor. - Add
self.save_hyperparameter()
within the__init__()
. - Create a trainer and run
trainer.fit()
Expected behavior
The hyperparameters should be saved and the model should continue without error.
Environment
From the Collab:
- CUDA:
- GPU:
- Tesla T4
- available: True
- version: 11.3
- GPU:
- Packages:
- numpy: 1.21.6
- pyTorch_debug: False
- pyTorch_version: 1.11.0+cu113
- pytorch-lightning: 1.6.3
- tqdm: 4.64.0
- System:
- OS: Linux
- architecture:
- 64bit
- processor: x86_64
- python: 3.7.13
- version: No.1 SMP Sun Apr 24 10:03:06 PDT 2022
My own machine:
- PyTorch Lightning Version (e.g., 1.5.0): 1.6.3
- PyTorch Version (e.g., 1.10): 1.11.0
- Python version (e.g., 3.9): 3.8
- OS (e.g., Linux): Linux
- CUDA/cuDNN version: 11.6
- GPU models and configuration: GTX 1080
- How you installed PyTorch (
conda
,pip
, source): conda - If compiling from source, the output of
torch.__config__.show()
: - Any other relevant information:
- torchmetrics version: 0.8.2
Additional context
This seems to be a similar to Issue #9318, but that seems to be marked as solved.
This issue from metrics may also be related, but that issue has already been addressed and a patch merged.
Here is the trace from the collab, for reference.
[<ipython-input-6-66d757539e81>](https://localhost:8080/#) in run()
14 enable_model_summary=False,
15 )
---> 16 trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
17 trainer.test(model, dataloaders=test_data)
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
767 self.strategy.model = model
768 self._call_and_handle_interrupt(
--> 769 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
770 )
771
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
719 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
720 else:
--> 721 return trainer_fn(*args, **kwargs)
722 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
723 except KeyboardInterrupt as exception:
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
807 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
808 )
--> 809 results = self._run(model, ckpt_path=self.ckpt_path)
810
811 assert self.state.stopped
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
1220 self._call_lightning_module_hook("on_fit_start")
1221
-> 1222 self._log_hyperparams()
1223
1224 if self.strategy.restore_checkpoint_after_setup:
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _log_hyperparams(self)
1290 logger.log_hyperparams(hparams_initial)
1291 logger.log_graph(self.lightning_module)
-> 1292 logger.save()
1293
1294 def _teardown(self):
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/rank_zero.py](https://localhost:8080/#) in wrapped_fn(*args, **kwargs)
30 def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
31 if rank_zero_only.rank == 0:
---> 32 return fn(*args, **kwargs)
33 return None
34
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loggers/tensorboard.py](https://localhost:8080/#) in save(self)
264 # save the metatags file if it doesn't exist and the log directory exists
265 if self._fs.isdir(dir_path) and not self._fs.isfile(hparams_file):
--> 266 save_hparams_to_yaml(hparams_file, self.hparams)
267
268 @rank_zero_only
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/saving.py](https://localhost:8080/#) in save_hparams_to_yaml(config_yaml, hparams, use_omegaconf)
400 try:
401 v = v.name if isinstance(v, Enum) else v
--> 402 yaml.dump(v)
403 except TypeError:
404 warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
[/usr/local/lib/python3.7/dist-packages/yaml/__init__.py](https://localhost:8080/#) in dump(data, stream, Dumper, **kwds)
251 If stream is None, return the produced string instead.
252 """
--> 253 return dump_all([data], stream, Dumper=Dumper, **kwds)
254
255 def safe_dump_all(documents, stream=None, **kwds):
[/usr/local/lib/python3.7/dist-packages/yaml/__init__.py](https://localhost:8080/#) in dump_all(documents, stream, Dumper, default_style, default_flow_style, canonical, indent, width, allow_unicode, line_break, encoding, explicit_start, explicit_end, version, tags, sort_keys)
239 dumper.open()
240 for data in documents:
--> 241 dumper.represent(data)
242 dumper.close()
243 finally:
[/usr/local/lib/python3.7/dist-packages/yaml/representer.py](https://localhost:8080/#) in represent(self, data)
25
26 def represent(self, data):
---> 27 node = self.represent_data(data)
28 self.serialize(node)
29 self.represented_objects = {}
[/usr/local/lib/python3.7/dist-packages/yaml/representer.py](https://localhost:8080/#) in represent_data(self, data)
50 for data_type in data_types:
51 if data_type in self.yaml_multi_representers:
---> 52 node = self.yaml_multi_representers[data_type](self, data)
53 break
54 else:
[/usr/local/lib/python3.7/dist-packages/yaml/representer.py](https://localhost:8080/#) in represent_object(self, data)
328 listitems = list(listitems)
329 if dictitems is not None:
--> 330 dictitems = dict(dictitems)
331 if function.__name__ == '__newobj__':
332 function = args[0]
ValueError: dictionary update sequence element #0 has length 1; 2 is required
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Error in Pytorch ax-platform hyperparameters tuning
I was trying to tune the hyperparameters of the optimizer of my neural network, using the code below, this is my custom dataset...
Read more >How to use custom data and implement custom models and ...
How to use custom data and implement custom models and metrics#. Building a new model in PyTorch Forecasting is relatively easy. Many things...
Read more >TorchScript Language Reference - PyTorch
To specify that an argument to a TorchScript function is another type, it is possible to use MyPy-style type annotations using the types...
Read more >Models - Hugging Face
PreTrainedModel takes care of storing the configuration of the models and ... your model is to instantiate it at a lower precision dtype...
Read more >Aman's AI Journal • Primers • PyTorch
torch.tensor() supports the dtype argument, if you would like to change the ... If this errors out, check that you downloaded the right...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hi!
This is a problem upstream. I suggest you open an issue in https://github.com/pytorch/pytorch:
One improvement would be to also catch
ValueError
here: https://github.com/PyTorchLightning/pytorch-lightning/blob/dd475183227644a8d22dca3deb18c99fb0a9b2c4/pytorch_lightning/core/saving.py#L427 but It wouldn’t get saved anywaysClosing. Please track https://github.com/pytorch/pytorch/issues/78720 instead