Maybe there is one problem in implementing the class PrioritizedReplayBuffer
See original GitHub issueIn the file https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/buffers.py,
(line 206) total = self._it_sum.sum(0, len(self._storage) - 1)
Use the above code to compute the total priorities and set param end
of function self._it_sum.sum
to len(self._storage) - 1
.
def _sample_proportional(self, batch_size):
mass = []
total = self._it_sum.sum(0, len(self._storage) - 1)
# TODO(szymon): should we ensure no repeats?
mass = np.random.random(size=batch_size) * total
idx = self._it_sum.find_prefixsum_idx(mass)
return idx
But in the file https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/segment_tree.py,
(line 75) the code end -= 1
in the function reduce
which is called by the above function self._it_sum.sum
also subtract by 1.
def reduce(self, start=0, end=None):
"""
Returns result of applying `self.operation`
to a contiguous subsequence of the array.
self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
:param start: (int) beginning of the subsequence
:param end: (int) end of the subsequences
:return: (Any) result of reducing self.operation over the specified range of array elements.
"""
if end is None:
end = self._capacity
if end < 0:
end += self._capacity
end -= 1
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
Has it been repeatedly subtracted by 1?
I simply verified my idea with the following code.
from stable_baselines.common.buffers import PrioritizedReplayBuffer
buffer = PrioritizedReplayBuffer(100, 0.6)
x = np.array([1.])
for _ in range(10):
buffer.add(x, x, x, x, x)
print(buffer._it_sum.sum(0, len(buffer._storage-1)))#result:9.0
print(buffer._it_sum.sum(0, len(buffer._storage)))#result:10.0
If changing len(buffer._storage-1)
to len(buffer._storage)
, I can get the correct result.
Because I add 10 new data into the buffer, the total priorities I think should be 10.
If I misunderstood the code, please let me know.
Issue Analytics
- State:
- Created 3 years ago
- Comments:9
Top GitHub Comments
@Jogima-cyber
Dare I say most of it is working, except the last added sample is not included in the random sampling process. Given the number of samples in buffer this is seems like a minuscule error (which still should be fixed!), but I can not say for sure if the effect on learning is small.
@UPUPGOO
Any update on PR for this? I am asking to check if somebody is working on this and wants to make a PR out of it. If not, I can add it.
By the way, the codes for calculating weights in the function
sample
of the classPrioritizedReplayBuffer
can be simplified.change
to
This can be derived by simple mathematical derivation.
I also did some experiments to verify this with the following code.
The result is all 0. The original codes are more like the formula of the paper, but the simplified codes I think are much faster.