Skip to content

Commit

Permalink
[INTERNAL] PLAT-2553: python client uses serverless and validates conn
Browse files Browse the repository at this point in the history
GitOrigin-RevId: f7f4b8b59773460d143e93fe2d5bcaa01d1570be
  • Loading branch information
benmccown committed Oct 4, 2024
1 parent 4f42276 commit 590a543
Show file tree
Hide file tree
Showing 16 changed files with 114 additions and 43 deletions.
8 changes: 7 additions & 1 deletion src/gretel_client/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def _check_endpoint(endpoint: str) -> str:
metavar="PROJECT",
help="Default Gretel project.",
)
@click.option(
"--skip-validate",
is_flag=True,
help="The API connection will be validated by default unless this flag is set.",
)
@pass_session
def configure(
sc: SessionContext,
Expand All @@ -94,6 +99,7 @@ def configure(
api_key: str,
project: str,
default_runner: str,
skip_validate: bool,
):
project_name = None if project == "none" else project
endpoint = _check_endpoint(endpoint)
Expand All @@ -108,7 +114,7 @@ def configure(
default_runner=default_runner,
)
config.update_default_project(project_id=project_name)
configure_session(config)
configure_session(config, validate=not skip_validate)

config_path = write_config(config)
sc.log.info(f"Configuration written to {config_path}. Done.")
Expand Down
30 changes: 28 additions & 2 deletions src/gretel_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from gretel_client.rest.api.users_api import UsersApi
from gretel_client.rest.api_client import ApiClient
from gretel_client.rest.configuration import Configuration
from gretel_client.rest_v1.api.serverless_api import ServerlessApi
from gretel_client.rest_v1.api_client import ApiClient as V1ApiClient
from gretel_client.rest_v1.configuration import Configuration as V1Configuration

Expand Down Expand Up @@ -292,7 +293,10 @@ def email(self) -> str:

@property
def stage(self) -> str:
if "https://api-dev.gretel" in self.endpoint:
if (
"https://api-dev.gretel" in self.endpoint
or ".dev.gretel.cloud" in self.endpoint
):
return "dev"
return "prod"

Expand All @@ -316,6 +320,25 @@ def _get_v1_api_client(self, *args, **kwargs) -> V1ApiClient:
V1ApiClient, V1Configuration, *args, **kwargs
)

def set_serverless_api(self) -> bool:
serverless: ServerlessApi = self.get_v1_api(ServerlessApi)
serverless_tenants_resp = serverless.list_serverless_tenants()
if (
isinstance(serverless_tenants_resp.tenants, list)
and len(serverless_tenants_resp.tenants) > 0
):
tenant = serverless_tenants_resp.tenants[0]
tenant_endpoint = tenant.config.api_endpoint
if not tenant_endpoint.startswith("https://"):
tenant_endpoint = f"https://{tenant_endpoint}"
if tenant_endpoint != self.endpoint:
print(
"Found a serverless tenant associated with this API key. Updating client configuration to use the tenant API endpoint."
)
self.endpoint = tenant_endpoint
return True
return False

