TD3-robot-nav-irsim / replay_buffer.py
TaherFattahi's picture
init: td3 robot nav
a54010a
import random
from collections import deque
import itertools
import numpy as np
class ReplayBuffer(object):
def __init__(self, buffer_size, random_seed=123):
"""
The right side of the deque contains the most recent experiences
"""
self.buffer_size = buffer_size
self.count = 0
self.buffer = deque()
random.seed(random_seed)
def add(self, s, a, r, t, s2):
experience = (s, a, r, t, s2)
if self.count < self.buffer_size:
self.buffer.append(experience)
self.count += 1
else:
self.buffer.popleft()
self.buffer.append(experience)
def size(self):
return self.count
def sample_batch(self, batch_size):
if self.count < batch_size:
batch = random.sample(self.buffer, self.count)
else:
batch = random.sample(self.buffer, batch_size)
s_batch = np.array([_[0] for _ in batch])
a_batch = np.array([_[1] for _ in batch])
r_batch = np.array([_[2] for _ in batch]).reshape(-1, 1)
t_batch = np.array([_[3] for _ in batch]).reshape(-1, 1)
s2_batch = np.array([_[4] for _ in batch])
return s_batch, a_batch, r_batch, t_batch, s2_batch
def return_buffer(self):
s = np.array([_[0] for _ in self.buffer])
a = np.array([_[1] for _ in self.buffer])
r = np.array([_[2] for _ in self.buffer]).reshape(-1, 1)
t = np.array([_[3] for _ in self.buffer]).reshape(-1, 1)
s2 = np.array([_[4] for _ in self.buffer])
return s, a, r, t, s2
def clear(self):
self.buffer.clear()
self.count = 0
class RolloutReplayBuffer(object):
def __init__(self, buffer_size, random_seed=123, history_len=10):
"""
The right side of the deque contains the most recent experiences
"""
self.buffer_size = buffer_size
self.count = 0
self.buffer = deque(maxlen=buffer_size)
random.seed(random_seed)
self.buffer.append([])
self.history_len = history_len
def add(self, s, a, r, t, s2):
experience = (s, a, r, t, s2)
if t:
self.count += 1
self.buffer[-1].append(experience)
self.buffer.append([])
else:
self.buffer[-1].append(experience)
def size(self):
return self.count
def sample_batch(self, batch_size):
if self.count < batch_size:
batch = random.sample(
list(itertools.islice(self.buffer, 0, len(self.buffer) - 1)), self.count
)
else:
batch = random.sample(
list(itertools.islice(self.buffer, 0, len(self.buffer) - 1)), batch_size
)
idx = [random.randint(0, len(b) - 1) for b in batch]
s_batch = []
s2_batch = []
for i in range(len(batch)):
if idx[i] == len(batch[i]):
s = batch[i]
s2 = batch[i]
else:
s = batch[i][: idx[i] + 1]
s2 = batch[i][: idx[i] + 1]
s = [v[0] for v in s]
s = s[::-1]
s2 = [v[4] for v in s2]
s2 = s2[::-1]
if len(s) < self.history_len:
missing = self.history_len - len(s)
s += [s[-1]] * missing
s2 += [s2[-1]] * missing
else:
s = s[: self.history_len]
s2 = s2[: self.history_len]
s = s[::-1]
s_batch.append(s)
s2 = s2[::-1]
s2_batch.append(s2)
a_batch = np.array([batch[i][idx[i]][1] for i in range(len(batch))])
r_batch = np.array([batch[i][idx[i]][2] for i in range(len(batch))]).reshape(
-1, 1
)
t_batch = np.array([batch[i][idx[i]][3] for i in range(len(batch))]).reshape(
-1, 1
)
return np.array(s_batch), a_batch, r_batch, t_batch, np.array(s2_batch)
def return_buffer(self):
s = np.array([_[0] for _ in self.buffer])
a = np.array([_[1] for _ in self.buffer])
r = np.array([_[2] for _ in self.buffer]).reshape(-1, 1)
t = np.array([_[3] for _ in self.buffer]).reshape(-1, 1)
s2 = np.array([_[4] for _ in self.buffer])
return s, a, r, t, s2
def clear(self):
self.buffer.clear()
self.count = 0