Current ONNX opset doesn't support StableBaselines3 natively, requires creating a wrapper class.
See original GitHub issueI am interested in using stable-baselines
to train an agent, and then export it through ONNX to put it inside the Unity engine via Barracuda. I was hoping to write up the documentation too!
Unfortunately the opset 9 or opset 12 in ONNX doesn’t seem to support converting trained policies.
RuntimeError: Exporting the operator broadcast_tensors to ONNX opset version 9 is not supported.
Please open a bug to request ONNX export support for the missing operator.
While the broadcast_tensor isn’t something explicitly called in the codebase it potentially might be related to using torch.distributions. Unfortunately, this seems to be an open issue since 2019 November, so I am pessimistic about it being solved soon.
While very unlikely, do you think there might be a way around this? Either way, I wanted to raise this issue so the team is aware.
Checklist
- I have read the documentation (required)
- I have checked that there is no similar issue in the repo (required)
Issue Analytics
- State:
- Created 2 years ago
- Comments:19 (8 by maintainers)
Top Results From Across the Web
Exporting models — Stable Baselines3 1.7.0a5 documentation
Stable Baselines3 does not include tools to export models to other frameworks, but this document aims to cover parts that are required for...
Read more >Supported scikit-learn Models - ONNX
Name Package Supported
ARDRegression linear_model Yes
AdaBoostClassifier ensemble Yes
AdaBoostRegressor ensemble Yes
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
The “after the fact” instantiate option. I’m still not sure where exactly the broadcast is but I guess its not in these 3 modules :p
thanks =), but it seems that you are passing randomly initialized features extractor instead of the trained ones:
model.policy.make_features_extractor()
instead ofmodel.policy.actor.features_extractor