Skip to content

Commit

Permalink
PLAT-1895: support for combined models image in job polling apis
Browse files Browse the repository at this point in the history
Co-authored-by: Drew Newberry <drew@gretel.ai>
GitOrigin-RevId: 622447e2fa66edaa702c71ce4c9e0137ff8e8415
  • Loading branch information
benmccown and drew committed Jun 21, 2024
1 parent 0a2e98f commit 87cdb20
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
4 changes: 3 additions & 1 deletion docs/rest/JobsApi.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ with gretel_client.rest.ApiClient(configuration) as api_client:
] # [str] | (optional)
org_only = True # bool | Query for jobs within the same organization only (optional)
cluster_guid = "cluster_guid_example" # str | GUID of the cluster for which to retrieve jobs (optional)
use_combined_models_image = False # bool | True results in the jobs' container_image field being set to the combined models image (optional) if omitted the server will use the default value of False

# example passing only required values which don't have defaults set
# and optional values
try:
# Get Gretel job for scheduling
api_response = api_instance.receive_one(project_id=project_id, project_ids=project_ids, runner_modes=runner_modes, org_only=org_only, cluster_guid=cluster_guid)
api_response = api_instance.receive_one(project_id=project_id, project_ids=project_ids, runner_modes=runner_modes, org_only=org_only, cluster_guid=cluster_guid, use_combined_models_image=use_combined_models_image)
pprint(api_response)
except gretel_client.rest.ApiException as e:
print("Exception when calling JobsApi->receive_one: %s\n" % e)
Expand All @@ -71,6 +72,7 @@ Name | Type | Description | Notes
**runner_modes** | **[str]**| | [optional]
**org_only** | **bool**| Query for jobs within the same organization only | [optional]
**cluster_guid** | **str**| GUID of the cluster for which to retrieve jobs | [optional]
**use_combined_models_image** | **bool**| True results in the jobs&#39; container_image field being set to the combined models image | [optional] if omitted the server will use the default value of False

### Return type

Expand Down
7 changes: 7 additions & 0 deletions src/gretel_client/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_CLUSTERID_HEADER_KEY = "X-Gretel-Clusterid"
_APP_VERSION_KEY = "X-Gretel-AppVersion"
_IMAGE_VERSION_KEY = "X-Gretel-ImageVersion"
USE_COMBINED_MODELS_IMAGE_ENV_NAME = "USE_COMBINED_MODELS_IMAGE"


class AgentError(Exception): ...
Expand Down Expand Up @@ -279,6 +280,7 @@ class Job:
container_image: str
worker_token: str
instance_type: str
model_type: str
log: Optional[Callable] = None
cloud_creds: Optional[List[CloudCreds]] = None
artifact_endpoint: Optional[str] = None
Expand All @@ -294,6 +296,7 @@ def from_dict(cls, source: dict, agent_config: AgentConfig) -> Job:
job_type=source["job_type"],
instance_type=source["instance_type"],
container_image=source["container_image"],
model_type=source.get("model_type", ""),
worker_token=source["worker_token"],
log=agent_config.log_factory(
source.get("run_id") or source.get("model_id")
Expand Down Expand Up @@ -485,6 +488,10 @@ def poll_endpoint(self) -> Optional[Job]:
]
if self._agent_config.cluster_guid:
api_kwargs["cluster_guid"] = self._agent_config.cluster_guid
# This feature flag exists so that we're backward compatible with hybrid deployments
# We can hard code this to True so that new releases get the updated combined image
if os.getenv(USE_COMBINED_MODELS_IMAGE_ENV_NAME) == "true":
api_kwargs["use_combined_models_image"] = True
next_job = self._jobs_api.receive_one(**api_kwargs)
if next_job["data"]["job"] is not None:
return Job.from_dict(next_job["data"]["job"], self._agent_config)
Expand Down
7 changes: 7 additions & 0 deletions src/gretel_client/agents/drivers/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
CPU_MODEL_WORKER_RESOURCES_ENV_NAME = "GRETEL_MODEL_WORKER_CPU_RESOURCES"
GPU_MODEL_WORKER_RESOURCES_ENV_NAME = "GRETEL_MODEL_WORKER_GPU_RESOURCES"
WORKER_MEMORY_GB_ENV_NAME = "MEMORY_LIMIT_IN_GB"
GRETEL_MODEL_TYPE_ENV_NAME = "GRETEL_MODEL_TYPE"
PULL_SECRET_ENV_NAME = "GRETEL_PULL_SECRET"
PULL_SECRETS_ENV_NAME = "GRETEL_PULL_SECRETS"
GPU_NODE_SELECTOR_ENV_NAME = "GPU_NODE_SELECTOR"
Expand Down Expand Up @@ -547,6 +548,12 @@ def _setup_environment_variables(
)
env.append(client.V1EnvVar(name="GRETEL_STAGE", value=job_config.gretel_stage))

env.append(
client.V1EnvVar(
name=GRETEL_MODEL_TYPE_ENV_NAME, value=job_config.model_type
)
)

if cpu_limit := resources.limits.get("cpu"):
cpu_quantity = parse_quantity(cpu_limit)
cpu_count = max(math.floor(cpu_quantity), 1)
Expand Down
5 changes: 5 additions & 0 deletions src/gretel_client/rest/api/jobs_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __receive_one(self, **kwargs):
runner_modes ([str]): [optional]
org_only (bool): Query for jobs within the same organization only. [optional]
cluster_guid (str): GUID of the cluster for which to retrieve jobs. [optional]
use_combined_models_image (bool): True results in the jobs' container_image field being set to the combined models image. [optional] if omitted the server will use the default value of False
_return_http_data_only (bool): response data without head status
code and headers. Default is True.
_preload_content (bool): if False, the urllib3.HTTPResponse object
Expand Down Expand Up @@ -117,6 +118,7 @@ def __receive_one(self, **kwargs):
"runner_modes",
"org_only",
"cluster_guid",
"use_combined_models_image",
],
"required": [],
"nullable": [],
Expand All @@ -140,20 +142,23 @@ def __receive_one(self, **kwargs):
"runner_modes": ([str],),
"org_only": (bool,),
"cluster_guid": (str,),
"use_combined_models_image": (bool,),
},
"attribute_map": {
"project_id": "project_id",
"project_ids": "project_ids",
"runner_modes": "runner_modes",
"org_only": "org_only",
"cluster_guid": "cluster_guid",
"use_combined_models_image": "use_combined_models_image",
},
"location_map": {
"project_id": "query",
"project_ids": "query",
"runner_modes": "query",
"org_only": "query",
"cluster_guid": "query",
"use_combined_models_image": "query",
},
"collection_format_map": {
"project_ids": "csv",
Expand Down

0 comments on commit 87cdb20

Please sign in to comment.