import random from collections import deque import itertools import numpy as np class ReplayBuffer(object): def __init__(self, buffer_size, random_seed=123): """ The right side of the deque contains the most recent experiences """ self.buffer_size = buffer_size self.count = 0 self.buffer = deque() random.seed(random_seed) def add(self, s, a, r, t, s2): experience = (s, a, r, t, s2) if self.count < self.buffer_size: self.buffer.append(experience) self.count += 1 else: self.buffer.popleft() self.buffer.append(experience) def size(self): return self.count def sample_batch(self, batch_size): if self.count < batch_size: batch = random.sample(self.buffer, self.count) else: batch = random.sample(self.buffer, batch_size) s_batch = np.array([_[0] for _ in batch]) a_batch = np.array([_[1] for _ in batch]) r_batch = np.array([_[2] for _ in batch]).reshape(-1, 1) t_batch = np.array([_[3] for _ in batch]).reshape(-1, 1) s2_batch = np.array([_[4] for _ in batch]) return s_batch, a_batch, r_batch, t_batch, s2_batch def return_buffer(self): s = np.array([_[0] for _ in self.buffer]) a = np.array([_[1] for _ in self.buffer]) r = np.array([_[2] for _ in self.buffer]).reshape(-1, 1) t = np.array([_[3] for _ in self.buffer]).reshape(-1, 1) s2 = np.array([_[4] for _ in self.buffer]) return s, a, r, t, s2 def clear(self): self.buffer.clear() self.count = 0 class RolloutReplayBuffer(object): def __init__(self, buffer_size, random_seed=123, history_len=10): """ The right side of the deque contains the most recent experiences """ self.buffer_size = buffer_size self.count = 0 self.buffer = deque(maxlen=buffer_size) random.seed(random_seed) self.buffer.append([]) self.history_len = history_len def add(self, s, a, r, t, s2): experience = (s, a, r, t, s2) if t: self.count += 1 self.buffer[-1].append(experience) self.buffer.append([]) else: self.buffer[-1].append(experience) def size(self): return self.count def sample_batch(self, batch_size): if self.count < batch_size: batch = random.sample( list(itertools.islice(self.buffer, 0, len(self.buffer) - 1)), self.count ) else: batch = random.sample( list(itertools.islice(self.buffer, 0, len(self.buffer) - 1)), batch_size ) idx = [random.randint(0, len(b) - 1) for b in batch] s_batch = [] s2_batch = [] for i in range(len(batch)): if idx[i] == len(batch[i]): s = batch[i] s2 = batch[i] else: s = batch[i][: idx[i] + 1] s2 = batch[i][: idx[i] + 1] s = [v[0] for v in s] s = s[::-1] s2 = [v[4] for v in s2] s2 = s2[::-1] if len(s) < self.history_len: missing = self.history_len - len(s) s += [s[-1]] * missing s2 += [s2[-1]] * missing else: s = s[: self.history_len] s2 = s2[: self.history_len] s = s[::-1] s_batch.append(s) s2 = s2[::-1] s2_batch.append(s2) a_batch = np.array([batch[i][idx[i]][1] for i in range(len(batch))]) r_batch = np.array([batch[i][idx[i]][2] for i in range(len(batch))]).reshape( -1, 1 ) t_batch = np.array([batch[i][idx[i]][3] for i in range(len(batch))]).reshape( -1, 1 ) return np.array(s_batch), a_batch, r_batch, t_batch, np.array(s2_batch) def return_buffer(self): s = np.array([_[0] for _ in self.buffer]) a = np.array([_[1] for _ in self.buffer]) r = np.array([_[2] for _ in self.buffer]).reshape(-1, 1) t = np.array([_[3] for _ in self.buffer]).reshape(-1, 1) s2 = np.array([_[4] for _ in self.buffer]) return s, a, r, t, s2 def clear(self): self.buffer.clear() self.count = 0