Skip to content

Commit

Permalink
chore: some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 13, 2024
1 parent 7b206a5 commit 7278a60
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 94 deletions.
4 changes: 2 additions & 2 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def get_sebulba_eval_fn(
)

def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict:
"""Evaluates the given params on an environment and returns relevent metrics.
"""Evaluates the given params on an environment and returns relevant metrics.
Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length,
also win rate for environments that support it.
Expand Down Expand Up @@ -431,7 +431,7 @@ def _episode(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:
# care about subsequent steps because we only the results from the first episode
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging
del metrics["is_terminal_step"] # unneeded for logging

return key, metrics

Expand Down
76 changes: 34 additions & 42 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import threading
import time
from collections import defaultdict
from queue import Queue
from typing import Any, Dict, List, Sequence, Tuple
Expand All @@ -15,7 +14,6 @@
from colorama import Fore, Style
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from jumanji.env import Environment
from omegaconf import DictConfig, OmegaConf
from rich.pretty import pprint

Expand All @@ -24,29 +22,19 @@
ActorCriticOptStates,
ActorCriticParams,
CriticApply,
EnvFactory,
ExperimentOutput,
LearnerFn,
LearnerState,
Observation,
SebulbaLearnerFn,
)
from stoix.evaluator import (
evaluator_setup,
get_distribution_act_fn,
get_sebulba_eval_fn,
)
from stoix.evaluator import get_distribution_act_fn, get_sebulba_eval_fn
from stoix.networks.base import FeedForwardActor as Actor
from stoix.networks.base import FeedForwardCritic as Critic
from stoix.systems.ppo.ppo_types import PPOTransition
from stoix.utils import make_env as environments
from stoix.utils.checkpointing import Checkpointer
from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory
from stoix.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from stoix.utils.env_factory import EnvFactory
from stoix.utils.jax_utils import merge_leading_dims
from stoix.utils.logger import LogEvent, StoixLogger
from stoix.utils.loss import clipped_value_loss, ppo_clip_loss
from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation
Expand Down Expand Up @@ -78,7 +66,9 @@ def get_rollout_fn(
critic_apply_fn = jax.jit(critic_apply_fn, device=actor_device)
cpu = jax.devices("cpu")[0]
split_key_fn = jax.jit(jax.random.split, device=actor_device)
move_to_device = lambda tree: jax.tree_map(lambda x: jax.device_put(x, actor_device), tree)
move_to_device = lambda tree: jax.tree_util.tree_map(
lambda x: jax.device_put(x, actor_device), tree
)
move_to_device = jax.jit(move_to_device)
# Build the environments
envs = env_factory(config.arch.actor.envs_per_actor)
Expand Down Expand Up @@ -115,7 +105,7 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None:
cached_next_dones = move_to_device(next_dones)
cached_next_trunc = move_to_device(next_trunc)

# Run the actor and critic networks to get the action, value and log_prob
# Run the actor and critic networks to get the action, value and log_prob
with RecordTimeTo(timings_dict["compute_action_time"]):
rng_key, policy_key = split_key_fn(rng_key)
pi = actor_apply_fn(params.actor_params, cached_next_obs)
Expand All @@ -126,7 +116,7 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None:
# Move the action to the CPU
with RecordTimeTo(timings_dict["put_action_on_cpu_time"]):
action_cpu = np.asarray(jax.device_put(action, cpu))

# Step the environment
with RecordTimeTo(timings_dict["env_step_time"]):
timestep = envs.step(action_cpu)
Expand Down Expand Up @@ -158,8 +148,8 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None:
# Send the trajectory to the pipeline
with RecordTimeTo(timings_dict["rollout_put_time"]):
pipeline.put(traj, timestep, timings_dict)
# Close the environments

# Close the environments
envs.close()

return rollout_fn
Expand Down Expand Up @@ -205,7 +195,7 @@ def get_learner_update_fn(
update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn],
config: DictConfig,
) -> SebulbaLearnerFn[LearnerState, PPOTransition]:
"""Get the learner update function which is used to update the actor and critic networks.
"""Get the learner update function which is used to update the actor and critic networks.
This function is used by the learner thread to update the networks."""

# Get apply and update functions for actor and critic networks.
Expand Down Expand Up @@ -407,19 +397,19 @@ def get_learner_rollout_fn(
pipeline: Pipeline,
params_sources: Sequence[ParamsSource],
):
"""Get the learner rollout function that is used by the learner thread to update the networks.
"""Get the learner rollout function that is used by the learner thread to update the networks.
This function is what is actually run by the learner thread. It gets the data from the pipeline and
uses the learner update function to update the networks. It then sends these intermediate network parameters
uses the learner update function to update the networks. It then sends these intermediate network parameters
to a queue for evaluation."""

