from typing import List from tqdm import tqdm import yaml from replay_buffer import ReplayBuffer, RolloutReplayBuffer class Pretraining: def __init__( self, file_names: List[str], model: object, replay_buffer: object, reward_function, ): self.file_names = file_names self.model = model self.replay_buffer = replay_buffer self.reward_function = reward_function def load_buffer(self): for file_name in self.file_names: print("Loading file: ", file_name) with open(file_name, "r") as file: samples = yaml.full_load(file) for i in tqdm(range(1, len(samples) - 1)): sample = samples[i] latest_scan = sample["latest_scan"] distance = sample["distance"] cos = sample["cos"] sin = sample["sin"] collision = sample["collision"] goal = sample["goal"] action = sample["action"] state, terminal = self.model.prepare_state( latest_scan, distance, cos, sin, collision, goal, action ) if terminal: continue next_sample = samples[i + 1] next_latest_scan = next_sample["latest_scan"] next_distance = next_sample["distance"] next_cos = next_sample["cos"] next_sin = next_sample["sin"] next_collision = next_sample["collision"] next_goal = next_sample["goal"] next_action = next_sample["action"] next_state, next_terminal = self.model.prepare_state( next_latest_scan, next_distance, next_cos, next_sin, next_collision, next_goal, next_action, ) reward = self.reward_function( next_goal, next_collision, action, next_latest_scan ) self.replay_buffer.add( state, action, reward, next_terminal, next_state ) return self.replay_buffer def train( self, pretraining_iterations, replay_buffer, iterations, batch_size, ): print("Running Pretraining") for _ in tqdm(range(pretraining_iterations)): self.model.train( replay_buffer=replay_buffer, iterations=iterations, batch_size=batch_size, ) print("Model Pretrained") def get_buffer( model, sim, load_saved_buffer, pretrain, pretraining_iterations, training_iterations, batch_size, buffer_size=50000, random_seed=666, file_names=["assets/data.yml"], history_len=10, ): replay_buffer = ReplayBuffer(buffer_size=buffer_size, random_seed=random_seed) if pretrain: assert ( load_saved_buffer ), "To pre-train model, load_saved_buffer must be set to True" if load_saved_buffer: pretraining = Pretraining( file_names=file_names, model=model, replay_buffer=replay_buffer, reward_function=sim.get_reward, ) # instantiate pre-trainind replay_buffer = ( pretraining.load_buffer() ) # fill buffer with experiences from the data.yml file if pretrain: pretraining.train( pretraining_iterations=pretraining_iterations, replay_buffer=replay_buffer, iterations=training_iterations, batch_size=batch_size, ) # run pre-training return replay_buffer