From 7278a60c12abee95c089cd38d09d940211680205 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 13 Aug 2024 17:28:12 +0000 Subject: [PATCH] chore: some cleanup --- stoix/evaluator.py | 4 +- stoix/systems/ppo/sebulba/ff_ppo.py | 76 ++++++++++++--------------- stoix/utils/env_factory.py | 18 +++++-- stoix/utils/make_env.py | 28 ++++++++-- stoix/utils/sebulba_utils.py | 1 - stoix/utils/total_timestep_checker.py | 4 +- stoix/wrappers/envpool.py | 42 +++++++-------- stoix/wrappers/gymnasium.py | 29 +++++----- 8 files changed, 108 insertions(+), 94 deletions(-) diff --git a/stoix/evaluator.py b/stoix/evaluator.py index 57e64324..40362705 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -396,7 +396,7 @@ def get_sebulba_eval_fn( ) def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict: - """Evaluates the given params on an environment and returns relevent metrics. + """Evaluates the given params on an environment and returns relevant metrics. Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length, also win rate for environments that support it. @@ -431,7 +431,7 @@ def _episode(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]: # care about subsequent steps because we only the results from the first episode done_idx = np.argmax(timesteps.last(), axis=0) metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) - del metrics["is_terminal_step"] # uneeded for logging + del metrics["is_terminal_step"] # unneeded for logging return key, metrics diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 1be21eac..7c32dafd 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -1,6 +1,5 @@ import copy import threading -import time from collections import defaultdict from queue import Queue from typing import Any, Dict, List, Sequence, Tuple @@ -15,7 +14,6 @@ from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate -from jumanji.env import Environment from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -24,29 +22,19 @@ ActorCriticOptStates, ActorCriticParams, CriticApply, - EnvFactory, ExperimentOutput, LearnerFn, LearnerState, - Observation, SebulbaLearnerFn, ) -from stoix.evaluator import ( - evaluator_setup, - get_distribution_act_fn, - get_sebulba_eval_fn, -) +from stoix.evaluator import get_distribution_act_fn, get_sebulba_eval_fn from stoix.networks.base import FeedForwardActor as Actor from stoix.networks.base import FeedForwardCritic as Critic from stoix.systems.ppo.ppo_types import PPOTransition from stoix.utils import make_env as environments from stoix.utils.checkpointing import Checkpointer -from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory -from stoix.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) +from stoix.utils.env_factory import EnvFactory +from stoix.utils.jax_utils import merge_leading_dims from stoix.utils.logger import LogEvent, StoixLogger from stoix.utils.loss import clipped_value_loss, ppo_clip_loss from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation @@ -78,7 +66,9 @@ def get_rollout_fn( critic_apply_fn = jax.jit(critic_apply_fn, device=actor_device) cpu = jax.devices("cpu")[0] split_key_fn = jax.jit(jax.random.split, device=actor_device) - move_to_device = lambda tree: jax.tree_map(lambda x: jax.device_put(x, actor_device), tree) + move_to_device = lambda tree: jax.tree_util.tree_map( + lambda x: jax.device_put(x, actor_device), tree + ) move_to_device = jax.jit(move_to_device) # Build the environments envs = env_factory(config.arch.actor.envs_per_actor) @@ -115,7 +105,7 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None: cached_next_dones = move_to_device(next_dones) cached_next_trunc = move_to_device(next_trunc) - # Run the actor and critic networks to get the action, value and log_prob + # Run the actor and critic networks to get the action, value and log_prob with RecordTimeTo(timings_dict["compute_action_time"]): rng_key, policy_key = split_key_fn(rng_key) pi = actor_apply_fn(params.actor_params, cached_next_obs) @@ -126,7 +116,7 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None: # Move the action to the CPU with RecordTimeTo(timings_dict["put_action_on_cpu_time"]): action_cpu = np.asarray(jax.device_put(action, cpu)) - + # Step the environment with RecordTimeTo(timings_dict["env_step_time"]): timestep = envs.step(action_cpu) @@ -158,8 +148,8 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None: # Send the trajectory to the pipeline with RecordTimeTo(timings_dict["rollout_put_time"]): pipeline.put(traj, timestep, timings_dict) - - # Close the environments + + # Close the environments envs.close() return rollout_fn @@ -205,7 +195,7 @@ def get_learner_update_fn( update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, ) -> SebulbaLearnerFn[LearnerState, PPOTransition]: - """Get the learner update function which is used to update the actor and critic networks. + """Get the learner update function which is used to update the actor and critic networks. This function is used by the learner thread to update the networks.""" # Get apply and update functions for actor and critic networks. @@ -407,11 +397,11 @@ def get_learner_rollout_fn( pipeline: Pipeline, params_sources: Sequence[ParamsSource], ): - """Get the learner rollout function that is used by the learner thread to update the networks. + """Get the learner rollout function that is used by the learner thread to update the networks. This function is what is actually run by the learner thread. It gets the data from the pipeline and - uses the learner update function to update the networks. It then sends these intermediate network parameters + uses the learner update function to update the networks. It then sends these intermediate network parameters to a queue for evaluation.""" - + def learner_rollout(learner_state: LearnerState) -> None: # Loop for the total number of evaluations selected to be performed. for _ in range(config.arch.num_evaluation): @@ -419,7 +409,7 @@ def learner_rollout(learner_state: LearnerState) -> None: metrics: List[Tuple[Dict, Dict]] = [] rollout_times: List[Dict] = [] learn_timings: Dict[str, List[float]] = defaultdict(list) - # Loop for the number of updates per evaluation + # Loop for the number of updates per evaluation for _ in range(config.arch.num_updates_per_eval): # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. @@ -429,15 +419,15 @@ def learner_rollout(learner_state: LearnerState) -> None: # This means the learner has access to the entire trajectory as well as an additional timestep # which it can use to bootstrap. learner_state = learner_state._replace(timestep=timestep) - # We then call the update function to update the networks + # We then call the update function to update the networks with RecordTimeTo(learn_timings["learning_time"]): learner_state, episode_metrics, train_metrics = learn(learner_state, traj_batch) # We store the metrics and timings for this update metrics.append((episode_metrics, train_metrics)) rollout_times.append(rollout_time) - - # After the update we need to update the params sources with the new params + + # After the update we need to update the params sources with the new params unreplicated_params = unreplicate(learner_state.params) # We loop over all params sources and update them with the new params # This is so that all the actors can get the latest params @@ -483,7 +473,7 @@ def learner_setup( config: DictConfig, ) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: """Setup for the learner state and networks.""" - + # Create a single environment just to get the observation and action specs. env = env_factory(num_envs=1) # Get number/dimension of actions. @@ -536,7 +526,7 @@ def learner_setup( # Pack params. params = ActorCriticParams(actor_params, critic_params) - + # Extract apply functions. actor_network_apply_fn = actor_network.apply critic_network_apply_fn = critic_network.apply @@ -579,17 +569,17 @@ def learner_setup( def run_experiment(_config: DictConfig) -> float: """Runs experiment.""" config = copy.deepcopy(_config) - + # Perform some checks on the config # This additionally calculates certains # values based on the config config = check_total_timesteps(config) - + assert ( config.arch.num_updates > config.arch.num_evaluation ), "Number of updates per evaluation must be less than total number of updates." - - # Calculate the number of updates per evaluation + + # Calculate the number of updates per evaluation config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation) # Get the learner and actor devices @@ -598,7 +588,7 @@ def run_experiment(_config: DictConfig) -> float: assert len(local_devices) == len( global_devices ), "Local and global devices must be the same for now. We dont support multihost just yet" - # Extract the actor and learner devices + # Extract the actor and learner devices actor_devices = [local_devices[device_id] for device_id in config.arch.actor.device_ids] local_learner_devices = [ local_devices[device_id] for device_id in config.arch.learner.device_ids @@ -607,13 +597,15 @@ def run_experiment(_config: DictConfig) -> float: print( f"{Fore.GREEN}{Style.BRIGHT}[Sebulba] Learner devices: {local_learner_devices}{Style.RESET_ALL}" ) - # Set the number of learning and acting devices in the config + # Set the number of learning and acting devices in the config # useful for keeping track of experimental setup config.num_learning_devices = len(local_learner_devices) config.num_actor_actor_devices = len(actor_devices) # Calculate the number of envs per actor - assert config.arch.num_envs == config.arch.total_num_envs, ("arch.num_envs must equal arch.total_num_envs for Sebulba architectures") + assert ( + config.arch.num_envs == config.arch.total_num_envs + ), "arch.num_envs must equal arch.total_num_envs for Sebulba architectures" # We first simply take the total number of envs and divide by the number of actor devices # to get the number of envs per actor device num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices) @@ -629,11 +621,10 @@ def run_experiment(_config: DictConfig) -> float: ), "The number of envs per actor must be divisible by the number of learner devices" # Create the environment factory. - # env_factory = EnvPoolFactory( - # "CartPole-v1", - # config.arch.seed - # ) - env_factory = GymnasiumFactory("CartPole-v1") + env_factory = environments.make(config) + assert isinstance( + env_factory, EnvFactory + ), "Environment factory must be an instance of EnvFactory" # PRNG keys. key, key_e, actor_net_key, critic_net_key = jax.random.split( @@ -654,6 +645,7 @@ def run_experiment(_config: DictConfig) -> float: np_rng, absolute_metric=False, ) + evaluator_envs.close() # Logger setup logger = StoixLogger(config) diff --git a/stoix/utils/env_factory.py b/stoix/utils/env_factory.py index 3096bafa..9fbaf4bd 100644 --- a/stoix/utils/env_factory.py +++ b/stoix/utils/env_factory.py @@ -13,8 +13,8 @@ class EnvFactory(abc.ABC): """ Abstract class to create environments """ - - def __init__(self, task_id : str, init_seed: int = 42, **kwargs: Any): + + def __init__(self, task_id: str, init_seed: int = 42, **kwargs: Any): self.task_id = task_id self.seed = init_seed # a lock is needed because this object will be used from different threads. @@ -36,13 +36,23 @@ def __call__(self, num_envs: int) -> Any: with self.lock: seed = self.seed self.seed += num_envs - return EnvPoolToJumanji(envpool.make(task_id=self.task_id, env_type="gymnasium", num_envs=num_envs, seed=seed, gym_reset_return_info=True, **self.kwargs)) + return EnvPoolToJumanji( + envpool.make( + task_id=self.task_id, + env_type="gymnasium", + num_envs=num_envs, + seed=seed, + gym_reset_return_info=True, + **self.kwargs + ) + ) + class GymnasiumFactory(EnvFactory): """ Create environments using gymnasium """ - + def __call__(self, num_envs: int) -> Any: with self.lock: vec_env = gymnasium.make_vec(id=self.task_id, num_envs=num_envs, **self.kwargs) diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 67ebed25..a3b1b236 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -1,5 +1,5 @@ import copy -from typing import Tuple +from typing import Tuple, Union import gymnax import hydra @@ -25,6 +25,7 @@ from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY from stoix.utils.debug_env import IdentityGame, SequenceGame +from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper @@ -371,7 +372,21 @@ def make_navix_env(env_name: str, config: DictConfig) -> Tuple[Environment, Envi return env, eval_env -def make(config: DictConfig) -> Tuple[Environment, Environment]: +def make_gymnasium_factory(env_name: str, config: DictConfig) -> GymnasiumFactory: + + env_factory = GymnasiumFactory(env_name, init_seed=config.arch.seed, **config.env.kwargs) + + return env_factory + + +def make_envpool_factory(env_name: str, config: DictConfig) -> EnvPoolFactory: + + env_factory = EnvPoolFactory(env_name, init_seed=config.arch.seed, **config.env.kwargs) + + return env_factory + + +def make(config: DictConfig) -> Union[EnvFactory, Tuple[Environment, Environment]]: """ Create environments for training and evaluation.. @@ -379,11 +394,16 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: config (Dict): The configuration of the environment. Returns: - A tuple of the environments. + training and evaluation environments or a factory to create them. """ env_name = config.env.scenario.name + suite_name = config.env.env_name - if env_name in gymnax_environments: + if "envpool" in suite_name: + return make_envpool_factory(env_name, config) + elif "gymnasium" in suite_name: + return make_gymnasium_factory(env_name, config) + elif env_name in gymnax_environments: envs = make_gymnax_env(env_name, config) elif env_name in JUMANJI_REGISTRY: envs = make_jumanji_env(env_name, config) diff --git a/stoix/utils/sebulba_utils.py b/stoix/utils/sebulba_utils.py index d68c572e..a9a63d6e 100644 --- a/stoix/utils/sebulba_utils.py +++ b/stoix/utils/sebulba_utils.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp from jumanji.types import TimeStep -from omegaconf import DictConfig from stoix.base_types import Parameters, StoixTransition diff --git a/stoix/utils/total_timestep_checker.py b/stoix/utils/total_timestep_checker.py index c85f4861..6a468de3 100644 --- a/stoix/utils/total_timestep_checker.py +++ b/stoix/utils/total_timestep_checker.py @@ -7,14 +7,14 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: # Check if the number of devices is set in the config # If not, it is assumed that the number of devices is 1 - # For the case of using a sebulba config, the number of + # For the case of using a sebulba config, the number of # devices is set to 1 for the calculation # of the number of environments per device, etc if "num_devices" not in config: num_devices = 1 else: num_devices = num_devices - # If update_batch_size is not in the config, usualyl this means a sebulba config is being used. + # If update_batch_size is not in the config, usually this means a sebulba config is being used. if "update_batch_size" not in config.arch: update_batch_size = 1 else: diff --git a/stoix/wrappers/envpool.py b/stoix/wrappers/envpool.py index 225a515d..053af42c 100644 --- a/stoix/wrappers/envpool.py +++ b/stoix/wrappers/envpool.py @@ -1,16 +1,10 @@ -import sys -import traceback -import warnings -from multiprocessing import Queue -from multiprocessing.connection import Connection -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import envpool -import jax +from typing import Any, Dict, Optional + import numpy as np +from jumanji.specs import Array, DiscreteArray, Spec from jumanji.types import StepType, TimeStep from numpy.typing import NDArray -from jumanji.specs import Array, Spec, DiscreteArray + from stoix.base_types import Observation @@ -29,7 +23,7 @@ def __init__(self, env: Any): self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) self.episode_return = np.zeros(self.num_envs, dtype=float) self.episode_length = np.zeros(self.num_envs, dtype=int) - + def reset( self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None ) -> TimeStep: @@ -38,20 +32,20 @@ def reset( ep_done = np.zeros(self.num_envs, dtype=float) rewards = np.zeros(self.num_envs, dtype=float) terminated = np.zeros(self.num_envs, dtype=float) - + # Reset the metrics self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) self.episode_return = np.zeros(self.num_envs, dtype=float) self.episode_length = np.zeros(self.num_envs, dtype=int) - + # Create the metrics dict metrics = { "episode_return": np.zeros(self.num_envs, dtype=float), "episode_length": np.zeros(self.num_envs, dtype=int), "is_terminal_step": np.zeros(self.num_envs, dtype=bool), } - + info["metrics"] = metrics timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) @@ -62,13 +56,13 @@ def step(self, action: list) -> TimeStep: obs, rewards, terminated, truncated, info = self.env.step(action) ep_done = np.logical_or(terminated, truncated) not_done = 1 - ep_done - + # Counting episode return and length. if "reward" in info: metric_reward = info["reward"] else: metric_reward = rewards - + # Counting episode return and length. new_episode_return = self.running_count_episode_return + metric_reward new_episode_length = self.running_count_episode_length + 1 @@ -85,11 +79,11 @@ def step(self, action: list) -> TimeStep: info["metrics"] = metrics # Update the metrics - self.running_count_episode_return=new_episode_return * not_done - self.running_count_episode_length=new_episode_length * not_done - self.episode_return=episode_return_info - self.episode_length=episode_length_info - + self.running_count_episode_return = new_episode_return * not_done + self.running_count_episode_length = new_episode_length * not_done + self.episode_return = episode_return_info + self.episode_length = episode_length_info + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) return timestep @@ -112,7 +106,7 @@ def _create_timestep( observation=obs, extras=extras, ) - + def observation_spec(self) -> Spec: agent_view_spec = Array(shape=self.obs_shape, dtype=float) return Spec( @@ -125,6 +119,6 @@ def observation_spec(self) -> Spec: def action_spec(self) -> Spec: return DiscreteArray(num_values=self.num_actions) - + def close(self) -> None: - self.env.close() \ No newline at end of file + self.env.close() diff --git a/stoix/wrappers/gymnasium.py b/stoix/wrappers/gymnasium.py index 269b63c1..ee12358c 100644 --- a/stoix/wrappers/gymnasium.py +++ b/stoix/wrappers/gymnasium.py @@ -1,18 +1,17 @@ -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Dict, Optional import gymnasium -import jax import numpy as np +from jumanji.specs import Array, DiscreteArray, Spec from jumanji.types import StepType, TimeStep from numpy.typing import NDArray -from jumanji.specs import Array, Spec, DiscreteArray from stoix.base_types import Observation class VecGymToJumanji: """Converts from a Vectorised Gymnasium environment to Jumanji's API.""" - + def __init__(self, env: gymnasium.vector.AsyncVectorEnv): self.env = env self.num_envs = int(self.env.num_envs) @@ -30,7 +29,7 @@ def __init__(self, env: gymnasium.vector.AsyncVectorEnv): self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) self.episode_return = np.zeros(self.num_envs, dtype=float) self.episode_length = np.zeros(self.num_envs, dtype=int) - + def reset( self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None ) -> TimeStep: @@ -39,20 +38,20 @@ def reset( ep_done = np.zeros(self.num_envs, dtype=float) rewards = np.zeros(self.num_envs, dtype=float) terminated = np.zeros(self.num_envs, dtype=float) - + # Reset the metrics self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) self.episode_return = np.zeros(self.num_envs, dtype=float) self.episode_length = np.zeros(self.num_envs, dtype=int) - + # Create the metrics dict metrics = { "episode_return": np.zeros(self.num_envs, dtype=float), "episode_length": np.zeros(self.num_envs, dtype=int), "is_terminal_step": np.zeros(self.num_envs, dtype=bool), } - + info["metrics"] = metrics timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) @@ -67,7 +66,7 @@ def step(self, action: list) -> TimeStep: truncated = np.asarray(truncated) ep_done = np.logical_or(terminated, truncated) not_done = 1 - ep_done - + # Counting episode return and length. new_episode_return = self.running_count_episode_return + rewards new_episode_length = self.running_count_episode_length + 1 @@ -84,10 +83,10 @@ def step(self, action: list) -> TimeStep: info["metrics"] = metrics # Update the metrics - self.running_count_episode_return=new_episode_return * not_done - self.running_count_episode_length=new_episode_length * not_done - self.episode_return=episode_return_info - self.episode_length=episode_length_info + self.running_count_episode_return = new_episode_return * not_done + self.running_count_episode_length = new_episode_length * not_done + self.episode_return = episode_return_info + self.episode_length = episode_length_info timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) @@ -111,7 +110,7 @@ def _create_timestep( observation=obs, extras=extras, ) - + def observation_spec(self) -> Spec: agent_view_spec = Array(shape=self.obs_shape, dtype=float) return Spec( @@ -129,4 +128,4 @@ def action_spec(self) -> Spec: return Array(shape=(self.num_actions,), dtype=float) def close(self): - self.env.close() \ No newline at end of file + self.env.close()