Skip to content

Commit

Permalink
chore: refactor to add more metrics for sebulba ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 27, 2024
1 parent db6a23e commit ad07426
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 65 deletions.
2 changes: 1 addition & 1 deletion stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def get_sebulba_eval_fn(
# we will run all episodes in parallel.
# Otherwise we will run `num_envs` parallel envs and loop enough times
# so that we do at least `eval_episodes` number of episodes.
n_parallel_envs = int(min(eval_episodes, config.arch.num_envs))
n_parallel_envs = int(min(eval_episodes, config.arch.total_num_envs))
episode_loops = math.ceil(eval_episodes / n_parallel_envs)
envs = env_factory(n_parallel_envs)
cpu = jax.devices("cpu")[0]
Expand Down
57 changes: 16 additions & 41 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import queue
import threading
import time
import warnings
from collections import defaultdict
from queue import Queue
Expand Down Expand Up @@ -91,7 +92,7 @@ def get_rollout_fn(
move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, actor_device), tree)
split_key_fn = jax.jit(jax.random.split, device=actor_device)
# Build the environments
envs = env_factory(config.arch.actor.envs_per_actor)
envs = env_factory(config.arch.actor.num_envs_per_actor)

# Create the rollout function
def rollout_fn(rng_key: chex.PRNGKey) -> None:
Expand Down Expand Up @@ -348,7 +349,7 @@ def _critic_loss_fn(

# SHUFFLE MINIBATCHES
# Since we shard the envs per actor across the devices
envs_per_batch = config.arch.actor.envs_per_actor // len(config.arch.learner.device_ids)
envs_per_batch = config.arch.actor.num_envs_per_actor // config.num_learner_devices
batch_size = config.system.rollout_length * envs_per_batch
permutation = jax.random.permutation(shuffle_key, batch_size)
batch = (traj_batch, advantages, targets)
Expand Down Expand Up @@ -448,7 +449,7 @@ def learner_rollout(learner_state: CoreLearnerState) -> None:
# an additional timestep which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# We then call the update function to update the networks
with RecordTimeTo(learner_timings["learning_time"]):
with RecordTimeTo(learner_timings["learner_step_time"]):
learner_state, train_metrics = learn_step(learner_state, traj_batch)

# We store the metrics and timings for this update
Expand Down Expand Up @@ -618,18 +619,6 @@ def run_experiment(_config: DictConfig) -> float:
"""Runs experiment."""
config = copy.deepcopy(_config)

# Perform some checks on the config
# This additionally calculates certains
# values based on the config
config = check_total_timesteps(config)

assert (
config.arch.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."

# Calculate the number of updates per evaluation
config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation)

# Get the learner and actor devices
local_devices = jax.local_devices()
global_devices = jax.devices()
Expand All @@ -648,28 +637,13 @@ def run_experiment(_config: DictConfig) -> float:
print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}")
# Set the number of learning and acting devices in the config
# useful for keeping track of experimental setup
config.num_learning_devices = len(local_learner_devices)
config.num_actor_actor_devices = len(actor_devices)

# Calculate the number of envs per actor
assert (
config.arch.num_envs == config.arch.total_num_envs
), "arch.num_envs must equal arch.total_num_envs for Sebulba architectures"
# We first simply take the total number of envs and divide by the number of actor devices
# to get the number of envs per actor device
num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices)
# We then divide this by the number of actors per device to get the number of envs per actor
num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device)
config.arch.actor.envs_per_actor = num_envs_per_actor

# We then perform a simple check to ensure that the number of envs per actor is
# divisible by the number of learner devices. This is because we shard the envs
# per actor across the learner devices This check is mainly relevant for on-policy
# algorithms
assert num_envs_per_actor % len(local_learner_devices) == 0, (
f"The number of envs per actor must be divisible by the number of learner devices. "
f"Got {num_envs_per_actor} envs per actor and {len(local_learner_devices)} learner devices"
)
config.num_learner_devices = len(local_learner_devices)
config.num_actor_devices = len(actor_devices)

# Perform some checks on the config
# This additionally calculates certains
# values based on the config
config = check_total_timesteps(config)

