From ad07426cc0307be1a076f7043eba039f249ca50c Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 27 Aug 2024 20:21:50 +0000 Subject: [PATCH] chore: refactor to add more metrics for sebulba ppo --- stoix/evaluator.py | 2 +- stoix/systems/ppo/sebulba/ff_ppo.py | 57 ++++-------- stoix/utils/total_timestep_checker.py | 119 +++++++++++++++++++++----- 3 files changed, 113 insertions(+), 65 deletions(-) diff --git a/stoix/evaluator.py b/stoix/evaluator.py index 2a86890..27c6103 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -371,7 +371,7 @@ def get_sebulba_eval_fn( # we will run all episodes in parallel. # Otherwise we will run `num_envs` parallel envs and loop enough times # so that we do at least `eval_episodes` number of episodes. - n_parallel_envs = int(min(eval_episodes, config.arch.num_envs)) + n_parallel_envs = int(min(eval_episodes, config.arch.total_num_envs)) episode_loops = math.ceil(eval_episodes / n_parallel_envs) envs = env_factory(n_parallel_envs) cpu = jax.devices("cpu")[0] diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index ed3c10b..8bae635 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -1,6 +1,7 @@ import copy import queue import threading +import time import warnings from collections import defaultdict from queue import Queue @@ -91,7 +92,7 @@ def get_rollout_fn( move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, actor_device), tree) split_key_fn = jax.jit(jax.random.split, device=actor_device) # Build the environments - envs = env_factory(config.arch.actor.envs_per_actor) + envs = env_factory(config.arch.actor.num_envs_per_actor) # Create the rollout function def rollout_fn(rng_key: chex.PRNGKey) -> None: @@ -348,7 +349,7 @@ def _critic_loss_fn( # SHUFFLE MINIBATCHES # Since we shard the envs per actor across the devices - envs_per_batch = config.arch.actor.envs_per_actor // len(config.arch.learner.device_ids) + envs_per_batch = config.arch.actor.num_envs_per_actor // config.num_learner_devices batch_size = config.system.rollout_length * envs_per_batch permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) @@ -448,7 +449,7 @@ def learner_rollout(learner_state: CoreLearnerState) -> None: # an additional timestep which it can use to bootstrap. learner_state = learner_state._replace(timestep=timestep) # We then call the update function to update the networks - with RecordTimeTo(learner_timings["learning_time"]): + with RecordTimeTo(learner_timings["learner_step_time"]): learner_state, train_metrics = learn_step(learner_state, traj_batch) # We store the metrics and timings for this update @@ -618,18 +619,6 @@ def run_experiment(_config: DictConfig) -> float: """Runs experiment.""" config = copy.deepcopy(_config) - # Perform some checks on the config - # This additionally calculates certains - # values based on the config - config = check_total_timesteps(config) - - assert ( - config.arch.num_updates > config.arch.num_evaluation - ), "Number of updates per evaluation must be less than total number of updates." - - # Calculate the number of updates per evaluation - config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation) - # Get the learner and actor devices local_devices = jax.local_devices() global_devices = jax.devices() @@ -648,28 +637,13 @@ def run_experiment(_config: DictConfig) -> float: print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}") # Set the number of learning and acting devices in the config # useful for keeping track of experimental setup - config.num_learning_devices = len(local_learner_devices) - config.num_actor_actor_devices = len(actor_devices) - - # Calculate the number of envs per actor - assert ( - config.arch.num_envs == config.arch.total_num_envs - ), "arch.num_envs must equal arch.total_num_envs for Sebulba architectures" - # We first simply take the total number of envs and divide by the number of actor devices - # to get the number of envs per actor device - num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices) - # We then divide this by the number of actors per device to get the number of envs per actor - num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device) - config.arch.actor.envs_per_actor = num_envs_per_actor - - # We then perform a simple check to ensure that the number of envs per actor is - # divisible by the number of learner devices. This is because we shard the envs - # per actor across the learner devices This check is mainly relevant for on-policy - # algorithms - assert num_envs_per_actor % len(local_learner_devices) == 0, ( - f"The number of envs per actor must be divisible by the number of learner devices. " - f"Got {num_envs_per_actor} envs per actor and {len(local_learner_devices)} learner devices" - ) + config.num_learner_devices = len(local_learner_devices) + config.num_actor_devices = len(actor_devices) + + # Perform some checks on the config + # This additionally calculates certains + # values based on the config + config = check_total_timesteps(config) # Create the environment factory. env_factory = environments.make_factory(config) @@ -713,7 +687,7 @@ def run_experiment(_config: DictConfig) -> float: initial_params = unreplicate(learner_state.params) # Get the number of steps consumed by the learner per learner step - steps_per_learner_step = config.system.rollout_length * config.arch.actor.envs_per_actor + steps_per_learner_step = config.system.rollout_length * config.arch.actor.num_envs_per_actor # Get the number of steps consumed by the learner per evaluation steps_consumed_per_eval = steps_per_learner_step * config.arch.num_updates_per_eval @@ -744,7 +718,7 @@ def run_experiment(_config: DictConfig) -> float: for i in range(config.arch.actor.actor_per_device): key, actors_key = jax.random.split(key) seeds = np_rng.integers( - np.iinfo(np.int32).max, size=config.arch.actor.envs_per_actor + np.iinfo(np.int32).max, size=config.arch.actor.num_envs_per_actor ).tolist() actor_thread = get_actor_thread( env_factory, @@ -880,9 +854,10 @@ def hydra_entry_point(cfg: DictConfig) -> float: OmegaConf.set_struct(cfg, False) # Run experiment. + start = time.monotonic() eval_performance = run_experiment(cfg) - - print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}") + end = time.monotonic() + print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed in {end - start:.2f} seconds.{Style.RESET_ALL}") return eval_performance diff --git a/stoix/utils/total_timestep_checker.py b/stoix/utils/total_timestep_checker.py index 9e64685..4ccce27 100644 --- a/stoix/utils/total_timestep_checker.py +++ b/stoix/utils/total_timestep_checker.py @@ -5,38 +5,33 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: """Check if total_timesteps is set, if not, set it based on the other parameters""" - # Check if the number of devices is set in the config - # If not, it is assumed that the number of devices is 1 - # For the case of using a sebulba config, the number of - # devices is set to 1 for the calculation - # of the number of environments per device, etc - if "num_devices" not in config: - num_devices = 1 + # If num_devices and update_batch_size are not in the config, + # usually this means a sebulba config is being used. + if "num_devices" not in config and "update_batch_size" not in config.arch: + return check_total_timesteps_sebulba(config) else: - num_devices = config.num_devices + return check_total_timesteps_anakin(config) - # If update_batch_size is not in the config, usually this means a sebulba config is being used. - if "update_batch_size" not in config.arch: - update_batch_size = 1 - print(f"{Fore.YELLOW}{Style.BRIGHT}Using Sebulba System!{Style.RESET_ALL}") - else: - update_batch_size = config.arch.update_batch_size - print(f"{Fore.YELLOW}{Style.BRIGHT}Using Anakin System!{Style.RESET_ALL}") - assert config.arch.total_num_envs % (num_devices * update_batch_size) == 0, ( +def check_total_timesteps_anakin(config: DictConfig) -> DictConfig: + """Check if total_timesteps is set, if not, set it based on the other parameters""" + + print(f"{Fore.YELLOW}{Style.BRIGHT}Using Anakin System!{Style.RESET_ALL}") + + assert config.arch.total_num_envs % (config.num_devices * config.arch.update_batch_size) == 0, ( f"{Fore.RED}{Style.BRIGHT}The total number of environments " + f"should be divisible by the n_devices*update_batch_size!{Style.RESET_ALL}" ) config.arch.num_envs = int( - config.arch.total_num_envs // (num_devices * update_batch_size) + config.arch.total_num_envs // (config.num_devices * config.arch.update_batch_size) ) # Number of environments per device if config.arch.total_timesteps is None: config.arch.total_timesteps = ( - num_devices + config.num_devices * config.arch.num_updates * config.system.rollout_length - * update_batch_size + * config.arch.update_batch_size * config.arch.num_envs ) print( @@ -49,9 +44,9 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: config.arch.num_updates = ( config.arch.total_timesteps // config.system.rollout_length - // update_batch_size + // config.arch.update_batch_size // config.arch.num_envs - // num_devices + // config.num_devices ) print( f"{Fore.YELLOW}{Style.BRIGHT}Changing the number of updates " @@ -63,10 +58,10 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: # Calculate the actual number of timesteps that will be run num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation steps_per_rollout = ( - num_devices + config.num_devices * num_updates_per_eval * config.system.rollout_length - * update_batch_size + * config.arch.update_batch_size * config.arch.num_envs ) total_actual_timesteps = steps_per_rollout * config.arch.num_evaluation @@ -80,3 +75,81 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: ) return config + + +def check_total_timesteps_sebulba(config: DictConfig) -> DictConfig: + """Check if total_timesteps is set, if not, set it based on the other parameters""" + + print(f"{Fore.YELLOW}{Style.BRIGHT}Using Sebulba System!{Style.RESET_ALL}") + + assert ( + config.arch.total_num_envs % (config.num_actor_devices * config.arch.actor.actor_per_device) + == 0 + ), ( + f"{Fore.RED}{Style.BRIGHT}The total number of environments " + + f"should be divisible by the number of actor devices * actor_per_device!{Style.RESET_ALL}" + ) + # We first simply take the total number of envs and divide by the number of actor devices + # to get the number of envs per actor device + num_envs_per_actor_device = config.arch.total_num_envs // config.num_actor_devices + # We then divide this by the number of actors per device to get the number of envs per actor + num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device) + config.arch.actor.num_envs_per_actor = num_envs_per_actor + + # We base the total number of timesteps based on the number of steps the learner consumes + if config.arch.total_timesteps is None: + config.arch.total_timesteps = ( + config.arch.num_updates + * config.system.rollout_length + * config.arch.actor.num_envs_per_actor + ) + print( + f"{Fore.YELLOW}{Style.BRIGHT}Changing the total number of timesteps " + + f"to {config.arch.total_timesteps}: If you want to train" + + " for a specific number of timesteps, please set num_updates to None!" + + f"{Style.RESET_ALL}" + ) + else: + config.arch.num_updates = ( + config.arch.total_timesteps + // config.system.rollout_length + // config.arch.actor.num_envs_per_actor + ) + print( + f"{Fore.YELLOW}{Style.BRIGHT}Changing the number of updates " + + f"to {config.arch.num_updates}: If you want to train" + + " for a specific number of updates, please set total_timesteps to None!" + + f"{Style.RESET_ALL}" + ) + + # Calculate the number of updates per evaluation + config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation) + # Get the number of steps consumed by the learner per learner step + steps_per_learner_step = config.system.rollout_length * config.arch.actor.num_envs_per_actor + # Get the number of steps consumed by the learner per evaluation + steps_consumed_per_eval = steps_per_learner_step * config.arch.num_updates_per_eval + total_actual_timesteps = steps_consumed_per_eval * config.arch.num_evaluation + print( + f"{Fore.RED}{Style.BRIGHT}Warning: Due to the interaction of various factors such as " + f"rollout length, number of evaluations, etc... the actual number of timesteps that " + f"will be run is {total_actual_timesteps}! This is a difference of " + f"{config.arch.total_timesteps - total_actual_timesteps} timesteps! To change this, " + f"see total_timestep_checker.py in the utils folder. " + f"{Style.RESET_ALL}" + ) + + assert ( + config.arch.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # We then perform a simple check to ensure that the number of envs per actor is + # divisible by the number of learner devices. This is because we shard the envs + # per actor across the learner devices This check is mainly relevant for on-policy + # algorithms + assert config.arch.actor.num_envs_per_actor % config.num_learner_devices == 0, ( + f"The number of envs per actor must be divisible by the number of learner devices. " + f"Got {config.arch.actor.num_envs_per_actor} envs per actor " + f"and {config.num_learner_devices} learner devices" + ) + + return config