Skip to content

Commit

Permalink
feat: intermediate work
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 25, 2024
1 parent c17b2a1 commit ca59569
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 24 deletions.
8 changes: 4 additions & 4 deletions stoix/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ learner:
device_ids: [2,3] # Define which devices to use for the learner.

# Size of the queue for the pipeline where actors push data and the learner pulls data.
pipeline:
samples_per_insert_tolerance_rate: 0.1
min_replay_size: 10
# pipeline:
# samples_per_insert_tolerance_rate: 0.1
# min_replay_size: 10

# pipeline_queue_size: 10
pipeline_queue_size: 10

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
Expand Down
2 changes: 1 addition & 1 deletion stoix/configs/system/ff_dqn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
system_name: ff_dqn # Name of the system.

# --- RL hyperparameters ---
rollout_length: 2 # Number of environment steps per vectorised environment.
rollout_length: 16 # Number of environment steps per vectorised environment.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 1_000_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
Expand Down
3 changes: 3 additions & 0 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def get_learner_rollout_fn(
network parameters to a queue for evaluation."""

def learner_rollout(learner_state: CoreLearnerState) -> None:
learner_step = 0
# Loop for the total number of evaluations selected to be performed.
for _ in range(config.arch.num_evaluation):
# Create the lists to store metrics and timings for this learning iteration.
Expand Down Expand Up @@ -450,6 +451,7 @@ def learner_rollout(learner_state: CoreLearnerState) -> None:
with RecordTimeTo(learner_timings["learning_time"]):
learner_state, train_metrics = learn_step(learner_state, traj_batch)

learner_step += 1
# We store the metrics and timings for this update
metrics.append((episode_metrics, train_metrics))
actor_timings.append(actor_times)
Expand All @@ -469,6 +471,7 @@ def learner_rollout(learner_state: CoreLearnerState) -> None:
actor_timings = jax.tree.map(lambda *x: np.mean(x), *actor_timings)
timing_dict = actor_timings | learner_timings
timing_dict["pipeline_qsize"] = q_sizes
timing_dict["learner_step"] = [learner_step]
timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list))
try:
# We add a timeout mainly for sanity checks
Expand Down
40 changes: 24 additions & 16 deletions stoix/systems/q_learning/sebulba/ff_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from stoix.utils.env_factory import EnvFactory
from stoix.utils.logger import LogEvent, StoixLogger
from stoix.utils.loss import q_learning
from stoix.utils.rate_limiters import MinSize, SampleToInsertRatio
from stoix.utils.rate_limiters import SampleToInsertRatio
from stoix.utils.sebulba_utils import (
OffPolicyPipeline,
ParamsSource,
Expand Down Expand Up @@ -156,8 +156,8 @@ def rollout_fn(rng_key: chex.PRNGKey) -> None:
# Send the trajectory to the pipeline
with RecordTimeTo(actor_timings_dict["rollout_put_time"]):
try:
pipeline.put(traj, actor_timings_dict, episode_metrics)
except queue.Full:
pipeline.put(traj, actor_timings_dict, episode_metrics, timeout=60)
except TimeoutError:
warnings.warn(
"Waited too long to add to the rollout queue, killing the actor thread",
stacklevel=2,
Expand Down Expand Up @@ -325,7 +325,14 @@ def learner_rollout(learner_state: CoreOffPolicyLearnerState) -> None:
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
with RecordTimeTo(learner_timings["rollout_get_time"]):
traj_batch, actor_times, episode_metrics = pipeline.get() # type: ignore
try:
traj_batch, actor_times, episode_metrics = pipeline.get(timeout=60) # type: ignore
except TimeoutError:
warnings.warn(
"Waited too long to sample from pipeline, killing the learner thread",
stacklevel=2,
)
break

# We then call the update function to update the networks
with RecordTimeTo(learner_timings["learning_time"]):
Expand All @@ -348,6 +355,13 @@ def learner_rollout(learner_state: CoreOffPolicyLearnerState) -> None:
episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics)
actor_timings = jax.tree.map(lambda *x: np.mean(x), *actor_timings)
timing_dict = actor_timings | learner_timings
buffer_size = (
(pipeline.get_num_inserts() - pipeline.get_num_deletes())
* config.arch.actor.envs_per_actor
* config.system.rollout_length
)
timing_dict["buffer_size"] = [buffer_size]
timing_dict["num_samples"] = [pipeline.get_num_samples()]
timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list))
try:
# We add a timeout mainly for sanity checks
Expand Down Expand Up @@ -623,18 +637,12 @@ def run_experiment(_config: DictConfig) -> float:
# Now we create the pipeline
replay_buffer_add, replay_buffer_sample = buffer_fns
# Set up the rate limiter that controls how actors and learners interact with the pipeline
if config.system.epochs > 0:
samples_per_insert_tolerance_rate = config.arch.pipeline.samples_per_insert_tolerance_rate
samples_per_insert_tolerance = (
samples_per_insert_tolerance_rate * config.system.epochs
)
error_buffer = config.arch.pipeline.min_replay_size * samples_per_insert_tolerance
rate_limiter = SampleToInsertRatio(
config.system.epochs, config.arch.pipeline.min_replay_size, error_buffer
)
else:
pass
rate_limiter = MinSize(config.arch.pipeline.min_replay_size) # type: ignore
samples_per_insert_tolerance_rate = 0.1 # This allows for 10% tolerance
samples_per_insert_tolerance = samples_per_insert_tolerance_rate * config.system.epochs
error_buffer = config.system.total_batch_size * samples_per_insert_tolerance
min_inserts = max(config.system.total_batch_size // steps_per_insert, 1)
rate_limiter = SampleToInsertRatio(config.system.epochs, min_inserts, error_buffer)

pipeline = OffPolicyPipeline(
replay_buffer_add,
replay_buffer_sample,
Expand Down
19 changes: 16 additions & 3 deletions stoix/utils/sebulba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def put(
traj: Sequence[StoixTransition],
actor_timings_dict: Dict[str, List[float]],
actor_episode_metrics: List[Dict[str, List[float]]],
timeout: Union[float, None] = None,
) -> None:
start_condition, end_condition = (threading.Condition(), threading.Condition())
with start_condition:
Expand All @@ -207,7 +208,7 @@ def put(

# wait until we can insert the data
try:
self.rate_limiter.await_can_insert(timeout=180)
self.rate_limiter.await_can_insert(timeout=timeout)
except TimeoutError:
print(
f"{Fore.RED}{Style.BRIGHT}Actor has timed out on insertion, "
Expand All @@ -223,7 +224,7 @@ def put(

# signal that we have inserted the data
self.rate_limiter.insert()

# Concatenate metrics - List[Dict[str, List[float]]] --> Dict[str, List[float]]
actor_episode_metrics = self.concatenate_metrics(actor_episode_metrics)

Expand Down Expand Up @@ -261,6 +262,18 @@ def get(self, timeout: Union[float, None] = None) -> Tuple[StoixTransition, Dict

return sharded_sampled_batch, actor_timings, actor_metrics # type: ignore

def get_num_inserts(self) -> int:
"""Get the number of inserts that have been made to the buffer."""
return self.rate_limiter.num_inserts()

def get_num_samples(self) -> int:
"""Get the number of samples that have been made from the buffer."""
return self.rate_limiter.num_samples()

def get_num_deletes(self) -> int:
"""Get the number of deletes that have been made from the buffer."""
return self.rate_limiter.num_deletes()

@partial(jax.jit, static_argnums=(0, 2))
def stack_trajectory(self, trajectory: List[StoixTransition], axis: int = 0) -> StoixTransition:
"""Stack a list of parallel_env transitions into a single
Expand All @@ -271,7 +284,7 @@ def shard_split_playload(self, payload: Any, axis: int = 0) -> Any:
"""Split the payload over the learner devices."""
split_payload = jnp.split(payload, len(self.learner_devices), axis=axis)
return jax.device_put_sharded(split_payload, devices=self.learner_devices)

@partial(jax.jit, static_argnums=(0,))
def concatenate_metrics(
self, actor_metrics: List[Dict[str, List[float]]]
Expand Down

0 comments on commit ca59569

Please sign in to comment.