# Create the environment factory.
env_factory = environments.make_factory(config)
Expand Down Expand Up @@ -713,7 +687,7 @@ def run_experiment(_config: DictConfig) -> float:
initial_params = unreplicate(learner_state.params)

# Get the number of steps consumed by the learner per learner step
steps_per_learner_step = config.system.rollout_length * config.arch.actor.envs_per_actor
steps_per_learner_step = config.system.rollout_length * config.arch.actor.num_envs_per_actor
# Get the number of steps consumed by the learner per evaluation
steps_consumed_per_eval = steps_per_learner_step * config.arch.num_updates_per_eval

Expand Down Expand Up @@ -744,7 +718,7 @@ def run_experiment(_config: DictConfig) -> float:
for i in range(config.arch.actor.actor_per_device):
key, actors_key = jax.random.split(key)
seeds = np_rng.integers(
np.iinfo(np.int32).max, size=config.arch.actor.envs_per_actor
np.iinfo(np.int32).max, size=config.arch.actor.num_envs_per_actor
).tolist()
actor_thread = get_actor_thread(
env_factory,
Expand Down Expand Up @@ -880,9 +854,10 @@ def hydra_entry_point(cfg: DictConfig) -> float:
OmegaConf.set_struct(cfg, False)

# Run experiment.
start = time.monotonic()
eval_performance = run_experiment(cfg)

print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}")
end = time.monotonic()
print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed in {end - start:.2f} seconds.{Style.RESET_ALL}")
return eval_performance


Expand Down
119 changes: 96 additions & 23 deletions stoix/utils/total_timestep_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,33 @@
def check_total_timesteps(config: DictConfig) -> DictConfig:
"""Check if total_timesteps is set, if not, set it based on the other parameters"""

# Check if the number of devices is set in the config
# If not, it is assumed that the number of devices is 1
# For the case of using a sebulba config, the number of
# devices is set to 1 for the calculation
# of the number of environments per device, etc
if "num_devices" not in config:
num_devices = 1
# If num_devices and update_batch_size are not in the config,
# usually this means a sebulba config is being used.
if "num_devices" not in config and "update_batch_size" not in config.arch:
return check_total_timesteps_sebulba(config)
else:
num_devices = config.num_devices
return check_total_timesteps_anakin(config)

# If update_batch_size is not in the config, usually this means a sebulba config is being used.
if "update_batch_size" not in config.arch:
update_batch_size = 1
print(f"{Fore.YELLOW}{Style.BRIGHT}Using Sebulba System!{Style.RESET_ALL}")
else:
update_batch_size = config.arch.update_batch_size
print(f"{Fore.YELLOW}{Style.BRIGHT}Using Anakin System!{Style.RESET_ALL}")

