Skip to content

Commit

Permalink
chore: merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Sep 20, 2024
1 parent 558f552 commit c34e82e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
6 changes: 3 additions & 3 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import gymnax
import hydra
import jax
import jax.numpy as jnp
import jaxmarl
import jumanji
Expand All @@ -26,10 +25,10 @@
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY

from stoix.utils.debug_env import IdentityGame, SequenceGame
from stoix.utils.env_factory import EnvFactory, GymnasiumFactory, EnvPoolFactory
from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jax_to_factory import JaxEnvFactory, JaxToStateful
from stoix.wrappers.jax_to_factory import JaxEnvFactory
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
from stoix.wrappers.navix import NavixWrapper
from stoix.wrappers.pgx import PGXWrapper
Expand Down Expand Up @@ -427,6 +426,7 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]:

return envs


def make_factory(config: DictConfig) -> EnvFactory:
"""
Create a env_factory for sebulba systems.
Expand Down
13 changes: 7 additions & 6 deletions stoix/wrappers/jax_to_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import threading
from typing import Any, Optional
from typing import Optional

import jax
import jax.numpy as jnp
import numpy as np
from jumanji.env import Environment
from jumanji.specs import Spec
Expand All @@ -28,7 +27,9 @@ def __init__(self, env: Environment, num_envs: int, device: jax.Device, init_see
# Create the seeds
max_int = np.iinfo(np.int32).max
min_int = np.iinfo(np.int32).min
init_seeds = jax.random.randint(jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int)
init_seeds = jax.random.randint(
jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int
)
self.rng_keys = jax.vmap(jax.random.PRNGKey)(init_seeds)

# Vmap and compile the reset and step functions
Expand Down Expand Up @@ -112,8 +113,8 @@ class JaxEnvFactory(EnvFactory):
"""
Create environments using stoix-ready JAX environments
"""
def __init__(self, jax_env : Environment, init_seed: int):

def __init__(self, jax_env: Environment, init_seed: int):
self.jax_env = jax_env
self.cpu = jax.devices("cpu")[0]
self.seed = init_seed
Expand All @@ -125,4 +126,4 @@ def __call__(self, num_envs: int) -> JaxToStateful:
with self.lock:
seed = self.seed
self.seed += num_envs
return JaxToStateful(self.jax_env, num_envs, self.cpu, seed)
return JaxToStateful(self.jax_env, num_envs, self.cpu, seed)

0 comments on commit c34e82e

Please sign in to comment.