def learner_rollout(learner_state: LearnerState) -> None:
# Loop for the total number of evaluations selected to be performed.
for _ in range(config.arch.num_evaluation):
# Create the lists to store metrics and timings for this learning iteration.
metrics: List[Tuple[Dict, Dict]] = []
rollout_times: List[Dict] = []
learn_timings: Dict[str, List[float]] = defaultdict(list)
# Loop for the number of updates per evaluation
# Loop for the number of updates per evaluation
for _ in range(config.arch.num_updates_per_eval):
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
Expand All @@ -429,15 +419,15 @@ def learner_rollout(learner_state: LearnerState) -> None:
# This means the learner has access to the entire trajectory as well as an additional timestep
# which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# We then call the update function to update the networks
# We then call the update function to update the networks
with RecordTimeTo(learn_timings["learning_time"]):
learner_state, episode_metrics, train_metrics = learn(learner_state, traj_batch)

# We store the metrics and timings for this update
metrics.append((episode_metrics, train_metrics))
rollout_times.append(rollout_time)
# After the update we need to update the params sources with the new params

# After the update we need to update the params sources with the new params
unreplicated_params = unreplicate(learner_state.params)
# We loop over all params sources and update them with the new params
# This is so that all the actors can get the latest params
Expand Down Expand Up @@ -483,7 +473,7 @@ def learner_setup(
config: DictConfig,
) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]:
"""Setup for the learner state and networks."""