assert config.arch.total_num_envs % (num_devices * update_batch_size) == 0, (
def check_total_timesteps_anakin(config: DictConfig) -> DictConfig:
"""Check if total_timesteps is set, if not, set it based on the other parameters"""

print(f"{Fore.YELLOW}{Style.BRIGHT}Using Anakin System!{Style.RESET_ALL}")

assert config.arch.total_num_envs % (config.num_devices * config.arch.update_batch_size) == 0, (
f"{Fore.RED}{Style.BRIGHT}The total number of environments "
+ f"should be divisible by the n_devices*update_batch_size!{Style.RESET_ALL}"
)
config.arch.num_envs = int(
config.arch.total_num_envs // (num_devices * update_batch_size)
config.arch.total_num_envs // (config.num_devices * config.arch.update_batch_size)
) # Number of environments per device

if config.arch.total_timesteps is None:
config.arch.total_timesteps = (
num_devices
config.num_devices
* config.arch.num_updates
* config.system.rollout_length
* update_batch_size
* config.arch.update_batch_size
* config.arch.num_envs
)
print(
Expand All @@ -49,9 +44,9 @@ def check_total_timesteps(config: DictConfig) -> DictConfig:
config.arch.num_updates = (
config.arch.total_timesteps
// config.system.rollout_length
// update_batch_size
// config.arch.update_batch_size
// config.arch.num_envs
// num_devices
// config.num_devices
)
print(
f"{Fore.YELLOW}{Style.BRIGHT}Changing the number of updates "
Expand All @@ -63,10 +58,10 @@ def check_total_timesteps(config: DictConfig) -> DictConfig:
# Calculate the actual number of timesteps that will be run
num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation
steps_per_rollout = (
num_devices
config.num_devices
* num_updates_per_eval
* config.system.rollout_length
* update_batch_size
* config.arch.update_batch_size
* config.arch.num_envs
)
total_actual_timesteps = steps_per_rollout * config.arch.num_evaluation
Expand All @@ -80,3 +75,81 @@ def check_total_timesteps(config: DictConfig) -> DictConfig:
)

return config


def check_total_timesteps_sebulba(config: DictConfig) -> DictConfig:
"""Check if total_timesteps is set, if not, set it based on the other parameters"""

print(f"{Fore.YELLOW}{Style.BRIGHT}Using Sebulba System!{Style.RESET_ALL}")

assert (
config.arch.total_num_envs % (config.num_actor_devices * config.arch.actor.actor_per_device)
== 0
), (
f"{Fore.RED}{Style.BRIGHT}The total number of environments "
+ f"should be divisible by the number of actor devices * actor_per_device!{Style.RESET_ALL}"
)
# We first simply take the total number of envs and divide by the number of actor devices
# to get the number of envs per actor device
num_envs_per_actor_device = config.arch.total_num_envs // config.num_actor_devices
# We then divide this by the number of actors per device to get the number of envs per actor
num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device)
config.arch.actor.num_envs_per_actor = num_envs_per_actor

# We base the total number of timesteps based on the number of steps the learner consumes
if config.arch.total_timesteps is None:
config.arch.total_timesteps = (
config.arch.num_updates
* config.system.rollout_length
* config.arch.actor.num_envs_per_actor
)
print(
f"{Fore.YELLOW}{Style.BRIGHT}Changing the total number of timesteps "
+ f"to {config.arch.total_timesteps}: If you want to train"
+ " for a specific number of timesteps, please set num_updates to None!"
+ f"{Style.RESET_ALL}"
)
else:
config.arch.num_updates = (
config.arch.total_timesteps
// config.system.rollout_length
// config.arch.actor.num_envs_per_actor
)
print(
f"{Fore.YELLOW}{Style.BRIGHT}Changing the number of updates "
+ f"to {config.arch.num_updates}: If you want to train"
+ " for a specific number of updates, please set total_timesteps to None!"
+ f"{Style.RESET_ALL}"
)

# Calculate the number of updates per evaluation
config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation)
# Get the number of steps consumed by the learner per learner step
steps_per_learner_step = config.system.rollout_length * config.arch.actor.num_envs_per_actor
# Get the number of steps consumed by the learner per evaluation
steps_consumed_per_eval = steps_per_learner_step * config.arch.num_updates_per_eval
total_actual_timesteps = steps_consumed_per_eval * config.arch.num_evaluation
print(
f"{Fore.RED}{Style.BRIGHT}Warning: Due to the interaction of various factors such as "
f"rollout length, number of evaluations, etc... the actual number of timesteps that "
f"will be run is {total_actual_timesteps}! This is a difference of "
f"{config.arch.total_timesteps - total_actual_timesteps} timesteps! To change this, "
f"see total_timestep_checker.py in the utils folder. "
f"{Style.RESET_ALL}"
)

assert (
config.arch.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."

# We then perform a simple check to ensure that the number of envs per actor is
# divisible by the number of learner devices. This is because we shard the envs
# per actor across the learner devices This check is mainly relevant for on-policy
# algorithms
assert config.arch.actor.num_envs_per_actor % config.num_learner_devices == 0, (
f"The number of envs per actor must be divisible by the number of learner devices. "
f"Got {config.arch.actor.num_envs_per_actor} envs per actor "
f"and {config.num_learner_devices} learner devices"
)

return config

0 comments on commit ad07426

Please sign in to comment.