Skip to content

Commit

Permalink
feat: intermediate work
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 3, 2024
1 parent a4c9260 commit 08e41ba
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 10 deletions.
2 changes: 1 addition & 1 deletion stoix/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
architecture_name : sebulba
# --- Training ---
seed: 42 # RNG seed.
total_num_envs: 1024 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device.
total_num_envs: 4 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device.
total_timesteps: 1e7 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
Expand Down
21 changes: 12 additions & 9 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from stoix.systems.ppo.ppo_types import PPOTransition
from stoix.utils import make_env as environments
from stoix.utils.checkpointing import Checkpointer
from stoix.utils.env_factory import EnvPoolFactory
from stoix.utils.env_factory import EnvPoolFactory, make_gym_env_factory
from stoix.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
Expand Down Expand Up @@ -84,7 +84,7 @@ def rollout(rng: chex.PRNGKey) -> None:
with jax.default_device(actor_device):
# Reset the environment
# TODO(edan): put seeds in reset
timestep = envs.reset()
timestep = envs.reset(seed=seeds)
next_dones = np.logical_and(
np.array(timestep.last()), np.array(timestep.discount == 0.0)
)
Expand Down Expand Up @@ -147,6 +147,8 @@ def rollout(rng: chex.PRNGKey) -> None:
with RecordTimeTo(timings_dict["rollout_put_time"]):
pipeline.put(traj, timestep, timings_dict)

envs.close()

return rollout


Expand Down Expand Up @@ -454,8 +456,8 @@ def learner_setup(

# Get number/dimension of actions.
env = env_factory(num_envs=1)
obs_shape = env.observation_spec().obs.shape
num_actions = int(env.action_spec().num_values)
obs_shape = env.unwrapped.single_observation_space.shape
num_actions = int(env.env.unwrapped.single_action_space.n)
env.close()
config.system.action_dim = num_actions

Expand Down Expand Up @@ -574,11 +576,12 @@ def run_experiment(_config: DictConfig) -> float:
config.arch.actor.envs_per_actor = num_envs_per_actor

# Create the environments for train and eval.
env_factory = EnvPoolFactory(
config.arch.seed,
task_id="CartPole-v1",
env_type="dm",
)
# env_factory = EnvPoolFactory(
# config.arch.seed,
# task_id="CartPole-v1",
# env_type="dm",
# )
env_factory = make_gym_env_factory()

# PRNG keys.
key, key_e, actor_net_key, critic_net_key = jax.random.split(
Expand Down
19 changes: 19 additions & 0 deletions stoix/utils/env_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from typing import Any

import envpool
import gymnasium
from omegaconf import DictConfig

from stoix.wrappers.gym import GymRecordEpisodeMetrics, GymToJumanji


class EnvFactory(abc.ABC):
Expand Down Expand Up @@ -32,3 +36,18 @@ def __call__(self, num_envs: int) -> Any:
seed = self.seed
self.seed += num_envs
return envpool.make(**self.kwargs, num_envs=num_envs, seed=seed)


def make_gym_env_factory() -> EnvFactory:
def create_gym_env(name) -> gymnasium.Env:
env = gymnasium.make(name)
env = GymRecordEpisodeMetrics(env)
return env

def env_factory(num_envs):
envs = gymnasium.vector.AsyncVectorEnv(
[lambda: create_gym_env("CartPole-v1") for _ in range(num_envs)],
)
return GymToJumanji(envs)

return env_factory
166 changes: 166 additions & 0 deletions stoix/wrappers/gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import sys
import traceback
import warnings
from multiprocessing import Queue
from multiprocessing.connection import Connection
from typing import Any, Callable, Dict, Optional, Tuple, Union

import gymnasium
import jax
import numpy as np
from jumanji.types import StepType, TimeStep
from numpy.typing import NDArray

from stoix.base_types import Observation

# Filter out the warnings
warnings.filterwarnings("ignore", module="gymnasium.utils.passive_env_checker")


class GymWrapper(gymnasium.Wrapper):
"""Base wrapper for gym environments."""

def __init__(
self,
env: gymnasium.Env,
):
"""Initialise the gym wrapper
Args:
env (gymnasium.env): gymnasium env instance.
"""
super().__init__(env)
self._env = env
self.num_actions = self._env.action_space[0].n

def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[NDArray, Dict]:
if seed is not None:
self.env.seed(seed)

agents_view, info = self._env.reset()

info = {"actions_mask": self.get_actions_mask(info)}

return np.array(agents_view), info

def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]:
agents_view, reward, terminated, truncated, info = self._env.step(actions)

info = {"actions_mask": self.get_actions_mask(info)}

reward = np.array(reward)

return agents_view, reward, terminated, truncated, info

def get_actions_mask(self, info: Dict) -> NDArray:
if "action_mask" in info:
return np.array(info["action_mask"])
return np.ones((self.num_agents, self.num_actions), dtype=np.float32)

class GymRecordEpisodeMetrics(gymnasium.Wrapper):
"""Record the episode returns and lengths."""

def __init__(self, env: gymnasium.Env):
super().__init__(env)
self._env = env
self.running_count_episode_return = 0.0
self.running_count_episode_length = 0.0

def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[NDArray, Dict]:
agents_view, info = self._env.reset(seed, options)

# Create the metrics dict
metrics = {
"episode_return": self.running_count_episode_return,
"episode_length": self.running_count_episode_length,
"is_terminal_step": True,
}

# Reset the metrics
self.running_count_episode_return = 0.0
self.running_count_episode_length = 0

if "won_episode" in info:
metrics["won_episode"] = info["won_episode"]

info["metrics"] = metrics

return agents_view, info

def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]:
agents_view, reward, terminated, truncated, info = self._env.step(actions)

