diff --git a/stoix/configs/arch/sebulba.yaml b/stoix/configs/arch/sebulba.yaml index c1a8ee4..818f8f4 100644 --- a/stoix/configs/arch/sebulba.yaml +++ b/stoix/configs/arch/sebulba.yaml @@ -10,14 +10,14 @@ num_updates: ~ # Number of updates # Define the number of actors per device and which devices to use. actor: device_ids: [0,1] # Define which devices to use for the actors. - actor_per_device: 8 # number of different threads per actor device. + actor_per_device: 4 # number of different threads per actor device. # Define which devices to use for the learner. learner: device_ids: [2,3] # Define which devices to use for the learner. # Size of the queue for the pipeline where actors push data and the learner pulls data. -pipeline_queue_size: 10 +pipeline_queue_size: 20 # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select diff --git a/stoix/configs/env/envpool/breakout.yaml b/stoix/configs/env/envpool/breakout.yaml new file mode 100644 index 0000000..135d956 --- /dev/null +++ b/stoix/configs/env/envpool/breakout.yaml @@ -0,0 +1,22 @@ +# ---Environment Configs--- +env_name: envpool # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: Breakout-v5 + task_name: breakout # For logging purposes. + +kwargs: + episodic_life: True + repeat_action_probability: 0 + noop_max: 30 + full_action_space: False + max_episode_steps: 27000 + + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 400.0 diff --git a/stoix/configs/env/envpool/vizdoom_basic.yaml b/stoix/configs/env/envpool/vizdoom_basic.yaml new file mode 100644 index 0000000..e2e588b --- /dev/null +++ b/stoix/configs/env/envpool/vizdoom_basic.yaml @@ -0,0 +1,19 @@ +# ---Environment Configs--- +env_name: envpool # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: Basic-v1 + task_name: vizdoom_basic # For logging purposes. + +kwargs: + episodic_life: True + use_combined_action : True + + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 100.0 diff --git a/stoix/evaluator.py b/stoix/evaluator.py index 1bb2a59..a84c1e7 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -420,7 +420,8 @@ def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]: seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist() timestep = envs.reset(seed=seeds) - all_timesteps = [timestep] + all_metrics = [timestep.extras] + all_dones = [timestep.last()] finished_eps = timestep.last() # Loop until all episodes are done. @@ -429,17 +430,16 @@ def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]: action = act_fn(params, timestep.observation, act_key) action_cpu = np.asarray(jax.device_put(action, cpu)) timestep = envs.step(action_cpu) - all_timesteps.append(timestep) - + all_metrics.append(timestep.extras) + all_dones.append(timestep.last()) finished_eps = np.logical_or(finished_eps, timestep.last()) - all_timesteps = jax.tree.map(lambda *x: np.stack(x), *all_timesteps) - - metrics = all_timesteps.extras + metrics = jax.tree.map(lambda *x: np.stack(x), *all_metrics) + dones = np.stack(all_dones) # find the first instance of done to get the metrics at that timestep, we don't # care about subsequent steps because we only the results from the first episode - done_idx = np.argmax(all_timesteps.last(), axis=0) + done_idx = np.argmax(dones, axis=0) metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) del metrics["is_terminal_step"] # unneeded for logging diff --git a/stoix/networks/resnet.py b/stoix/networks/resnet.py index 6a21ae1..d2b95c5 100644 --- a/stoix/networks/resnet.py +++ b/stoix/networks/resnet.py @@ -115,15 +115,14 @@ class VisualResNetTorso(nn.Module): @nn.compact def __call__(self, observation: chex.Array) -> chex.Array: - if observation.ndim > 4: return nn.batch_apply.BatchApply(self.__call__)(observation) - # If the input is in the form of [B, C, H, W], we need to transpose it to [B, H, W, C] + # If the input is in the form of [B, C, H, W], we need to transpose it to [B, H, W, C] if self.channel_first: observation = observation.transpose((0, 2, 3, 1)) - + assert ( observation.ndim == 4 ), f"Expected inputs to have shape [B, H, W, C] but got shape {observation.shape}." diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 970ba50..37df0d8 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -615,11 +615,11 @@ def run_experiment(_config: DictConfig) -> float: ), "Local and global devices must be the same for now. We dont support multihost just yet" # Extract the actor and learner devices actor_devices = [local_devices[device_id] for device_id in config.arch.actor.device_ids] - # For evaluation we simply use the first actor device as its less computationally intensive - evaluator_device = actor_devices[0] local_learner_devices = [ local_devices[device_id] for device_id in config.arch.learner.device_ids ] + # For evaluation we simply use the first learner device + evaluator_device = local_learner_devices[0] print(f"{Fore.BLUE}{Style.BRIGHT}Actors devices: {actor_devices}{Style.RESET_ALL}") print(f"{Fore.GREEN}{Style.BRIGHT}Learner devices: {local_learner_devices}{Style.RESET_ALL}") print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}") diff --git a/stoix/utils/sebulba_utils.py b/stoix/utils/sebulba_utils.py index ef81620..6f0bbcc 100644 --- a/stoix/utils/sebulba_utils.py +++ b/stoix/utils/sebulba_utils.py @@ -3,9 +3,9 @@ import time from typing import Any, Dict, List, Sequence, Tuple, Union -from colorama import Fore, Style import jax import jax.numpy as jnp +from colorama import Fore, Style from jumanji.types import TimeStep from stoix.base_types import Parameters, StoixTransition @@ -109,10 +109,12 @@ def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: def clear(self) -> None: """Clear the pipeline.""" - num_items = self._queue.qsize() + n_items = self._queue.qsize() while not self._queue.empty(): self._queue.get() - print(f"{Fore.YELLOW}{Style.BRIGHT}Cleared {num_items} items from the pipeline{Style.RESET_ALL}") + print( + f"{Fore.YELLOW}{Style.BRIGHT}Cleared {n_items} items from the pipeline{Style.RESET_ALL}" + ) class ParamsSource(threading.Thread):