Proper way to re-load a replay buffer, for resuming training (of agents that use a replay buffer)
See original GitHub issueHi @astooke I am trying to write some code that will show how to save and load models for resuming training (not just for inference). Assuming we have algorithms that use replay buffers, this may require more coding because we need to also load the replay buffer used at the time we saved it.
The most naive way to save and load a replay buffer does not seem to work. I cloned the repo, made these changes:
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ git diff
diff --git a/linux_cuda10.yml b/linux_cuda10.yml
index fb73927..68ea0c2 100644
--- a/linux_cuda10.yml
+++ b/linux_cuda10.yml
@@ -1,4 +1,4 @@
-name: rlpyt
+name: rlpyt-astooke
channels:
- pytorch
dependencies:
diff --git a/rlpyt/runners/minibatch_rl.py b/rlpyt/runners/minibatch_rl.py
index 6850b93..7dd4741 100644
--- a/rlpyt/runners/minibatch_rl.py
+++ b/rlpyt/runners/minibatch_rl.py
@@ -296,6 +296,18 @@ class MinibatchRlEval(MinibatchRlBase):
specified log interval.
"""
n_itr = self.startup()
+
+ import pickle, sys
+ replay_buffer = self.algo.replay_buffer
+ print('saving buffer ...')
+ with open('buffer.pkl', 'wb') as fh:
+ pickle.dump(replay_buffer, fh, protocol=4)
+ print('done with saving buffer. now let us load ...')
+ with open('buffer.pkl', 'rb') as fh:
+ replay_buffer = pickle.load(fh)
+ print('buffer loaded')
+ sys.exit()
+
with logger.prefix(f"itr #0 "):
eval_traj_infos, eval_time = self.evaluate_agent(0)
self.log_diagnostics(0, eval_traj_infos, eval_time)
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $
and then ran the script. It seems to have saved and loaded:
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ python examples/example_5.py
logger_context received log_dir outside of /home/seita/rlpyt_astooke/data: prepending by /home/seita/rlpyt_astooke/data/local/<yyyymmdd>/
2020-02-25 12:48:50.955948 | dqn_pong_0 Runner master CPU affinity: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15].
2020-02-25 12:48:50.956036 | dqn_pong_0 Runner master Torch threads: 8.
using seed 4803
2020-02-25 12:48:50.958320 | dqn_pong_0 Total parallel evaluation envs: 10.
using seed 4805
using seed 4804
2020-02-25 12:48:51.552859 | dqn_pong_0 Sampler rank 1 initialized, CPU affinity [1], Torch threads 1, Seed 4805
2020-02-25 12:48:51.552953 | dqn_pong_0 Sampler rank 0 initialized, CPU affinity [0], Torch threads 1, Seed 4804
2020-02-25 12:48:54.984253 | dqn_pong_0 Sampler decorrelating envs, max steps: 0
2020-02-25 12:48:55.877700 | dqn_pong_0 Running 750000 iterations of minibatch RL.
2020-02-25 12:48:55.878391 | dqn_pong_0 From sampler batch size 64, training batch size 128, and replay ratio 8, computed 4 updates per iteration.
2020-02-25 12:48:55.878459 | dqn_pong_0 Agent setting min/max epsilon itrs: 781, 15625
2020-02-25 12:48:55.880199 | dqn_pong_0 Frame-based buffer using 4-frame sequences.
saving buffer ...
done with saving buffer. now let us load ...
buffer loaded
the buffer is here and is 16G in size (it’s pre-allocated):
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ ls -lh buffer.pkl
-rw-rw-r-- 1 seita seita 16G Feb 25 12:49 buffer.pkl
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $
However I cannot do the same re-loading if I am not in the same exact python call. For example:
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $ ipython
Python 3.7.5 (default, Oct 25 2019, 15:51:11)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.10.2 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import pickle
In [2]: with open('buffer.pkl', 'rb') as fh:
...: buffer = pickle.load(fh)
...:
---------------------------------------------------------------------------
UnpicklingError Traceback (most recent call last)
<ipython-input-2-97de235d15c6> in <module>
1 with open('buffer.pkl', 'rb') as fh:
----> 2 buffer = pickle.load(fh)
3
UnpicklingError: NEWOBJ class argument isn't a type object
In [3]: exit
(rlpyt-astooke) seita@stout:~/rlpyt_astooke (master) $
I believe this has to do with the general difficulties of saving entire classes through pickle files. (The torch save/load, which use pickles, suggest something similar here.)
As of now it seems like the way I should save and load replay buffers, at a high level, would be:
- Save the arguments used to initialize the replay buffer.
- Save the samples from the replay buffer separately
- Then in the new script, recreate the class with the same arguments.
- Load the samples and explicitly assign the samples to that newly created class.
I will need to double check if there are variables we need to track that might depend on time (e.g., prioritization terms). If you have any advice that would be great. Thanks!
Issue Analytics
- State:
- Created 4 years ago
- Comments:12 (12 by maintainers)
Top GitHub Comments
AFAICT, yes, but my case is slightly different from Daniels.
Ok really good question and I haven’t tried this before.
One possible difficulty is in the use of the
namedarraytuple
which uses class definitions that can lead to some pickling difficulties, recently talked about in #99.A solution was suggested in that issue and I just posted a first attempt at building that out…check out the namedtuple_schema branch. It has classes (https://github.com/astooke/rlpyt/blob/namedtuple_schema/rlpyt/utils/namedtuple_schema.py) for making objects that behave like namedarraytuples without requiring the class definition. Hopefully not too painful to grab that and replace every
X=namedtuple(...)
andY=namedarraytuple(..)
definition withX=NamedTupleSchema(...)
andY=NamedArrayTupleSchema(...)
. Then retry the pickle test? Hopefully this would be the most universal solution.I did work previously to make sure
namedarraytuple
can be pickled/unpickled, for example by always defining them at the module level and never nested inside another class…but I may have fallen short.EDIT: oh you might also need to go into
buffer_from_example()
and prevent it from usingnamedarraytuple_like()
, but have it do theNamedArrayTuple
equivalent (which I haven’t written yet)