[Bug] SAC Policy with loaded weights returns incorrect value
See original GitHub issueAfter loading external weights into an SAC policy (latent_pi and mu), expected policy output does not match the output from a test model loaded with the same weights (that’s not using SAC). I measured these results from sac\policies::get_action_dist_params to avoid having log_std affect the results.
Looking at it in my debugger, it looks like its correctly placing the values into latent_pi and mu but the output isn’t quite what it should be. Searching online suggests some people have seen occasional problems loading/saving pytorch models, not sure if that’s what’s happening here.
input:
tensor([[ 0.0000e+00, 4.6363e-01, 2.8770e-01, 3.3224e-02, 6.4662e-01,
-3.1398e-01, -1.4604e-01, 4.7870e-04, -7.0000e-05, -2.0692e-02,
-5.1917e-01, 5.2600e-01, 3.4382e-03, 0.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00, -5.5713e-01, -3.4492e-01,
4.4638e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2450e-01,
-5.3079e-01, 4.4315e-02, -5.0478e-04, -4.4743e-03, -1.7209e-02]],
device='cuda:0')
expected output:
tensor([[ 0.2665, -0.0349, -0.1470, 0.0130, -0.1034, 0.0865, 0.2703, 0.0194]],
device='cuda:0', grad_fn=<TanhBackward>)
output:
tensor([[ 0.2731, -0.0350, -0.1481, 0.0131, -0.1038, 0.0867, 0.2772, 0.0194]],
device='cuda:0')
Network structure:
external model
torch.nn.Sequential( torch.nn.Linear(30, 512), torch.nn.ReLU(), torch.nn.Linear(512, 512), torch.nn.ReLU(), torch.nn.Linear(512, 512), torch.nn.ReLU(), torch.nn.Linear(512, 8), torch.nn.Tanh())
SAC
policy_kwargs = dict(net_arch=dict(pi=[512, 512, 512], qf=[256, 256]))
(input is 30, output is 8)
model weights loaded and renamed to match SAC naming convention:
checkpoint = torch.load(unzippedWeights)
#hold modifed checkpoint data for loading
piCheckpoint, muCheckpoint = renameCheckpointKeys(checkpoint)
model.actor.latent_pi.load_state_dict(piCheckpoint)
model.actor.mu.load_state_dict(muCheckpoint)
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (3 by maintainers)
Top GitHub Comments
not sure what you mean… The deterministic output is the mean of the Gaussian squashed using tanh.
This is not an issue of stochasticity. You may take a look at that paper: https://arxiv.org/abs/2006.09359 Where they show that pre-training with behavior cloning does not really help (because you will get into the same issue of offline RL as soon as you start training). You may give residual RL a try though 😉
Closing this as this original issue seems to be resolved.
Ah, using the model.predict() yields a correct output:
[[ 0.26650167 -0.0349468 -0.1470437 0.01304841 -0.10342735 0.08647192, 0.2703234 0.01935792]]
So it appears the non-determinism is the factor here. If that’s the case though, why was the mean_action above not giving the correct result?
Regardless, the agent doesn’t seem to be properly taking advantage of the effective pre-trained strategy as the non-determinism destroys its useful behavior. Maybe I need to try TD3 (deterministic) or figure out a way to play with log_std so it has decent behavior.