Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Sebulba PPO Metrics #108

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def get_sebulba_eval_fn(
# we will run all episodes in parallel.
# Otherwise we will run `num_envs` parallel envs and loop enough times
# so that we do at least `eval_episodes` number of episodes.
n_parallel_envs = int(min(eval_episodes, config.arch.num_envs))
n_parallel_envs = int(min(eval_episodes, config.arch.total_num_envs))
episode_loops = math.ceil(eval_episodes / n_parallel_envs)
envs = env_factory(n_parallel_envs)
cpu = jax.devices("cpu")[0]
Expand Down
145 changes: 66 additions & 79 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import queue
import threading
import time
import warnings
from collections import defaultdict
from queue import Queue
Expand Down Expand Up @@ -91,7 +92,7 @@ def get_rollout_fn(
move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, actor_device), tree)
split_key_fn = jax.jit(jax.random.split, device=actor_device)
# Build the environments
envs = env_factory(config.arch.actor.envs_per_actor)
envs = env_factory(config.arch.actor.num_envs_per_actor)

# Create the rollout function
def rollout_fn(rng_key: chex.PRNGKey) -> None:
Expand Down Expand Up @@ -348,7 +349,7 @@ def _critic_loss_fn(

# SHUFFLE MINIBATCHES
# Since we shard the envs per actor across the devices
envs_per_batch = config.arch.actor.envs_per_actor // len(config.arch.learner.device_ids)
envs_per_batch = config.arch.actor.num_envs_per_actor // config.num_learner_devices
batch_size = config.system.rollout_length * envs_per_batch
permutation = jax.random.permutation(shuffle_key, batch_size)
batch = (traj_batch, advantages, targets)
Expand Down Expand Up @@ -429,38 +430,39 @@ def learner_rollout(learner_state: CoreLearnerState) -> None:
actor_timings: List[Dict] = []
learner_timings: Dict[str, List[float]] = defaultdict(list)
q_sizes: List[int] = []
# Loop for the number of updates per evaluation
for _ in range(config.arch.num_updates_per_eval):
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
with RecordTimeTo(learner_timings["rollout_get_time"]):
(
traj_batch,
timestep,
actor_times,
episode_metrics,
) = pipeline.get( # type: ignore
block=True
)
# We then replace the timestep in the learner state with the latest timestep
# This means the learner has access to the entire trajectory as well as
# an additional timestep which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# We then call the update function to update the networks
with RecordTimeTo(learner_timings["learning_time"]):
learner_state, train_metrics = learn_step(learner_state, traj_batch)

# We store the metrics and timings for this update
metrics.append((episode_metrics, train_metrics))
actor_timings.append(actor_times)
q_sizes.append(pipeline.qsize())

# After the update we need to update the params sources with the new params
unreplicated_params = unreplicate(learner_state.params)
# We loop over all params sources and update them with the new params
# This is so that all the actors can get the latest params
for source in params_sources:
source.update(unreplicated_params)
with RecordTimeTo(learner_timings["learner_time_per_eval"]):
# Loop for the number of updates per evaluation
for _ in range(config.arch.num_updates_per_eval):
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
with RecordTimeTo(learner_timings["rollout_get_time"]):
(
traj_batch,
timestep,
actor_times,
episode_metrics,
) = pipeline.get( # type: ignore
block=True
)
# We then replace the timestep in the learner state with the latest timestep
# This means the learner has access to the entire trajectory as well as
# an additional timestep which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# We then call the update function to update the networks
with RecordTimeTo(learner_timings["learner_step_time"]):
learner_state, train_metrics = learn_step(learner_state, traj_batch)

# We store the metrics and timings for this update
metrics.append((episode_metrics, train_metrics))
actor_timings.append(actor_times)
q_sizes.append(pipeline.qsize())

# After the update we need to update the params sources with the new params
unreplicated_params = unreplicate(learner_state.params)
# We loop over all params sources and update them with the new params
# This is so that all the actors can get the latest params
for source in params_sources:
source.update(unreplicated_params)

# We then pass all the environment metrics, training metrics, current learner state
# and timings to the evaluation queue. This is so the evaluator correctly evaluates
Expand Down Expand Up @@ -617,18 +619,6 @@ def run_experiment(_config: DictConfig) -> float:
"""Runs experiment."""
config = copy.deepcopy(_config)

# Perform some checks on the config
# This additionally calculates certains
# values based on the config
config = check_total_timesteps(config)

assert (
config.arch.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."

# Calculate the number of updates per evaluation
config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation)

# Get the learner and actor devices
local_devices = jax.local_devices()
global_devices = jax.devices()
Expand All @@ -647,28 +637,13 @@ def run_experiment(_config: DictConfig) -> float:
print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}")
# Set the number of learning and acting devices in the config
# useful for keeping track of experimental setup
config.num_learning_devices = len(local_learner_devices)
config.num_actor_actor_devices = len(actor_devices)

