diff --git a/stoix/base_types.py b/stoix/base_types.py index af26199..83691d9 100644 --- a/stoix/base_types.py +++ b/stoix/base_types.py @@ -203,6 +203,3 @@ class EvaluationOutput(NamedTuple, Generic[StoixState]): [FrozenDict, HiddenState, RNNObservation, chex.PRNGKey], Tuple[HiddenState, chex.Array] ] RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]] - - -EnvFactory = Callable[[int], Any] diff --git a/stoix/evaluator.py b/stoix/evaluator.py index 27c6103..ecebd77 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -15,7 +15,6 @@ from stoix.base_types import ( ActFn, ActorApply, - EnvFactory, EvalFn, EvalState, EvaluationOutput, @@ -25,6 +24,7 @@ RNNObservation, SebulbaEvalFn, ) +from stoix.utils.env_factory import EnvFactory from stoix.utils.jax_utils import unreplicate_batch_dim diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 37812f7..1e155f9 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -1,5 +1,5 @@ import copy -from typing import Tuple, Union +from typing import Tuple import gymnax import hydra @@ -25,9 +25,10 @@ from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY from stoix.utils.debug_env import IdentityGame, SequenceGame -from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory +from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper +from stoix.wrappers.jax_to_factory import JaxEnvFactory from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper from stoix.wrappers.navix import NavixWrapper from stoix.wrappers.pgx import PGXWrapper @@ -426,7 +427,7 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: return envs -def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]: +def make_factory(config: DictConfig) -> EnvFactory: """ Create a env_factory for sebulba systems. @@ -444,4 +445,4 @@ def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]: elif "gymnasium" in suite_name: return make_gymnasium_factory(env_name, config) else: - raise ValueError(f"{suite_name} is not a supported suite.") + return JaxEnvFactory(make(config)[0], init_seed=config.arch.seed) diff --git a/stoix/wrappers/jax_to_factory.py b/stoix/wrappers/jax_to_factory.py new file mode 100644 index 0000000..ce474a4 --- /dev/null +++ b/stoix/wrappers/jax_to_factory.py @@ -0,0 +1,129 @@ +import threading +from typing import Optional + +import jax +import numpy as np +from jumanji.env import Environment +from jumanji.specs import Spec +from jumanji.types import TimeStep + +from stoix.utils.env_factory import EnvFactory + + +class JaxToStateful: + """Converts a Stoix-ready JAX environment to a stateful one to be used by Sebulba systems.""" + + def __init__(self, env: Environment, num_envs: int, device: jax.Device, init_seed: int): + self.env = env + self.num_envs = num_envs + self.device = device + + # Create the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # Create the seeds + max_int = np.iinfo(np.int32).max + min_int = np.iinfo(np.int32).min + init_seeds = jax.random.randint( + jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int + ) + self.rng_keys = jax.vmap(jax.random.PRNGKey)(init_seeds) + + # Vmap and compile the reset and step functions + self.vmapped_reset = jax.jit(jax.vmap(self.env.reset), device=self.device) + self.vmapped_step = jax.jit(jax.vmap(self.env.step, in_axes=(0, 0)), device=self.device) + + def reset( + self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None + ) -> TimeStep: + with jax.default_device(self.device): + + self.state, timestep = self.vmapped_reset(self.rng_keys) + + # Reset the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # Create the metrics dict + metrics = { + "episode_return": np.zeros(self.num_envs, dtype=float), + "episode_length": np.zeros(self.num_envs, dtype=int), + "is_terminal_step": np.zeros(self.num_envs, dtype=bool), + } + + timestep_extras = timestep.extras + + timestep_extras["metrics"] = metrics + + timestep = timestep.replace(extras=timestep_extras) + + return timestep + + def step(self, action: list) -> TimeStep: + with jax.default_device(self.device): + self.state, timestep = self.vmapped_step(self.state, action) + + ep_done = timestep.last() + not_done = ~ep_done + + # Counting episode return and length. + new_episode_return = self.running_count_episode_return + timestep.reward + new_episode_length = self.running_count_episode_length + 1 + + # Update the episode return and length if the episode is done otherwise + # keep the previous values + episode_return_info = self.episode_return * not_done + new_episode_return * ep_done + episode_length_info = self.episode_length * not_done + new_episode_length * ep_done + # Update the running count + self.running_count_episode_return = new_episode_return * not_done + self.running_count_episode_length = new_episode_length * not_done + + self.episode_return = episode_return_info + self.episode_length = episode_length_info + + # Create the metrics dict + metrics = { + "episode_return": episode_return_info, + "episode_length": episode_length_info, + "is_terminal_step": ep_done, + } + + timestep_extras = timestep.extras + timestep_extras["metrics"] = metrics + timestep = timestep.replace(extras=timestep_extras) + + return timestep + + def observation_spec(self) -> Spec: + return self.env.observation_spec() + + def action_spec(self) -> Spec: + return self.env.action_spec() + + def close(self) -> None: + pass + + +class JaxEnvFactory(EnvFactory): + """ + Create environments using stoix-ready JAX environments + """ + + def __init__(self, jax_env: Environment, init_seed: int): + self.jax_env = jax_env + self.cpu = jax.devices("cpu")[0] + self.seed = init_seed + # a lock is needed because this object will be used from different threads. + # We want to make sure all seeds are unique + self.lock = threading.Lock() + + def __call__(self, num_envs: int) -> JaxToStateful: + with self.lock: + seed = self.seed + self.seed += num_envs + return JaxToStateful(self.jax_env, num_envs, self.cpu, seed)