self.running_count_episode_return += float(np.mean(reward))
self.running_count_episode_length += 1

metrics = {
"episode_return": self.running_count_episode_return,
"episode_length": self.running_count_episode_length,
"is_terminal_step": False,
}
if "won_episode" in info:
metrics["won_episode"] = info["won_episode"]

info["metrics"] = metrics

return agents_view, reward, terminated, truncated, info


class GymToJumanji(gymnasium.Wrapper):
"""Converts from the Gym API to the dm_env API, using Jumanji's Timestep type."""

def reset(
self, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None
) -> TimeStep:
obs, info = self.env.reset(seed=seed, options=options)

num_envs = self.env.num_envs

ep_done = np.zeros(num_envs, dtype=float)
rewards = np.zeros((num_envs,), dtype=float)
teminated = np.zeros((num_envs,), dtype=float)
self.step_count = np.zeros((num_envs,), dtype=float)

timestep = self._create_timestep(obs, ep_done, teminated, rewards, info)

return timestep

def step(self, action: list) -> TimeStep:
obs, rewards, terminated, truncated, info = self.env.step(action)

ep_done = np.logical_or(terminated, truncated).all(axis=1)
self.step_count += 1

timestep = self._create_timestep(obs, ep_done, terminated, rewards, info)

return timestep

def _format_observation(
self, obs: NDArray, info: Dict
) -> Observation:
"""Create an observation from the raw observation and environment state."""

obs = np.array(obs)
action_mask = info["actions_mask"]
obs_data = {"agents_view": obs, "action_mask": action_mask, "step_count": self.step_count}

return Observation(**obs_data)

def _create_timestep(
self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict
) -> TimeStep:
obs = self._format_observation(obs, info)
extras = jax.tree.map(lambda *x: np.stack(x), *info["metrics"])
step_type = np.where(ep_done, StepType.LAST, StepType.MID)
terminated = np.all(terminated, axis=1)

return TimeStep(
step_type=step_type,
reward=rewards,
discount=1.0 - terminated,
observation=obs,
extras=extras,
)

0 comments on commit 08e41ba

Please sign in to comment.