diff --git a/stoix/configs/arch/sebulba.yaml b/stoix/configs/arch/sebulba.yaml index 104f3b98..7051c2bb 100644 --- a/stoix/configs/arch/sebulba.yaml +++ b/stoix/configs/arch/sebulba.yaml @@ -17,11 +17,11 @@ learner: device_ids: [2,3] # Define which devices to use for the learner. # Size of the queue for the pipeline where actors push data and the learner pulls data. -pipeline: - samples_per_insert_tolerance_rate: 0.1 - min_replay_size: 10 +# pipeline: +# samples_per_insert_tolerance_rate: 0.1 +# min_replay_size: 10 -# pipeline_queue_size: 10 +pipeline_queue_size: 10 # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select diff --git a/stoix/configs/system/ff_dqn.yaml b/stoix/configs/system/ff_dqn.yaml index 4c38263f..9e4e60db 100644 --- a/stoix/configs/system/ff_dqn.yaml +++ b/stoix/configs/system/ff_dqn.yaml @@ -3,7 +3,7 @@ system_name: ff_dqn # Name of the system. # --- RL hyperparameters --- -rollout_length: 2 # Number of environment steps per vectorised environment. +rollout_length: 16 # Number of environment steps per vectorised environment. epochs: 16 # Number of sgd steps per rollout. warmup_steps: 16 # Number of steps to collect before training. total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size. diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index dce260a3..77d80063 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -422,6 +422,7 @@ def get_learner_rollout_fn( network parameters to a queue for evaluation.""" def learner_rollout(learner_state: CoreLearnerState) -> None: + learner_step = 0 # Loop for the total number of evaluations selected to be performed. for _ in range(config.arch.num_evaluation): # Create the lists to store metrics and timings for this learning iteration. @@ -450,6 +451,7 @@ def learner_rollout(learner_state: CoreLearnerState) -> None: with RecordTimeTo(learner_timings["learning_time"]): learner_state, train_metrics = learn_step(learner_state, traj_batch) + learner_step += 1 # We store the metrics and timings for this update metrics.append((episode_metrics, train_metrics)) actor_timings.append(actor_times) @@ -469,6 +471,7 @@ def learner_rollout(learner_state: CoreLearnerState) -> None: actor_timings = jax.tree.map(lambda *x: np.mean(x), *actor_timings) timing_dict = actor_timings | learner_timings timing_dict["pipeline_qsize"] = q_sizes + timing_dict["learner_step"] = [learner_step] timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) try: # We add a timeout mainly for sanity checks diff --git a/stoix/systems/q_learning/sebulba/ff_dqn.py b/stoix/systems/q_learning/sebulba/ff_dqn.py index 7e547a8b..7795c7b8 100644 --- a/stoix/systems/q_learning/sebulba/ff_dqn.py +++ b/stoix/systems/q_learning/sebulba/ff_dqn.py @@ -38,7 +38,7 @@ from stoix.utils.env_factory import EnvFactory from stoix.utils.logger import LogEvent, StoixLogger from stoix.utils.loss import q_learning -from stoix.utils.rate_limiters import MinSize, SampleToInsertRatio +from stoix.utils.rate_limiters import SampleToInsertRatio from stoix.utils.sebulba_utils import ( OffPolicyPipeline, ParamsSource, @@ -156,8 +156,8 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None: # Send the trajectory to the pipeline with RecordTimeTo(actor_timings_dict["rollout_put_time"]): try: - pipeline.put(traj, actor_timings_dict, episode_metrics) - except queue.Full: + pipeline.put(traj, actor_timings_dict, episode_metrics, timeout=60) + except TimeoutError: warnings.warn( "Waited too long to add to the rollout queue, killing the actor thread", stacklevel=2, @@ -325,7 +325,14 @@ def learner_rollout(learner_state: CoreOffPolicyLearnerState) -> None: # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. with RecordTimeTo(learner_timings["rollout_get_time"]): - traj_batch, actor_times, episode_metrics = pipeline.get() # type: ignore + try: + traj_batch, actor_times, episode_metrics = pipeline.get(timeout=60) # type: ignore + except TimeoutError: + warnings.warn( + "Waited too long to sample from pipeline, killing the learner thread", + stacklevel=2, + ) + break # We then call the update function to update the networks with RecordTimeTo(learner_timings["learning_time"]): @@ -348,6 +355,13 @@ def learner_rollout(learner_state: CoreOffPolicyLearnerState) -> None: episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) actor_timings = jax.tree.map(lambda *x: np.mean(x), *actor_timings) timing_dict = actor_timings | learner_timings + buffer_size = ( + (pipeline.get_num_inserts() - pipeline.get_num_deletes()) + * config.arch.actor.envs_per_actor + * config.system.rollout_length + ) + timing_dict["buffer_size"] = [buffer_size] + timing_dict["num_samples"] = [pipeline.get_num_samples()] timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) try: # We add a timeout mainly for sanity checks @@ -623,18 +637,12 @@ def run_experiment(_config: DictConfig) -> float: # Now we create the pipeline replay_buffer_add, replay_buffer_sample = buffer_fns # Set up the rate limiter that controls how actors and learners interact with the pipeline - if config.system.epochs > 0: - samples_per_insert_tolerance_rate = config.arch.pipeline.samples_per_insert_tolerance_rate - samples_per_insert_tolerance = ( - samples_per_insert_tolerance_rate * config.system.epochs - ) - error_buffer = config.arch.pipeline.min_replay_size * samples_per_insert_tolerance - rate_limiter = SampleToInsertRatio( - config.system.epochs, config.arch.pipeline.min_replay_size, error_buffer - ) - else: - pass - rate_limiter = MinSize(config.arch.pipeline.min_replay_size) # type: ignore + samples_per_insert_tolerance_rate = 0.1 # This allows for 10% tolerance + samples_per_insert_tolerance = samples_per_insert_tolerance_rate * config.system.epochs + error_buffer = config.system.total_batch_size * samples_per_insert_tolerance + min_inserts = max(config.system.total_batch_size // steps_per_insert, 1) + rate_limiter = SampleToInsertRatio(config.system.epochs, min_inserts, error_buffer) + pipeline = OffPolicyPipeline( replay_buffer_add, replay_buffer_sample, diff --git a/stoix/utils/sebulba_utils.py b/stoix/utils/sebulba_utils.py index 5064e5ad..cbe54126 100644 --- a/stoix/utils/sebulba_utils.py +++ b/stoix/utils/sebulba_utils.py @@ -196,6 +196,7 @@ def put( traj: Sequence[StoixTransition], actor_timings_dict: Dict[str, List[float]], actor_episode_metrics: List[Dict[str, List[float]]], + timeout: Union[float, None] = None, ) -> None: start_condition, end_condition = (threading.Condition(), threading.Condition()) with start_condition: @@ -207,7 +208,7 @@ def put( # wait until we can insert the data try: - self.rate_limiter.await_can_insert(timeout=180) + self.rate_limiter.await_can_insert(timeout=timeout) except TimeoutError: print( f"{Fore.RED}{Style.BRIGHT}Actor has timed out on insertion, " @@ -223,7 +224,7 @@ def put( # signal that we have inserted the data self.rate_limiter.insert() - + # Concatenate metrics - List[Dict[str, List[float]]] --> Dict[str, List[float]] actor_episode_metrics = self.concatenate_metrics(actor_episode_metrics) @@ -261,6 +262,18 @@ def get(self, timeout: Union[float, None] = None) -> Tuple[StoixTransition, Dict return sharded_sampled_batch, actor_timings, actor_metrics # type: ignore + def get_num_inserts(self) -> int: + """Get the number of inserts that have been made to the buffer.""" + return self.rate_limiter.num_inserts() + + def get_num_samples(self) -> int: + """Get the number of samples that have been made from the buffer.""" + return self.rate_limiter.num_samples() + + def get_num_deletes(self) -> int: + """Get the number of deletes that have been made from the buffer.""" + return self.rate_limiter.num_deletes() + @partial(jax.jit, static_argnums=(0, 2)) def stack_trajectory(self, trajectory: List[StoixTransition], axis: int = 0) -> StoixTransition: """Stack a list of parallel_env transitions into a single @@ -271,7 +284,7 @@ def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: """Split the payload over the learner devices.""" split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) return jax.device_put_sharded(split_payload, devices=self.learner_devices) - + @partial(jax.jit, static_argnums=(0,)) def concatenate_metrics( self, actor_metrics: List[Dict[str, List[float]]]