# Create a single environment just to get the observation and action specs.
env = env_factory(num_envs=1)
# Get number/dimension of actions.
Expand Down Expand Up @@ -536,7 +526,7 @@ def learner_setup(

# Pack params.
params = ActorCriticParams(actor_params, critic_params)

# Extract apply functions.
actor_network_apply_fn = actor_network.apply
critic_network_apply_fn = critic_network.apply
Expand Down Expand Up @@ -579,17 +569,17 @@ def learner_setup(
def run_experiment(_config: DictConfig) -> float:
"""Runs experiment."""
config = copy.deepcopy(_config)

# Perform some checks on the config
# This additionally calculates certains
# values based on the config
config = check_total_timesteps(config)

assert (
config.arch.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."
# Calculate the number of updates per evaluation

# Calculate the number of updates per evaluation
config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation)

# Get the learner and actor devices
Expand All @@ -598,7 +588,7 @@ def run_experiment(_config: DictConfig) -> float:
assert len(local_devices) == len(
global_devices
), "Local and global devices must be the same for now. We dont support multihost just yet"
# Extract the actor and learner devices
# Extract the actor and learner devices
actor_devices = [local_devices[device_id] for device_id in config.arch.actor.device_ids]
local_learner_devices = [
local_devices[device_id] for device_id in config.arch.learner.device_ids
Expand All @@ -607,13 +597,15 @@ def run_experiment(_config: DictConfig) -> float:
print(
f"{Fore.GREEN}{Style.BRIGHT}[Sebulba] Learner devices: {local_learner_devices}{Style.RESET_ALL}"
)
# Set the number of learning and acting devices in the config
# Set the number of learning and acting devices in the config
# useful for keeping track of experimental setup
config.num_learning_devices = len(local_learner_devices)
config.num_actor_actor_devices = len(actor_devices)

# Calculate the number of envs per actor
assert config.arch.num_envs == config.arch.total_num_envs, ("arch.num_envs must equal arch.total_num_envs for Sebulba architectures")
assert (
config.arch.num_envs == config.arch.total_num_envs
), "arch.num_envs must equal arch.total_num_envs for Sebulba architectures"
# We first simply take the total number of envs and divide by the number of actor devices
# to get the number of envs per actor device
num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices)
Expand All @@ -629,11 +621,10 @@ def run_experiment(_config: DictConfig) -> float:
), "The number of envs per actor must be divisible by the number of learner devices"

# Create the environment factory.
# env_factory = EnvPoolFactory(
# "CartPole-v1",
# config.arch.seed
# )
env_factory = GymnasiumFactory("CartPole-v1")
env_factory = environments.make(config)
assert isinstance(
env_factory, EnvFactory
), "Environment factory must be an instance of EnvFactory"

# PRNG keys.
key, key_e, actor_net_key, critic_net_key = jax.random.split(
Expand All @@ -654,6 +645,7 @@ def run_experiment(_config: DictConfig) -> float:
np_rng,
absolute_metric=False,
)
evaluator_envs.close()

# Logger setup
logger = StoixLogger(config)
Expand Down
18 changes: 14 additions & 4 deletions stoix/utils/env_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class EnvFactory(abc.ABC):
"""
Abstract class to create environments
"""
def __init__(self, task_id : str, init_seed: int = 42, **kwargs: Any):

def __init__(self, task_id: str, init_seed: int = 42, **kwargs: Any):
self.task_id = task_id
self.seed = init_seed
# a lock is needed because this object will be used from different threads.
Expand All @@ -36,13 +36,23 @@ def __call__(self, num_envs: int) -> Any:
with self.lock:
seed = self.seed
self.seed += num_envs
return EnvPoolToJumanji(envpool.make(task_id=self.task_id, env_type="gymnasium", num_envs=num_envs, seed=seed, gym_reset_return_info=True, **self.kwargs))
return EnvPoolToJumanji(
envpool.make(
task_id=self.task_id,
env_type="gymnasium",
num_envs=num_envs,
seed=seed,
gym_reset_return_info=True,
**self.kwargs
)
)


class GymnasiumFactory(EnvFactory):
"""
Create environments using gymnasium
"""

def __call__(self, num_envs: int) -> Any:
with self.lock:
vec_env = gymnasium.make_vec(id=self.task_id, num_envs=num_envs, **self.kwargs)
Expand Down
28 changes: 24 additions & 4 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Tuple
from typing import Tuple, Union

import gymnax
import hydra
Expand All @@ -25,6 +25,7 @@
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY

from stoix.utils.debug_env import IdentityGame, SequenceGame
from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
Expand Down Expand Up @@ -371,19 +372,38 @@ def make_navix_env(env_name: str, config: DictConfig) -> Tuple[Environment, Envi
return env, eval_env


def make(config: DictConfig) -> Tuple[Environment, Environment]:
def make_gymnasium_factory(env_name: str, config: DictConfig) -> GymnasiumFactory:

env_factory = GymnasiumFactory(env_name, init_seed=config.arch.seed, **config.env.kwargs)

return env_factory


def make_envpool_factory(env_name: str, config: DictConfig) -> EnvPoolFactory:

env_factory = EnvPoolFactory(env_name, init_seed=config.arch.seed, **config.env.kwargs)

return env_factory


def make(config: DictConfig) -> Union[EnvFactory, Tuple[Environment, Environment]]:
"""
Create environments for training and evaluation..
Args:
config (Dict): The configuration of the environment.
Returns:
A tuple of the environments.
training and evaluation environments or a factory to create them.
"""
env_name = config.env.scenario.name
suite_name = config.env.env_name

if env_name in gymnax_environments:
if "envpool" in suite_name:
return make_envpool_factory(env_name, config)
elif "gymnasium" in suite_name:
return make_gymnasium_factory(env_name, config)
elif env_name in gymnax_environments:
envs = make_gymnax_env(env_name, config)
elif env_name in JUMANJI_REGISTRY:
envs = make_jumanji_env(env_name, config)
Expand Down
1 change: 0 additions & 1 deletion stoix/utils/sebulba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax
import jax.numpy as jnp
from jumanji.types import TimeStep
from omegaconf import DictConfig

from stoix.base_types import Parameters, StoixTransition

Expand Down
4 changes: 2 additions & 2 deletions stoix/utils/total_timestep_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ def check_total_timesteps(config: DictConfig) -> DictConfig:

# Check if the number of devices is set in the config
# If not, it is assumed that the number of devices is 1
# For the case of using a sebulba config, the number of
# For the case of using a sebulba config, the number of
# devices is set to 1 for the calculation
# of the number of environments per device, etc
if "num_devices" not in config:
num_devices = 1
else:
num_devices = num_devices
# If update_batch_size is not in the config, usualyl this means a sebulba config is being used.
# If update_batch_size is not in the config, usually this means a sebulba config is being used.
if "update_batch_size" not in config.arch:
update_batch_size = 1
else:
Expand Down
Loading

0 comments on commit 7278a60

Please sign in to comment.