From 3c379e9e853c84a7df1483983bf03a52c06fbf71 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Thu, 29 Aug 2024 21:53:11 +0100 Subject: [PATCH 1/3] chore: quick rough untested wrapper and env factory --- stoix/base_types.py | 3 - stoix/evaluator.py | 2 +- stoix/utils/make_env.py | 22 +++++-- stoix/wrappers/jax_to_factory.py | 102 +++++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 8 deletions(-) create mode 100644 stoix/wrappers/jax_to_factory.py diff --git a/stoix/base_types.py b/stoix/base_types.py index af26199..83691d9 100644 --- a/stoix/base_types.py +++ b/stoix/base_types.py @@ -203,6 +203,3 @@ class EvaluationOutput(NamedTuple, Generic[StoixState]): [FrozenDict, HiddenState, RNNObservation, chex.PRNGKey], Tuple[HiddenState, chex.Array] ] RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]] - - -EnvFactory = Callable[[int], Any] diff --git a/stoix/evaluator.py b/stoix/evaluator.py index 27c6103..ecebd77 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -15,7 +15,6 @@ from stoix.base_types import ( ActFn, ActorApply, - EnvFactory, EvalFn, EvalState, EvaluationOutput, @@ -25,6 +24,7 @@ RNNObservation, SebulbaEvalFn, ) +from stoix.utils.env_factory import EnvFactory from stoix.utils.jax_utils import unreplicate_batch_dim diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 65821a8..2348069 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -1,8 +1,9 @@ import copy -from typing import Tuple, Union +from typing import Tuple import gymnax import hydra +import jax import jax.numpy as jnp import jaxmarl import jumanji @@ -25,9 +26,10 @@ from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY from stoix.utils.debug_env import IdentityGame, SequenceGame -from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory +from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper +from stoix.wrappers.jax_to_factory import JaxToStateful from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper from stoix.wrappers.navix import NavixWrapper from stoix.wrappers.pgx import PGXWrapper @@ -426,7 +428,19 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: return envs -def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]: +class JaxEnvFactory(EnvFactory): + """ + Create environments using stoix-ready JAX environments + """ + + def __call__(self, num_envs: int) -> JaxToStateful: + cpu = jax.devices("cpu")[0] + with self.lock: + train_env, _ = make(self.task_id) + return JaxToStateful(train_env, num_envs, cpu, self.seed, **self.kwargs) + + +def make_factory(config: DictConfig) -> EnvFactory: """ Create a env_factory for sebulba systems. @@ -444,4 +458,4 @@ def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]: elif "gymnasium" in suite_name: return make_gymnasium_factory(env_name, config) else: - raise ValueError(f"{suite_name} is not a supported suite.") + return JaxEnvFactory(config, init_seed=config.arch.seed, **config.env.kwargs) diff --git a/stoix/wrappers/jax_to_factory.py b/stoix/wrappers/jax_to_factory.py new file mode 100644 index 0000000..6198d9f --- /dev/null +++ b/stoix/wrappers/jax_to_factory.py @@ -0,0 +1,102 @@ +from typing import Optional + +import jax +import numpy as np +from jumanji.env import Environment +from jumanji.specs import Spec +from jumanji.types import TimeStep + + +class JaxToStateful: + """Converts a Stoix-ready JAX environment to a stateful one to be used by Sebulba systems.""" + + def __init__(self, env: Environment, num_envs: int, device: jax.Device, init_seed: int): + self.env = env + self.num_envs = num_envs + self.device = device + + # Create the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # Create the seeds + init_seeds = jax.random.randint(jax.random.PRNGKey(init_seed), (num_envs,), 0, 2**32) + self.rng_keys = jax.vmap(jax.random.PRNGKey)(init_seeds) + + # Vmap and compile the reset and step functions + self.vmapped_reset = jax.jit(jax.vmap(self.env.reset), device=self.device) + self.vmapped_step = jax.jit(jax.vmap(self.env.step, in_axes=(0, 0)), device=self.device) + + def reset( + self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None + ) -> TimeStep: + with jax.default_device(self.device): + + self.state, timestep = self.vmapped_reset(self.rng_keys) + + # Reset the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # Create the metrics dict + metrics = { + "episode_return": np.zeros(self.num_envs, dtype=float), + "episode_length": np.zeros(self.num_envs, dtype=int), + "is_terminal_step": np.zeros(self.num_envs, dtype=bool), + } + + timestep_extras = timestep.extras + + timestep_extras["metrics"] = metrics + + timestep = timestep.replace(extras=timestep_extras) + + return timestep + + def step(self, action: list) -> TimeStep: + with jax.default_device(self.device): + self.state, timestep = self.vmapped_step(self.state, action) + + ep_done = timestep.last() + not_done = ~ep_done + + # Counting episode return and length. + new_episode_return = self.running_count_episode_return + timestep.reward + new_episode_length = self.running_count_episode_length + 1 + + # Update the episode return and length if the episode is done otherwise + # keep the previous values + episode_return_info = self.episode_return * not_done + new_episode_return * ep_done + episode_length_info = self.episode_length * not_done + new_episode_length * ep_done + # Update the running count + self.running_count_episode_return = new_episode_return * not_done + self.running_count_episode_length = new_episode_length * not_done + + self.episode_return = episode_return_info + self.episode_length = episode_length_info + + # Create the metrics dict + metrics = { + "episode_return": episode_return_info, + "episode_length": episode_length_info, + "is_terminal_step": ep_done, + } + + timestep_extras = timestep.extras + timestep_extras["metrics"] = metrics + timestep = timestep.replace(extras=timestep_extras) + + return timestep + + def observation_spec(self) -> Spec: + return self.env.observation_spec() + + def action_spec(self) -> Spec: + return self.env.action_spec() + + def close(self) -> None: + self.env.close() From 7da6be498c02775811108a86394b608fa7e8f13b Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 30 Aug 2024 13:38:16 +0100 Subject: [PATCH 2/3] chore: slight edit to structure --- stoix/utils/make_env.py | 19 +++---------------- stoix/wrappers/jax_to_factory.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 2348069..a391dd0 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -26,10 +26,10 @@ from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY from stoix.utils.debug_env import IdentityGame, SequenceGame -from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory +from stoix.utils.env_factory import EnvFactory, GymnasiumFactory, EnvPoolFactory from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper -from stoix.wrappers.jax_to_factory import JaxToStateful +from stoix.wrappers.jax_to_factory import JaxEnvFactory, JaxToStateful from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper from stoix.wrappers.navix import NavixWrapper from stoix.wrappers.pgx import PGXWrapper @@ -427,19 +427,6 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: return envs - -class JaxEnvFactory(EnvFactory): - """ - Create environments using stoix-ready JAX environments - """ - - def __call__(self, num_envs: int) -> JaxToStateful: - cpu = jax.devices("cpu")[0] - with self.lock: - train_env, _ = make(self.task_id) - return JaxToStateful(train_env, num_envs, cpu, self.seed, **self.kwargs) - - def make_factory(config: DictConfig) -> EnvFactory: """ Create a env_factory for sebulba systems. @@ -458,4 +445,4 @@ def make_factory(config: DictConfig) -> EnvFactory: elif "gymnasium" in suite_name: return make_gymnasium_factory(env_name, config) else: - return JaxEnvFactory(config, init_seed=config.arch.seed, **config.env.kwargs) + return JaxEnvFactory(make(config)[0], init_seed=config.arch.seed) diff --git a/stoix/wrappers/jax_to_factory.py b/stoix/wrappers/jax_to_factory.py index 6198d9f..18485d7 100644 --- a/stoix/wrappers/jax_to_factory.py +++ b/stoix/wrappers/jax_to_factory.py @@ -1,11 +1,15 @@ -from typing import Optional +import threading +from typing import Any, Optional import jax +import jax.numpy as jnp import numpy as np from jumanji.env import Environment from jumanji.specs import Spec from jumanji.types import TimeStep +from stoix.utils.env_factory import EnvFactory + class JaxToStateful: """Converts a Stoix-ready JAX environment to a stateful one to be used by Sebulba systems.""" @@ -22,7 +26,9 @@ def __init__(self, env: Environment, num_envs: int, device: jax.Device, init_see self.episode_length = np.zeros(self.num_envs, dtype=int) # Create the seeds - init_seeds = jax.random.randint(jax.random.PRNGKey(init_seed), (num_envs,), 0, 2**32) + max_int = np.iinfo(np.int32).max + min_int = np.iinfo(np.int32).min + init_seeds = jax.random.randint(jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int) self.rng_keys = jax.vmap(jax.random.PRNGKey)(init_seeds) # Vmap and compile the reset and step functions @@ -99,4 +105,24 @@ def action_spec(self) -> Spec: return self.env.action_spec() def close(self) -> None: - self.env.close() + pass + + +class JaxEnvFactory(EnvFactory): + """ + Create environments using stoix-ready JAX environments + """ + + def __init__(self, jax_env : Environment, init_seed: int): + self.jax_env = jax_env + self.cpu = jax.devices("cpu")[0] + self.seed = init_seed + # a lock is needed because this object will be used from different threads. + # We want to make sure all seeds are unique + self.lock = threading.Lock() + + def __call__(self, num_envs: int) -> JaxToStateful: + with self.lock: + seed = self.seed + self.seed += num_envs + return JaxToStateful(self.jax_env, num_envs, self.cpu, seed) \ No newline at end of file From c34e82efab119ef602d1a84b920b1779b975e1b5 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 20 Sep 2024 16:49:03 +0200 Subject: [PATCH 3/3] chore: merge main --- stoix/utils/make_env.py | 6 +++--- stoix/wrappers/jax_to_factory.py | 13 +++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 07b1bf3..1e155f9 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -3,7 +3,6 @@ import gymnax import hydra -import jax import jax.numpy as jnp import jaxmarl import jumanji @@ -26,10 +25,10 @@ from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY from stoix.utils.debug_env import IdentityGame, SequenceGame -from stoix.utils.env_factory import EnvFactory, GymnasiumFactory, EnvPoolFactory +from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper -from stoix.wrappers.jax_to_factory import JaxEnvFactory, JaxToStateful +from stoix.wrappers.jax_to_factory import JaxEnvFactory from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper from stoix.wrappers.navix import NavixWrapper from stoix.wrappers.pgx import PGXWrapper @@ -427,6 +426,7 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: return envs + def make_factory(config: DictConfig) -> EnvFactory: """ Create a env_factory for sebulba systems. diff --git a/stoix/wrappers/jax_to_factory.py b/stoix/wrappers/jax_to_factory.py index 18485d7..ce474a4 100644 --- a/stoix/wrappers/jax_to_factory.py +++ b/stoix/wrappers/jax_to_factory.py @@ -1,8 +1,7 @@ import threading -from typing import Any, Optional +from typing import Optional import jax -import jax.numpy as jnp import numpy as np from jumanji.env import Environment from jumanji.specs import Spec @@ -28,7 +27,9 @@ def __init__(self, env: Environment, num_envs: int, device: jax.Device, init_see # Create the seeds max_int = np.iinfo(np.int32).max min_int = np.iinfo(np.int32).min - init_seeds = jax.random.randint(jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int) + init_seeds = jax.random.randint( + jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int + ) self.rng_keys = jax.vmap(jax.random.PRNGKey)(init_seeds) # Vmap and compile the reset and step functions @@ -112,8 +113,8 @@ class JaxEnvFactory(EnvFactory): """ Create environments using stoix-ready JAX environments """ - - def __init__(self, jax_env : Environment, init_seed: int): + + def __init__(self, jax_env: Environment, init_seed: int): self.jax_env = jax_env self.cpu = jax.devices("cpu")[0] self.seed = init_seed @@ -125,4 +126,4 @@ def __call__(self, num_envs: int) -> JaxToStateful: with self.lock: seed = self.seed self.seed += num_envs - return JaxToStateful(self.jax_env, num_envs, self.cpu, seed) \ No newline at end of file + return JaxToStateful(self.jax_env, num_envs, self.cpu, seed)