Skip to content

Commit

Permalink
chore: slight modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 16, 2024
1 parent 6a0acf9 commit 0c12e86
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 17 deletions.
4 changes: 2 additions & 2 deletions stoix/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ num_updates: ~ # Number of updates
# Define the number of actors per device and which devices to use.
actor:
device_ids: [0,1] # Define which devices to use for the actors.
actor_per_device: 8 # number of different threads per actor device.
actor_per_device: 4 # number of different threads per actor device.

# Define which devices to use for the learner.
learner:
device_ids: [2,3] # Define which devices to use for the learner.

# Size of the queue for the pipeline where actors push data and the learner pulls data.
pipeline_queue_size: 10
pipeline_queue_size: 20

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
Expand Down
22 changes: 22 additions & 0 deletions stoix/configs/env/envpool/breakout.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ---Environment Configs---
env_name: envpool # Used for logging purposes and selection of the corresponding wrapper.

scenario:
name: Breakout-v5
task_name: breakout # For logging purposes.

kwargs:
episodic_life: True
repeat_action_probability: 0
noop_max: 30
full_action_space: False
max_episode_steps: 27000


# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# optional - defines the threshold that needs to be reached in order to consider the environment solved.
# if present then solve rate will be logged.
solved_return_threshold: 400.0
19 changes: 19 additions & 0 deletions stoix/configs/env/envpool/vizdoom_basic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ---Environment Configs---
env_name: envpool # Used for logging purposes and selection of the corresponding wrapper.

scenario:
name: Basic-v1
task_name: vizdoom_basic # For logging purposes.

kwargs:
episodic_life: True
use_combined_action : True


# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# optional - defines the threshold that needs to be reached in order to consider the environment solved.
# if present then solve rate will be logged.
solved_return_threshold: 100.0
14 changes: 7 additions & 7 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:
seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist()
timestep = envs.reset(seed=seeds)

all_timesteps = [timestep]
all_metrics = [timestep.extras]
all_dones = [timestep.last()]
finished_eps = timestep.last()

# Loop until all episodes are done.
Expand All @@ -429,17 +430,16 @@ def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:
action = act_fn(params, timestep.observation, act_key)
action_cpu = np.asarray(jax.device_put(action, cpu))
timestep = envs.step(action_cpu)
all_timesteps.append(timestep)

all_metrics.append(timestep.extras)
all_dones.append(timestep.last())
finished_eps = np.logical_or(finished_eps, timestep.last())

all_timesteps = jax.tree.map(lambda *x: np.stack(x), *all_timesteps)

metrics = all_timesteps.extras
metrics = jax.tree.map(lambda *x: np.stack(x), *all_metrics)
dones = np.stack(all_dones)

# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
done_idx = np.argmax(all_timesteps.last(), axis=0)
done_idx = np.argmax(dones, axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # unneeded for logging

Expand Down
5 changes: 2 additions & 3 deletions stoix/networks/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,14 @@ class VisualResNetTorso(nn.Module):

@nn.compact
def __call__(self, observation: chex.Array) -> chex.Array:


if observation.ndim > 4:
return nn.batch_apply.BatchApply(self.__call__)(observation)

# If the input is in the form of [B, C, H, W], we need to transpose it to [B, H, W, C]
# If the input is in the form of [B, C, H, W], we need to transpose it to [B, H, W, C]
if self.channel_first:
observation = observation.transpose((0, 2, 3, 1))

assert (
observation.ndim == 4
), f"Expected inputs to have shape [B, H, W, C] but got shape {observation.shape}."
Expand Down
4 changes: 2 additions & 2 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,11 @@ def run_experiment(_config: DictConfig) -> float:
), "Local and global devices must be the same for now. We dont support multihost just yet"
# Extract the actor and learner devices
actor_devices = [local_devices[device_id] for device_id in config.arch.actor.device_ids]
# For evaluation we simply use the first actor device as its less computationally intensive
evaluator_device = actor_devices[0]
local_learner_devices = [
local_devices[device_id] for device_id in config.arch.learner.device_ids
]
# For evaluation we simply use the first learner device
evaluator_device = local_learner_devices[0]
print(f"{Fore.BLUE}{Style.BRIGHT}Actors devices: {actor_devices}{Style.RESET_ALL}")
print(f"{Fore.GREEN}{Style.BRIGHT}Learner devices: {local_learner_devices}{Style.RESET_ALL}")
print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}")
Expand Down
8 changes: 5 additions & 3 deletions stoix/utils/sebulba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import time
from typing import Any, Dict, List, Sequence, Tuple, Union

from colorama import Fore, Style
import jax
import jax.numpy as jnp
from colorama import Fore, Style
from jumanji.types import TimeStep

from stoix.base_types import Parameters, StoixTransition
Expand Down Expand Up @@ -109,10 +109,12 @@ def shard_split_playload(self, payload: Any, axis: int = 0) -> Any:

def clear(self) -> None:
"""Clear the pipeline."""
num_items = self._queue.qsize()
n_items = self._queue.qsize()
while not self._queue.empty():
self._queue.get()
print(f"{Fore.YELLOW}{Style.BRIGHT}Cleared {num_items} items from the pipeline{Style.RESET_ALL}")
print(
f"{Fore.YELLOW}{Style.BRIGHT}Cleared {n_items} items from the pipeline{Style.RESET_ALL}"
)


class ParamsSource(threading.Thread):
Expand Down

0 comments on commit 0c12e86

Please sign in to comment.