# Calculate the number of envs per actor
assert (
config.arch.num_envs == config.arch.total_num_envs
), "arch.num_envs must equal arch.total_num_envs for Sebulba architectures"
# We first simply take the total number of envs and divide by the number of actor devices
# to get the number of envs per actor device
num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices)
# We then divide this by the number of actors per device to get the number of envs per actor
num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device)
config.arch.actor.envs_per_actor = num_envs_per_actor

# We then perform a simple check to ensure that the number of envs per actor is
# divisible by the number of learner devices. This is because we shard the envs
# per actor across the learner devices This check is mainly relevant for on-policy
# algorithms
assert num_envs_per_actor % len(local_learner_devices) == 0, (
f"The number of envs per actor must be divisible by the number of learner devices. "
f"Got {num_envs_per_actor} envs per actor and {len(local_learner_devices)} learner devices"
)
config.num_learner_devices = len(local_learner_devices)
config.num_actor_devices = len(actor_devices)

# Perform some checks on the config
# This additionally calculates certains
# values based on the config
config = check_total_timesteps(config)

# Create the environment factory.
env_factory = environments.make_factory(config)
Expand Down Expand Up @@ -711,10 +686,10 @@ def run_experiment(_config: DictConfig) -> float:
# Get initial parameters
initial_params = unreplicate(learner_state.params)

# Get the number of steps per rollout
steps_per_rollout = (
config.system.rollout_length * config.arch.total_num_envs * config.arch.num_updates_per_eval
)
# Get the number of steps consumed by the learner per learner step
steps_per_learner_step = config.system.rollout_length * config.arch.actor.num_envs_per_actor
# Get the number of steps consumed by the learner per evaluation
steps_consumed_per_eval = steps_per_learner_step * config.arch.num_updates_per_eval

# Creating the pipeline
# First we create the lifetime so we can stop the pipeline when we want
Expand Down Expand Up @@ -743,7 +718,7 @@ def run_experiment(_config: DictConfig) -> float:
for i in range(config.arch.actor.actor_per_device):
key, actors_key = jax.random.split(key)
seeds = np_rng.integers(
np.iinfo(np.int32).max, size=config.arch.actor.envs_per_actor
np.iinfo(np.int32).max, size=config.arch.actor.num_envs_per_actor
).tolist()
actor_thread = get_actor_thread(
env_factory,
Expand Down Expand Up @@ -779,17 +754,25 @@ def run_experiment(_config: DictConfig) -> float:
episode_metrics, train_metrics, learner_state, timings_dict = eval_queue.get(block=True)

# Log the metrics and timings
t = int(steps_per_rollout * (eval_step + 1))
t = int(steps_consumed_per_eval * (eval_step + 1))
timings_dict["timestep"] = t
logger.log(timings_dict, t, eval_step, LogEvent.MISC)

episode_metrics, ep_completed = get_final_step_metrics(episode_metrics)
# Calculate steps per second for actor
# Here we use the number of steps pushed to the pipeline each time
# and the average time it takes to do a single rollout across
# all the updates per evaluation
episode_metrics["steps_per_second"] = (
steps_per_rollout / timings_dict["single_rollout_time"]
steps_per_learner_step / timings_dict["single_rollout_time"]
)
if ep_completed:
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)

train_metrics["learner_step"] = (eval_step + 1) * config.arch.num_updates_per_eval
train_metrics["sgd_steps_per_second"] = (config.arch.num_updates_per_eval) / timings_dict[
"learner_time_per_eval"
]
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Evaluate the current model and log the metrics
Expand All @@ -803,7 +786,7 @@ def run_experiment(_config: DictConfig) -> float:
if save_checkpoint:
# Save checkpoint of learner state
checkpointer.save(
timestep=steps_per_rollout * (eval_step + 1),
timestep=steps_consumed_per_eval * (eval_step + 1),
unreplicated_learner_state=unreplicate(learner_state),
episode_return=episode_return,
)
Expand All @@ -824,6 +807,7 @@ def run_experiment(_config: DictConfig) -> float:

# Now we stop the actors and params sources
print(f"{Fore.MAGENTA}{Style.BRIGHT}Closing actors...{Style.RESET_ALL}")
pipeline.clear()
for actor in actor_threads:
# We clear the pipeline before stopping each actor thread
# since actors can be blocked on the pipeline
Expand All @@ -850,7 +834,7 @@ def run_experiment(_config: DictConfig) -> float:
key, eval_key = jax.random.split(key, 2)
eval_metrics = abs_metric_evaluator(best_params, eval_key)

t = int(steps_per_rollout * (eval_step + 1))
t = int(steps_consumed_per_eval * (eval_step + 1))
logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE)
abs_metric_evaluator_envs.close()

Expand All @@ -871,9 +855,12 @@ def hydra_entry_point(cfg: DictConfig) -> float:
OmegaConf.set_struct(cfg, False)

# Run experiment.
start = time.monotonic()
eval_performance = run_experiment(cfg)

print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}")
end = time.monotonic()
print(
f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed in {end - start:.2f}s.{Style.RESET_ALL}"
)
return eval_performance


Expand Down
Loading
Loading