Skip to content

Commit

Permalink
chore: precommit formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 15, 2024
1 parent 07590df commit 7684678
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 46 deletions.
2 changes: 1 addition & 1 deletion stoix/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ learner:
device_ids: [1,2] # Define which devices to use for the learner.

# Size of the queue for the pipeline where actors push data and the learner pulls data.
pipeline_queue_size: 20
pipeline_queue_size: 10

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
Expand Down
2 changes: 1 addition & 1 deletion stoix/configs/logger/base_logger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

base_exp_path: results # Base path for logging.
use_console: True # Whether to log to stdout.
use_tb: True # Whether to use tensorboard logging.
use_tb: False # Whether to use tensorboard logging.
use_json: False # Whether to log marl-eval style to json files.
use_neptune: False # Whether to log to neptune.ai.
use_wandb: False # Whether to log to wandb.ai.
Expand Down
12 changes: 8 additions & 4 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
import time
import warnings
from typing import Any, Dict, Optional, Tuple, Union

import chex
Expand Down Expand Up @@ -377,7 +376,8 @@ def get_sebulba_eval_fn(
and returns actions and optionally a state (see `EvalActFn`).
config: the system config.
np_rng: a numpy random number generator.
eval_multiplier: a scalar that will increase the number of evaluation episodes by a fixed factor.
eval_multiplier: a scalar that will increase the number of evaluation episodes
by a fixed factor.
"""
eval_episodes = config.arch.num_eval_episodes * eval_multiplier

Expand All @@ -394,9 +394,13 @@ def get_sebulba_eval_fn(

# Warnings if num eval episodes is not divisible by num parallel envs.
if eval_episodes % n_parallel_envs != 0:
print(
f"{Fore.YELLOW}{Style.BRIGHT}Number of evaluation episodes ({eval_episodes}) is not divisible by `num_envs`. Some extra evaluations will be executed. New number of evaluation episodes = {episode_loops * n_parallel_envs}{Style.RESET_ALL}"
msg = (
f"Please note that the number of evaluation episodes ({eval_episodes}) is not "
f"evenly divisible by `num_envs`. As a result, some additional evaluations will be "
f"conducted. The adjusted number of evaluation episodes is now "
f"{episode_loops * n_parallel_envs}."
)
print(f"{Fore.YELLOW}{Style.BRIGHT}{msg}{Style.RESET_ALL}")

def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict:
"""Evaluates the given params on an environment and returns relevant metrics.
Expand Down
48 changes: 24 additions & 24 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_act_fn(
def actor_fn(
params: ActorCriticParams, observation: Observation, rng_key: chex.PRNGKey
) -> Tuple[chex.Array, chex.Array]:
"""Takes in the actor params, observation and rng_key and returns the action, value and log prob."""
"""Get the action, value and log_prob from the actor and critic networks."""
rng_key, policy_key = jax.random.split(rng_key)
pi = actor_apply_fn(params.actor_params, observation)
value = critic_apply_fn(params.critic_params, observation)
Expand Down Expand Up @@ -186,7 +186,8 @@ def get_actor_thread(
thread_lifetime: ThreadLifetime,
name: str,
):
"""Get the actor thread that once started will collect data from the environment and send it to the pipeline."""
"""Get the actor thread that once started will collect data from the
environment and send it to the pipeline."""
rng_key = jax.device_put(rng_key, actor_device)

rollout_fn = get_rollout_fn(
Expand Down Expand Up @@ -417,9 +418,9 @@ def get_learner_rollout_fn(
params_sources: Sequence[ParamsSource],
):
"""Get the learner rollout function that is used by the learner thread to update the networks.
This function is what is actually run by the learner thread. It gets the data from the pipeline and
uses the learner update function to update the networks. It then sends these intermediate network parameters
to a queue for evaluation."""
This function is what is actually run by the learner thread. It gets the data from the pipeline
and uses the learner update function to update the networks. It then sends these intermediate
network parameters to a queue for evaluation."""

def learner_rollout(learner_state: LearnerState) -> None:
# Loop for the total number of evaluations selected to be performed.
Expand All @@ -435,8 +436,8 @@ def learner_rollout(learner_state: LearnerState) -> None:
with RecordTimeTo(learn_timings["rollout_get_time"]):
traj_batch, timestep, rollout_time = pipeline.get(block=True)
# We then replace the timestep in the learner state with the latest timestep
# This means the learner has access to the entire trajectory as well as an additional timestep
# which it can use to bootstrap.
# This means the learner has access to the entire trajectory as well as
# an additional timestep which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# We then call the update function to update the networks
with RecordTimeTo(learn_timings["learning_time"]):
Expand All @@ -453,8 +454,9 @@ def learner_rollout(learner_state: LearnerState) -> None:
for source in params_sources:
source.update(unreplicated_params)

# We then pass all the environment metrics, training metrics, current learner state and timings to the evaluation queue
# This is so the evaluator correctly evaluates the performance of the networks at this point in time.
# We then pass all the environment metrics, training metrics, current learner state
# and timings to the evaluation queue. This is so the evaluator correctly evaluates
# the performance of the networks at this point in time.
episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics)
rollout_times = jax.tree.map(lambda *x: np.mean(x), *rollout_times)
timing_dict = rollout_times | learn_timings
Expand Down Expand Up @@ -614,13 +616,9 @@ def run_experiment(_config: DictConfig) -> float:
local_learner_devices = [
local_devices[device_id] for device_id in config.arch.learner.device_ids
]
print(f"{Fore.BLUE}{Style.BRIGHT}[Sebulba] Actors devices: {actor_devices}{Style.RESET_ALL}")
print(
f"{Fore.GREEN}{Style.BRIGHT}[Sebulba] Learner devices: {local_learner_devices}{Style.RESET_ALL}"
)
print(
f"{Fore.MAGENTA}{Style.BRIGHT}[Sebulba] Global devices: {global_devices}{Style.RESET_ALL}"
)
print(f"{Fore.BLUE}{Style.BRIGHT}Actors devices: {actor_devices}{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}Learner devices: {local_learner_devices}{Style.RESET_ALL}")
print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}")
# Set the number of learning and acting devices in the config
# useful for keeping track of experimental setup
config.num_learning_devices = len(local_learner_devices)
Expand All @@ -637,15 +635,16 @@ def run_experiment(_config: DictConfig) -> float:
num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device)
config.arch.actor.envs_per_actor = num_envs_per_actor

# We then perform a simple check to ensure that the number of envs per actor is divisible by the number of learner devices
# This is because we shard the envs per actor across the learner devices
# This check is mainly relevant for on-policy algorithms
# We then perform a simple check to ensure that the number of envs per actor is
# divisible by the number of learner devices. This is because we shard the envs
# per actor across the learner devices This check is mainly relevant for on-policy
# algorithms
assert (
num_envs_per_actor % len(local_learner_devices) == 0
), "The number of envs per actor must be divisible by the number of learner devices"

# Create the environment factory.
env_factory = environments.make(config)
env_factory = environments.make_factory(config)
assert isinstance(
env_factory, EnvFactory
), "Environment factory must be an instance of EnvFactory"
Expand Down Expand Up @@ -706,7 +705,8 @@ def run_experiment(_config: DictConfig) -> float:
params_sources: List[ParamsSource] = []
actor_threads: List[threading.Thread] = []
for actor_device in actor_devices:
# Create 1 params source per actor device as this will be used to pass the params to the actors
# Create 1 params source per actor device as this will be used
# to pass the params to the actors
params_source = ParamsSource(initial_params, actor_device, params_sources_lifetime)
params_source.start()
params_sources.append(params_source)
Expand Down Expand Up @@ -797,16 +797,16 @@ def run_experiment(_config: DictConfig) -> float:
# Now we stop the actors and params sources
for actor in actor_threads:
actor.join()

# Stop the pipeline
pipeline_lifetime.stop()
pipeline.join()

# Stop the params sources
params_sources_lifetime.stop()
for param_source in params_sources:
param_source.join()

# Measure absolute metric.
if config.arch.absolute_metric:
abs_metric_evaluator, abs_metric_evaluator_envs = get_sebulba_eval_fn(
Expand Down
30 changes: 22 additions & 8 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY

from stoix.utils.debug_env import IdentityGame, SequenceGame
from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory
from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
Expand Down Expand Up @@ -386,7 +386,7 @@ def make_envpool_factory(env_name: str, config: DictConfig) -> EnvPoolFactory:
return env_factory


def make(config: DictConfig) -> Union[EnvFactory, Tuple[Environment, Environment]]:
def make(config: DictConfig) -> Tuple[Environment, Environment]:
"""
Create environments for training and evaluation..
Expand All @@ -397,13 +397,8 @@ def make(config: DictConfig) -> Union[EnvFactory, Tuple[Environment, Environment
training and evaluation environments or a factory to create them.
"""
env_name = config.env.scenario.name
suite_name = config.env.env_name

if "envpool" in suite_name:
return make_envpool_factory(env_name, config)
elif "gymnasium" in suite_name:
return make_gymnasium_factory(env_name, config)
elif env_name in gymnax_environments:
if env_name in gymnax_environments:
envs = make_gymnax_env(env_name, config)
elif env_name in JUMANJI_REGISTRY:
envs = make_jumanji_env(env_name, config)
Expand All @@ -429,3 +424,22 @@ def make(config: DictConfig) -> Union[EnvFactory, Tuple[Environment, Environment
envs = apply_optional_wrappers(envs, config)

return envs


def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]:
"""
Create a env_factory for sebulba systems.
Args:
config (Dict): The configuration of the environment.
Returns:
A factory to create environments.
"""
env_name = config.env.scenario.name
suite_name = config.env.env_name

if "envpool" in suite_name:
return make_envpool_factory(env_name, config)
elif "gymnasium" in suite_name:
return make_gymnasium_factory(env_name, config)
17 changes: 10 additions & 7 deletions stoix/utils/sebulba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,16 @@ def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, timings_dict:
# [(num_envs / num_learner_devices, ...)] * num_learner_devices
sharded_timestep = jax.tree.map(self.shard_split_playload, timestep)

# We block on the put to ensure that actors wait for the learners to catch up. This does two things:
# 1. It ensures that the actors don't get too far ahead of the learners, which could lead to off-policy data.
# 2. It ensures that the actors don't in a sense "waste" samples and their time by generating samples that
# the learners can't consume.
# However, we put a timeout of 90 seconds to avoid deadlocks in case the learner is not consuming the data.
# This is a safety measure and should not be hit in normal operation.
# We use a try-finally since the lock has to be released even if an exception is raised.
# We block on the put to ensure that actors wait for the learners to catch up. This does two
# things:
# 1. It ensures that the actors don't get too far ahead of the learners, which could lead to
# off-policy data.
# 2. It ensures that the actors don't in a sense "waste" samples and their time by
# generating samples that the learners can't consume.
# However, we put a timeout of 90 seconds to avoid deadlocks in case the learner
# is not consuming the data. This is a safety measure and should not be hit in normal
# operation. We use a try-finally since the lock has to be released even if an exception
# is raised.
try:
self._queue.put((sharded_traj, sharded_timestep, timings_dict), block=True, timeout=90)
finally:
Expand Down
2 changes: 1 addition & 1 deletion stoix/wrappers/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,5 @@ def action_spec(self) -> Spec:
else:
return Array(shape=(self.num_actions,), dtype=float)

def close(self):
def close(self) -> None:
self.env.close()

0 comments on commit 7684678

Please sign in to comment.