Can't save a full model using torch.save (at least with faster-RCNN)
See original GitHub issueIt is not possible to save a full model using default settings of torch.save
(see stack trace below). This is because of the implementation of remove_internal_model_transforms
, which uses inner functions in its implementation. The default pickle module does not support inner functions.
Workaround: use the dill
module instead, which does support inner functions.
Suggested fix: It does not look as if the internal functions are necessary. If there were moved to standard functions, then the default pickle module should work.
torch.save(model, 'mod.pth', pickle_module=pickle)
causes an error.
torch.save(model, 'mod.pth', pickle_module=dill)
is a workaround.
To Reproduce
torch.save(model, 'mod1-full.pth', pickle_module=pickle)
results in:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-12-50f3761f4f3c> in <module>
----> 1 torch.save(model, 'mod1-full.pth', pickle_module=pickle)
~/anaconda3/envs/dlm/lib/python3.8/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
370 if _use_new_zipfile_serialization:
371 with _open_zipfile_writer(opened_file) as opened_zipfile:
--> 372 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
373 return
374 _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
~/anaconda3/envs/dlm/lib/python3.8/site-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol)
474 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
475 pickler.persistent_id = persistent_id
--> 476 pickler.dump(obj)
477 data_value = data_buf.getvalue()
478 zip_file.write_record('data.pkl', data_value, len(data_value))
AttributeError: Can't pickle local object 'remove_internal_model_transforms.<locals>.noop_normalize'
Relevant definition:
def remove_internal_model_transforms(model: GeneralizedRCNN):
def noop_normalize(image: Tensor) -> Tensor:
return image
def noop_resize(
image: Tensor, target: Optional[Dict[str, Tensor]]
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
return image, target
model.transform.normalize = noop_normalize
model.transform.resize = noop_resize
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (1 by maintainers)
Top GitHub Comments
I just ran into this issue in icevision 0.8.0. While actually saving worked, loading didn’t. I was using model_type = models.mmdet.retinanet backbone = model_type.backbones.resnet50_fpn_1x
The workaround with dill worked for me.
We put together a new inference API. This works now.