Skip to content

Commit

Permalink
Remove assigning null when the value is null (NeMo launcher) (#250)
Browse files Browse the repository at this point in the history
* Remove assigning null when the value is null

* Reflect Andrei's comments
  • Loading branch information
TaekyungHeo authored Oct 9, 2024
1 parent 7455f42 commit 2b6181b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,9 @@ def _set_node_config(self, nodes: List[str], num_nodes: int) -> None:
self.final_cmd_args["training.trainer.num_nodes"] = str(len(nodes)) if nodes else num_nodes

def _validate_data_config(self) -> None:
"""Validate the data directory and prefix configuration for non-mock environments."""
"""Validate the data prefix configuration for non-mock environments."""
if self.final_cmd_args.get("training.model.data.data_impl") != "mock":
data_dir = self.final_cmd_args.get("data_dir")
data_prefix = self.final_cmd_args.get("training.model.data.data_prefix")

if not data_dir or data_dir == "~":
raise ValueError(
"The 'data_dir' field of the NeMo launcher test contains an invalid placeholder '~'. "
"Please provide a valid path to the dataset in the test schema TOML file. "
"The 'data_dir' field must point to an actual dataset location."
)

if data_prefix == "[]":
raise ValueError(
"The 'data_prefix' field of the NeMo launcher test is missing or empty. "
Expand Down Expand Up @@ -198,10 +189,7 @@ def _generate_cmd_args_str(self, args: Dict[str, str], nodes: List[str]) -> str:
value = f"\\'{value}\\'"
env_var_str_parts.append(f"+{key}={value}")
else:
if value == "~":
cmd_arg_str_parts.append(f"~{key}=null")
else:
cmd_arg_str_parts.append(f"{key}={value}")
cmd_arg_str_parts.append(f"{key}={value}")

if nodes:
nodes_str = ",".join(nodes)
Expand Down
1 change: 0 additions & 1 deletion src/cloudai/test_definitions/nemo_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class NeMoLauncherCmdArgs(CmdArgs):
repository_commit_hash: str = "cf411a9ede3b466677df8ee672bcc6c396e71e1a"
docker_image_url: str = "nvcr.io/nvidia/nemo:24.01.01"
stages: str = '["training"]'
data_dir: str = "~"
numa_mapping: NumaMapping = Field(default_factory=NumaMapping)
cluster: Cluster = Field(default_factory=Cluster)
training: Training = Field(default_factory=Training)
Expand Down
24 changes: 3 additions & 21 deletions tests/slurm_command_gen_strategy/test_nemo_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,40 +135,22 @@ def test_gpus_per_node_value(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStra

assert "cluster.gpus_per_node=null" in cmd

def test_argument_with_tilde_value(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy, nemo_test_run: TestRun):
tdef: NeMoLauncherTestDefinition = cast(NeMoLauncherTestDefinition, nemo_test_run.test.test_definition)
tdef.cmd_args.training.model.micro_batch_size = "~" # type: ignore

cmd = nemo_cmd_gen.gen_exec_command(nemo_test_run)
assert "~training.model.micro_batch_size=null" in cmd

def test_data_impl_mock_skips_checks(
self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy, nemo_test_run: TestRun
):
tdef: NeMoLauncherTestDefinition = cast(NeMoLauncherTestDefinition, nemo_test_run.test.test_definition)
tdef.cmd_args.data_dir = "DATA_DIR"

tdef.extra_cmd_args = {"data_dir": "DATA_DIR"}
cmd = nemo_cmd_gen.gen_exec_command(nemo_test_run)
assert "data_dir" in cmd
assert "data_dir=DATA_DIR" in cmd

def test_data_dir_and_data_prefix_validation(
self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy, nemo_test_run: TestRun
):
tdef: NeMoLauncherTestDefinition = cast(NeMoLauncherTestDefinition, nemo_test_run.test.test_definition)
tdef.cmd_args.training.model.data.data_impl = "not_mock"
tdef.cmd_args.training.model.data.data_prefix = "[]"
tdef.extra_cmd_args = {"data_dir": "DATA_DIR"}

with pytest.raises(
ValueError,
match=(
"The 'data_dir' field of the NeMo launcher test contains an invalid placeholder '~'. "
"Please provide a valid path to the dataset in the test schema TOML file. "
"The 'data_dir' field must point to an actual dataset location."
),
):
nemo_cmd_gen.gen_exec_command(nemo_test_run)

tdef.cmd_args.data_dir = "/fake/data_dir"
with pytest.raises(ValueError, match="The 'data_prefix' field of the NeMo launcher test is missing or empty."):
nemo_cmd_gen.gen_exec_command(nemo_test_run)

Expand Down

0 comments on commit 2b6181b

Please sign in to comment.