Skip to content

Commit

Permalink
Remove non_default_setting validation in high level interface for dev
Browse files Browse the repository at this point in the history
GitOrigin-RevId: cd46411fc2b146bd5c212835f18e53c6b9359813
  • Loading branch information
kboyd committed Jul 16, 2024
1 parent 1213338 commit f336be4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/gretel_client/gretel/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import yaml

from gretel_client.config import get_session_config
from gretel_client.gretel.artifact_fetching import ReportType
from gretel_client.gretel.exceptions import (
ConfigSettingError,
Expand Down Expand Up @@ -221,6 +222,7 @@ def create_model_config_from_base(
config = smart_read_model_config(base_config)
model_type, model_config_section = extract_model_config_section(config)
setup = CONFIG_SETUP_DICT[ModelType(model_type)]
is_gretel_dev = get_session_config().stage == "dev"

config = _backwards_compat_transform_config(config, non_default_settings)

Expand All @@ -230,20 +232,20 @@ def create_model_config_from_base(
for section, settings in non_default_settings.items():
if not isinstance(settings, dict):
extra_kwargs = setup.extra_kwargs or []
if section in extra_kwargs:
if section in extra_kwargs or is_gretel_dev:
model_config_section[section] = settings
else:
raise ConfigSettingError(
f"`{section}` is an invalid keyword argument. Valid options "
f"include {setup.config_sections + extra_kwargs}."
)
elif section not in setup.config_sections:
elif section in setup.config_sections or is_gretel_dev:
model_config_section.setdefault(section, {}).update(settings)
else:
raise ConfigSettingError(
f"`{section}` is not a valid `{setup.model_name}` config section. "
f"Must be one of [{setup.config_sections}]."
)
else:
model_config_section.setdefault(section, {}).update(settings)
return config


Expand Down
27 changes: 27 additions & 0 deletions tests/gretel_client/test_gretel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import Callable
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -116,6 +117,26 @@ def test_update_non_nested_setting(gpt_x_config_file_path: Path):
assert settings["column_name"] == "custom_name"


def test_no_validation_on_dev():
# Additional keys are accepted for internal testing on dev.
with patch(
"gretel_client.gretel.config_setup.get_session_config"
) as mock_get_session_config:
mock_session_config = MagicMock()
mock_session_config.stage = "dev"
mock_get_session_config.return_value = mock_session_config

settings = create_model_config_from_base(
base_config="navigator-ft",
is_gretel_dev=True,
params={"rope_scaling_factor": 2},
extra_stuff={"foo": "bar"},
)["models"][0]["navigator_ft"]

assert settings["params"]["rope_scaling_factor"] == 2
assert settings["extra_stuff"]["foo"] == "bar"


def test_gpt_x_backwards_compatibility(
gpt_x_old_config_file_path: Path, gpt_x_config_file_path: Path
):
Expand Down Expand Up @@ -186,3 +207,9 @@ def test_create_config_settings_error():
base_config="tabular-actgan",
invalid_section={"invalid": "section"},
)

with pytest.raises(ConfigSettingError):
create_model_config_from_base(
base_config="navigator-ft",
extra_stuff={"foo": "bar"},
)

0 comments on commit f336be4

Please sign in to comment.