Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add jax env factory #118

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,3 @@ class EvaluationOutput(NamedTuple, Generic[StoixState]):
[FrozenDict, HiddenState, RNNObservation, chex.PRNGKey], Tuple[HiddenState, chex.Array]
]
RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]]


EnvFactory = Callable[[int], Any]
2 changes: 1 addition & 1 deletion stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from stoix.base_types import (
ActFn,
ActorApply,
EnvFactory,
EvalFn,
EvalState,
EvaluationOutput,
Expand All @@ -25,6 +24,7 @@
RNNObservation,
SebulbaEvalFn,
)
from stoix.utils.env_factory import EnvFactory
from stoix.utils.jax_utils import unreplicate_batch_dim


Expand Down
9 changes: 5 additions & 4 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Tuple, Union
from typing import Tuple

import gymnax
import hydra
Expand All @@ -25,9 +25,10 @@
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY

from stoix.utils.debug_env import IdentityGame, SequenceGame
from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory
from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jax_to_factory import JaxEnvFactory
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
from stoix.wrappers.navix import NavixWrapper
from stoix.wrappers.pgx import PGXWrapper
Expand Down Expand Up @@ -426,7 +427,7 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]:
return envs


def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]:
def make_factory(config: DictConfig) -> EnvFactory:
"""
Create a env_factory for sebulba systems.

Expand All @@ -444,4 +445,4 @@ def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]:
elif "gymnasium" in suite_name:
return make_gymnasium_factory(env_name, config)
else:
raise ValueError(f"{suite_name} is not a supported suite.")
return JaxEnvFactory(make(config)[0], init_seed=config.arch.seed)
129 changes: 129 additions & 0 deletions stoix/wrappers/jax_to_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import threading
from typing import Optional

import jax
import numpy as np
from jumanji.env import Environment
from jumanji.specs import Spec
from jumanji.types import TimeStep

from stoix.utils.env_factory import EnvFactory


class JaxToStateful:
"""Converts a Stoix-ready JAX environment to a stateful one to be used by Sebulba systems."""

def __init__(self, env: Environment, num_envs: int, device: jax.Device, init_seed: int):
self.env = env
self.num_envs = num_envs
self.device = device

# Create the metrics
self.running_count_episode_return = np.zeros(self.num_envs, dtype=float)
self.running_count_episode_length = np.zeros(self.num_envs, dtype=int)
self.episode_return = np.zeros(self.num_envs, dtype=float)
self.episode_length = np.zeros(self.num_envs, dtype=int)

# Create the seeds
max_int = np.iinfo(np.int32).max
min_int = np.iinfo(np.int32).min
init_seeds = jax.random.randint(
jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int
)
self.rng_keys = jax.vmap(jax.random.PRNGKey)(init_seeds)

# Vmap and compile the reset and step functions
self.vmapped_reset = jax.jit(jax.vmap(self.env.reset), device=self.device)
self.vmapped_step = jax.jit(jax.vmap(self.env.step, in_axes=(0, 0)), device=self.device)

def reset(
self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None
) -> TimeStep:
with jax.default_device(self.device):

self.state, timestep = self.vmapped_reset(self.rng_keys)

# Reset the metrics
self.running_count_episode_return = np.zeros(self.num_envs, dtype=float)
self.running_count_episode_length = np.zeros(self.num_envs, dtype=int)
self.episode_return = np.zeros(self.num_envs, dtype=float)
self.episode_length = np.zeros(self.num_envs, dtype=int)

# Create the metrics dict
metrics = {
"episode_return": np.zeros(self.num_envs, dtype=float),
"episode_length": np.zeros(self.num_envs, dtype=int),
"is_terminal_step": np.zeros(self.num_envs, dtype=bool),
}

timestep_extras = timestep.extras

timestep_extras["metrics"] = metrics

timestep = timestep.replace(extras=timestep_extras)

return timestep

def step(self, action: list) -> TimeStep:
with jax.default_device(self.device):
self.state, timestep = self.vmapped_step(self.state, action)

ep_done = timestep.last()
not_done = ~ep_done

# Counting episode return and length.
new_episode_return = self.running_count_episode_return + timestep.reward
new_episode_length = self.running_count_episode_length + 1

# Update the episode return and length if the episode is done otherwise
# keep the previous values
episode_return_info = self.episode_return * not_done + new_episode_return * ep_done
episode_length_info = self.episode_length * not_done + new_episode_length * ep_done
# Update the running count
self.running_count_episode_return = new_episode_return * not_done
self.running_count_episode_length = new_episode_length * not_done

self.episode_return = episode_return_info
self.episode_length = episode_length_info

# Create the metrics dict
metrics = {
"episode_return": episode_return_info,
"episode_length": episode_length_info,
"is_terminal_step": ep_done,
}

timestep_extras = timestep.extras
timestep_extras["metrics"] = metrics
timestep = timestep.replace(extras=timestep_extras)

return timestep

def observation_spec(self) -> Spec:
return self.env.observation_spec()

def action_spec(self) -> Spec:
return self.env.action_spec()

def close(self) -> None:
pass


class JaxEnvFactory(EnvFactory):
"""
Create environments using stoix-ready JAX environments
"""

def __init__(self, jax_env: Environment, init_seed: int):
self.jax_env = jax_env
self.cpu = jax.devices("cpu")[0]
self.seed = init_seed
# a lock is needed because this object will be used from different threads.
# We want to make sure all seeds are unique
self.lock = threading.Lock()

def __call__(self, num_envs: int) -> JaxToStateful:
with self.lock:
seed = self.seed
self.seed += num_envs
return JaxToStateful(self.jax_env, num_envs, self.cpu, seed)
Loading