question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] SAC Policy with loaded weights returns incorrect value

See original GitHub issue

After loading external weights into an SAC policy (latent_pi and mu), expected policy output does not match the output from a test model loaded with the same weights (that’s not using SAC). I measured these results from sac\policies::get_action_dist_params to avoid having log_std affect the results.

Looking at it in my debugger, it looks like its correctly placing the values into latent_pi and mu but the output isn’t quite what it should be. Searching online suggests some people have seen occasional problems loading/saving pytorch models, not sure if that’s what’s happening here.

input:

tensor([[ 0.0000e+00,  4.6363e-01,  2.8770e-01,  3.3224e-02,  6.4662e-01,
         -3.1398e-01, -1.4604e-01,  4.7870e-04, -7.0000e-05, -2.0692e-02,
         -5.1917e-01,  5.2600e-01,  3.4382e-03,  0.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00, -5.5713e-01, -3.4492e-01,
          4.4638e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2450e-01,
         -5.3079e-01,  4.4315e-02, -5.0478e-04, -4.4743e-03, -1.7209e-02]],
       device='cuda:0')

expected output:

tensor([[ 0.2665, -0.0349, -0.1470,  0.0130, -0.1034,  0.0865,  0.2703,  0.0194]],
       device='cuda:0', grad_fn=<TanhBackward>)

output:

tensor([[ 0.2731, -0.0350, -0.1481,  0.0131, -0.1038,  0.0867,  0.2772,  0.0194]],
       device='cuda:0')

Network structure:

external model torch.nn.Sequential( torch.nn.Linear(30, 512), torch.nn.ReLU(), torch.nn.Linear(512, 512), torch.nn.ReLU(), torch.nn.Linear(512, 512), torch.nn.ReLU(), torch.nn.Linear(512, 8), torch.nn.Tanh())

SAC

policy_kwargs = dict(net_arch=dict(pi=[512, 512, 512], qf=[256, 256]))

(input is 30, output is 8)

model weights loaded and renamed to match SAC naming convention:

bestWeights.zip

checkpoint = torch.load(unzippedWeights)

#hold modifed checkpoint data for loading
piCheckpoint, muCheckpoint = renameCheckpointKeys(checkpoint)
model.actor.latent_pi.load_state_dict(piCheckpoint)
model.actor.mu.load_state_dict(muCheckpoint)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
araffincommented, Jan 21, 2021

If that’s the case though, why was the mean_action above not giving the correct result?

not sure what you mean… The deterministic output is the mean of the Gaussian squashed using tanh.

Regardless, the agent doesn’t seem to be properly taking advantage of the effective pre-trained strategy as the non-determinism destroys its useful behavior.

This is not an issue of stochasticity. You may take a look at that paper: https://arxiv.org/abs/2006.09359 Where they show that pre-training with behavior cloning does not really help (because you will get into the same issue of offline RL as soon as you start training). You may give residual RL a try though 😉

Closing this as this original issue seems to be resolved.

0reactions
DanielDownscommented, Jan 20, 2021

Ah, using the model.predict() yields a correct output:

[[ 0.26650167 -0.0349468 -0.1470437 0.01304841 -0.10342735 0.08647192, 0.2703234 0.01935792]]

So it appears the non-determinism is the factor here. If that’s the case though, why was the mean_action above not giving the correct result?

Regardless, the agent doesn’t seem to be properly taking advantage of the effective pre-trained strategy as the non-determinism destroys its useful behavior. Maybe I need to try TD3 (deterministic) or figure out a way to play with log_std so it has decent behavior.

Read more comments on GitHub >

github_iconTop Results From Across the Web

On saving and loading - Stable Baselines3 - Read the Docs
This can be done via custom_objects argument to load functions. Pros: More robust to unserializable objects (one bad object does not break everything)....
Read more >
Why Is My Load Cell Inaccurate? 11 Problems and Solutions ...
11 Problems and Solutions for Troubleshooting Load Cells · 1. Total Combined Error. All measurement devices will have some degree of error that ......
Read more >
Rllib OfflineData preparation for SAC - python - Stack Overflow
What I understand from the error message is that SAC is expecting some 'weights' (and some time 't'?!) beside the experiences that were ......
Read more >
sacsarlm: Spatial simultaneous autoregressive SAC model ...
exclude with consequences for residuals and fitted values - in these cases the weights list will be subsetted to remove NAs in the...
Read more >
The 37 Implementation Details of Proximal Policy Optimization
[Math Processing Error] T D ( λ ) return estimation (ppo2/runner.py#L65): PPO implements the return target as returns = advantages + values , ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found