Skip to content

Commit

Permalink
Allows other values than NaN for unplayable levels in SMB (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu authored May 24, 2024
1 parent 8ac4ca6 commit 6907b1d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ def __init__(
alphabet: List[str] = smb_info.alphabet,
max_time: int = 30,
visualize: bool = False,
value_on_unplayable: float = np.NaN,
):
self.alphabet = alphabet
self.alphabet_s_to_i = {s: i for i, s in enumerate(alphabet)}
self.alphabet_i_to_s = {i: s for i, s in enumerate(alphabet)}
self.max_time = max_time
self.visualize = visualize
self.value_on_unplayable = value_on_unplayable

def __call__(self, x: np.ndarray, context=None) -> np.ndarray:
"""Computes number of jumps in a given latent code x."""
Expand All @@ -95,7 +97,7 @@ def __call__(self, x: np.ndarray, context=None) -> np.ndarray:
if res["marioStatus"] == 1:
jumps = res["jumpActionsPerformed"]
else:
jumps = np.nan
jumps = self.value_on_unplayable

jumps_for_all_levels.append(jumps)

Expand Down
8 changes: 7 additions & 1 deletion src/poli/objective_repository/super_mario_bros/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self,
max_time: int = 30,
visualize: bool = False,
value_on_unplayable: float = np.NaN,
batch_size: int = None,
parallelize: bool = False,
num_workers: int = None,
Expand Down Expand Up @@ -91,8 +92,9 @@ def __init__(
evaluation_budget=evaluation_budget,
)
self.force_isolation = force_isolation
self.max_time = max_time
self.max_time = int(max_time)
self.visualize = visualize
self.value_on_unplayable = value_on_unplayable
_ = get_inner_function(
isolated_function_name="super_mario_bros__isolated",
class_name="SMBIsolatedLogic",
Expand All @@ -101,6 +103,7 @@ def __init__(
alphabet=smb_info.alphabet,
max_time=self.max_time,
visualize=self.visualize,
value_on_unplayable=self.value_on_unplayable,
)

def _black_box(self, x: np.ndarray, context=None) -> np.ndarray:
Expand All @@ -114,6 +117,7 @@ def _black_box(self, x: np.ndarray, context=None) -> np.ndarray:
alphabet=smb_info.alphabet,
max_time=self.max_time,
visualize=self.visualize,
value_on_unplayable=self.value_on_unplayable,
)
return inner_function(x, context)

Expand Down Expand Up @@ -142,6 +146,7 @@ def create(
self,
max_time: int = 30,
visualize: bool = False,
value_on_unplayable: float = np.NaN,
seed: int = None,
batch_size: int = None,
parallelize: bool = False,
Expand Down Expand Up @@ -182,6 +187,7 @@ def create(
f = SuperMarioBrosBlackBox(
max_time=max_time,
visualize=visualize,
value_on_unplayable=value_on_unplayable,
batch_size=batch_size,
parallelize=parallelize,
num_workers=num_workers,
Expand Down

0 comments on commit 6907b1d

Please sign in to comment.