def get_api(
self,
api_interface: Type[T],
Expand Down Expand Up @@ -822,7 +845,7 @@ def configure_session(
endpoint: Optional[str] = None,
artifact_endpoint: Optional[str] = None,
cache: str = "no",
validate: bool = False,
validate: bool = True,
clear: bool = False,
):
"""Updates client config for the session
Expand Down Expand Up @@ -895,6 +918,9 @@ def configure_session(
_session_client_config = config

if validate:
is_using_serverless = config.set_serverless_api()
if is_using_serverless:
print("Serverless tenant detected.")
print(f"Using endpoint {config.endpoint}")
try:
print(f"Logged in as {config.email} \u2705")
Expand Down
11 changes: 9 additions & 2 deletions src/gretel_client/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@ class GretelFactories:

_session: ClientConfig

def __init__(self, *, session: Optional[ClientConfig] = None, **session_kwargs):
def __init__(
self,
*,
session: Optional[ClientConfig] = None,
skip_configure_session: Optional[bool] = False,
**session_kwargs,
):
if session is None:
if len(session_kwargs) > 0:
# Only used for unit tests
if not skip_configure_session:
configure_session(**session_kwargs)
session = get_session_config()
elif len(session_kwargs) > 0:
Expand Down
16 changes: 13 additions & 3 deletions src/gretel_client/gretel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from pathlib import Path
from typing import Optional, Union

from gretel_client.config import add_session_context, ClientConfig, configure_session
from gretel_client.config import (
add_session_context,
ClientConfig,
configure_session,
get_session_config,
)
from gretel_client.dataframe import _DataFrameT
from gretel_client.factories import GretelFactories
from gretel_client.gretel.artifact_fetching import (
Expand Down Expand Up @@ -101,11 +106,14 @@ def __init__(
project_name: Optional[str] = None,
project_display_name: Optional[str] = None,
session: Optional[ClientConfig] = None,
skip_configure_session: Optional[bool] = False,
**session_kwargs,
):
if session is None:
if len(session_kwargs) > 0:
# Only used for unit tests
if not skip_configure_session:
configure_session(**session_kwargs)
session = get_session_config()
elif len(session_kwargs) > 0:
raise ValueError("cannot specify session arguments when passing a session")

Expand All @@ -114,7 +122,9 @@ def __init__(
)
self._user_id: str = get_me(session=self._session)["_id"][9:]
self._project: Optional[Project] = None
self.factories = GretelFactories(session=self._session)
self.factories = GretelFactories(
session=self._session, skip_configure_session=skip_configure_session
)

if project_name is not None:
self.set_project(name=project_name, display_name=project_display_name)
Expand Down
6 changes: 4 additions & 2 deletions src/gretel_client/inference_api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,18 @@ def __init__(
*,
verify_ssl: bool = True,
session: Optional[ClientConfig] = None,
skip_configure_session: Optional[bool] = False,
**session_kwargs,
):
if session is None:
if len(session_kwargs) > 0:
# Only used for unit tests
if not skip_configure_session:
configure_session(**session_kwargs)
session = get_session_config()
elif len(session_kwargs) > 0:
raise ValueError("cannot specify session arguments when passing a session")

if session.default_runner != "cloud":
if session.default_runner != "cloud" and not ".serverless." in session.endpoint:
raise GretelInferenceAPIError(
"Gretel's Inference API is currently only "
"available within Gretel Cloud. Your current runner "
Expand Down
3 changes: 2 additions & 1 deletion tests/gretel_client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def configure_session_client():
MagicMock(
default_runner=DEFAULT_RUNNER,
artifact_endpoint=DEFAULT_GRETEL_ARTIFACT_ENDPOINT,
)
),
validate=False,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/gretel_client/inference_api/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_generate_error_retry(mock_models, mock_all_models):
mock_models.return_value = [
{"model_id": "gretelai/tabular-v0", "model_type": "TABULAR"}
]
api = tabular.TabularInferenceAPI()
api = tabular.TabularInferenceAPI(validate=False)
api_response = {
"data": [
{
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_generate_timeout(
mock_models.return_value = [
{"model_id": "gretelai/tabular-v0", "model_type": "TABULAR"}
]
api = tabular.TabularInferenceAPI()
api = tabular.TabularInferenceAPI(validate=False)

timeout = 60
api_response = {
Expand Down
4 changes: 2 additions & 2 deletions tests/gretel_client/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def configure_session_client(request):
configurator = configure_hybrid_session
configure_kwargs = hybrid.kwargs or {}

configurator(config, **configure_kwargs)
configurator(config, validate=False, **configure_kwargs)

yield
configure_session(_load_config())
configure_session(_load_config(), validate=False)


@pytest.fixture
Expand Down
18 changes: 10 additions & 8 deletions tests/gretel_client/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def clear_session_config():
an existing Gretel config.
"""
with patch.dict(os.environ, {}, clear=True):
configure_session(ClientConfig())
configure_session(ClientConfig(), validate=False)
try:
yield
finally:
configure_session(_load_config())
configure_session(_load_config(), validate=False)


def test_cli(runner):
Expand All @@ -45,7 +45,9 @@ def test_cli(runner):

@patch("gretel_client.cli.cli.write_config")
def test_cli_does_configure(write_config: MagicMock, runner: CliRunner):
cmd = runner.invoke(cli, ["configure"], input="\n\n\ngrtu...\n\n")
cmd = runner.invoke(
cli, ["configure", "--skip-validate"], input="\n\n\ngrtu...\n\n"
)
assert not cmd.exception
write_config.assert_called_once_with(
ClientConfig(
Expand All @@ -65,7 +67,7 @@ def test_cli_does_configure_with_project(
with clear_session_config():
cmd = runner.invoke(
cli,
["configure"],
["configure", "--skip-validate"],
input=f"https://api-dev.gretel.cloud\n\n\n{os.getenv(GRETEL_API_KEY)}\n{project.name}\n",
catch_exceptions=True,
)
Expand All @@ -88,7 +90,7 @@ def test_cli_does_configure_with_custom_artifact_endpoint_and_hybrid_runner(
with clear_session_config():
cmd = runner.invoke(
cli,
["configure"],
["configure", "--skip-validate"],
input=f"https://api-dev.gretel.cloud\ns3://my-bucket\nhybrid\n{os.getenv(GRETEL_API_KEY)}\n\n",
catch_exceptions=True,
)
Expand All @@ -112,7 +114,7 @@ def test_cli_fails_configure_with_custom_artifact_endpoint_and_default_cloud_run
with clear_session_config():
cmd = runner.invoke(
cli,
["configure"],
["configure", "--skip-validate"],
input=f"https://api-dev.gretel.cloud\ns3://my-bucket\n\n{os.getenv(GRETEL_API_KEY)}\n\n",
catch_exceptions=True,
)
Expand All @@ -127,7 +129,7 @@ def test_cli_does_pass_configure_with_bad_project(
with clear_session_config():
cmd = runner.invoke(
cli,
["configure"],
["configure", "--skip-validate"],
input=f"{DEFAULT_GRETEL_ENDPOINT}\n\n{os.getenv(GRETEL_API_KEY)}\nbad-project-key\n",
catch_exceptions=True,
)
Expand All @@ -144,7 +146,7 @@ def test_missing_api_key(runner: CliRunner):

def test_invalid_api_key(runner: CliRunner):
with clear_session_config():
configure_session(ClientConfig(api_key="invalid"))
configure_session(ClientConfig(api_key="invalid"), validate=False)

cmd = runner.invoke(cli, ["projects", "create", "--name", "foo"])
assert cmd.exit_code == 1
Expand Down
2 changes: 2 additions & 0 deletions tests/gretel_client/integration/test_gretel_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def gretel() -> Gretel:
project_name=f"pytest-tabular-{uuid.uuid4().hex[:8]}",
api_key=os.getenv("GRETEL_API_KEY"),
endpoint="https://api-dev.gretel.cloud",
validate=False,
)
yield gretel
gretel._project.delete()
Expand Down Expand Up @@ -161,6 +162,7 @@ def test_gretel_no_project_set_exceptions():
gretel = Gretel(
api_key=os.getenv("GRETEL_API_KEY"),
endpoint="https://api-dev.gretel.cloud",
validate=False,
)

assert gretel._project is None
Expand Down
1 change: 1 addition & 0 deletions tests/gretel_client/integration/test_gretel_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def gretel() -> Gretel:
project_name=f"pytest-timeseries-{uuid.uuid4().hex[:8]}",
api_key=os.getenv("GRETEL_API_KEY"),
endpoint="https://api-dev.gretel.cloud",
validate=False,
)
yield gretel
gretel._project.delete()
Expand Down
10 changes: 7 additions & 3 deletions tests/gretel_client/integration/test_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,18 @@
@pytest.fixture(scope="module")
def llm():
return NaturalLanguageInferenceAPI(
api_key=os.getenv("GRETEL_API_KEY"), endpoint="https://api-dev.gretel.cloud"
api_key=os.getenv("GRETEL_API_KEY"),
endpoint="https://api-dev.gretel.cloud",
validate=False,
)


@pytest.fixture(scope="module")
def nav():
return TabularInferenceAPI(
api_key=os.getenv("GRETEL_API_KEY"), endpoint="https://api-dev.gretel.cloud"
api_key=os.getenv("GRETEL_API_KEY"),
endpoint="https://api-dev.gretel.cloud",
validate=False,
)


Expand Down Expand Up @@ -142,7 +146,7 @@ def test_nav_inference_api_edit_stream(nav):

def test_nav_inference_api_invalid_backend_model():
with pytest.raises(GretelInferenceAPIError):
TabularInferenceAPI(backend_model="invalid_model")
TabularInferenceAPI(backend_model="invalid_model", skip_configure_session=True)


def test_nav_inference_api_edit_invalid_seed_data_type(nav):
Expand Down
1 change: 1 addition & 0 deletions tests/gretel_client/integration/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def gretel() -> Gretel:
project_name=f"pytest-tuner-{uuid.uuid4().hex[:8]}",
api_key=os.getenv("GRETEL_API_KEY"),
endpoint="https://api-dev.gretel.cloud",
validate=False,
)
yield gretel
gretel.get_project().delete()
Expand Down
Loading

0 comments on commit 590a543

Please sign in to comment.