From 3ac9943a5c2be62eb361807fd9ad8f9c18b9f40a Mon Sep 17 00:00:00 2001 From: Gretel Team Date: Thu, 10 Oct 2024 12:05:04 -0400 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 16bf3a4b2bf28a0d0cb40d681e08d05153463f61 --- src/gretel_client/inference_api/base.py | 3 +- src/gretel_client/navigator/__init__.py | 1 + .../navigator/blueprints/__init__.py | 3 + .../navigator/blueprints/base.py | 16 ++ .../navigator/blueprints/config.py | 0 .../blueprints/text_to_code/__init__.py | 0 .../blueprints/text_to_code/blueprint.py | 157 ++++++++++++ .../text_to_code/prompt_templates.py | 49 ++++ .../blueprints/text_to_code/utils.py | 50 ++++ .../navigator/client/__init__.py | 11 + .../navigator/client/interface.py | 136 +++++++++++ src/gretel_client/navigator/client/remote.py | 225 ++++++++++++++++++ src/gretel_client/navigator/log.py | 17 ++ src/gretel_client/navigator/tasks/__init__.py | 10 + src/gretel_client/navigator/tasks/base.py | 78 ++++++ .../navigator/tasks/generate/__init__.py | 0 .../generate/generate_column_from_template.py | 60 +++++ .../generate/generate_contextual_tags.py | 67 ++++++ src/gretel_client/navigator/tasks/io.py | 3 + .../navigator/tasks/seed/__init__.py | 0 .../tasks/seed/seed_from_contextual_tags.py | 27 +++ .../navigator/tasks/seed/seed_from_records.py | 24 ++ src/gretel_client/navigator/workflow.py | 176 ++++++++++++++ 23 files changed, 1111 insertions(+), 2 deletions(-) create mode 100644 src/gretel_client/navigator/__init__.py create mode 100644 src/gretel_client/navigator/blueprints/__init__.py create mode 100644 src/gretel_client/navigator/blueprints/base.py create mode 100644 src/gretel_client/navigator/blueprints/config.py create mode 100644 src/gretel_client/navigator/blueprints/text_to_code/__init__.py create mode 100644 src/gretel_client/navigator/blueprints/text_to_code/blueprint.py create mode 100644 src/gretel_client/navigator/blueprints/text_to_code/prompt_templates.py create mode 100644 src/gretel_client/navigator/blueprints/text_to_code/utils.py create mode 100644 src/gretel_client/navigator/client/__init__.py create mode 100644 src/gretel_client/navigator/client/interface.py create mode 100644 src/gretel_client/navigator/client/remote.py create mode 100644 src/gretel_client/navigator/log.py create mode 100644 src/gretel_client/navigator/tasks/__init__.py create mode 100644 src/gretel_client/navigator/tasks/base.py create mode 100644 src/gretel_client/navigator/tasks/generate/__init__.py create mode 100644 src/gretel_client/navigator/tasks/generate/generate_column_from_template.py create mode 100644 src/gretel_client/navigator/tasks/generate/generate_contextual_tags.py create mode 100644 src/gretel_client/navigator/tasks/io.py create mode 100644 src/gretel_client/navigator/tasks/seed/__init__.py create mode 100644 src/gretel_client/navigator/tasks/seed/seed_from_contextual_tags.py create mode 100644 src/gretel_client/navigator/tasks/seed/seed_from_records.py create mode 100644 src/gretel_client/navigator/workflow.py diff --git a/src/gretel_client/inference_api/base.py b/src/gretel_client/inference_api/base.py index 365da5c..be70f21 100644 --- a/src/gretel_client/inference_api/base.py +++ b/src/gretel_client/inference_api/base.py @@ -8,7 +8,6 @@ from gretel_client.config import ClientConfig, configure_session, get_session_config from gretel_client.rest.api_client import ApiClient -from gretel_client.rest.configuration import Configuration MODELS_API_PATH = "/v1/inference/models" @@ -161,7 +160,7 @@ def __init__( elif len(session_kwargs) > 0: raise ValueError("cannot specify session arguments when passing a session") - if session.default_runner != "cloud" and not ".serverless." in session.endpoint: + if session.default_runner != "cloud" and ".serverless." not in session.endpoint: raise GretelInferenceAPIError( "Gretel's Inference API is currently only " "available within Gretel Cloud. Your current runner " diff --git a/src/gretel_client/navigator/__init__.py b/src/gretel_client/navigator/__init__.py new file mode 100644 index 0000000..fc2c8e3 --- /dev/null +++ b/src/gretel_client/navigator/__init__.py @@ -0,0 +1 @@ +from gretel_client.navigator.workflow import NavigatorWorkflow diff --git a/src/gretel_client/navigator/blueprints/__init__.py b/src/gretel_client/navigator/blueprints/__init__.py new file mode 100644 index 0000000..c13f183 --- /dev/null +++ b/src/gretel_client/navigator/blueprints/__init__.py @@ -0,0 +1,3 @@ +from gretel_client.navigator.blueprints.text_to_code.blueprint import ( + TextToCodeBlueprint, +) diff --git a/src/gretel_client/navigator/blueprints/base.py b/src/gretel_client/navigator/blueprints/base.py new file mode 100644 index 0000000..b29dfc1 --- /dev/null +++ b/src/gretel_client/navigator/blueprints/base.py @@ -0,0 +1,16 @@ +from abc import ABC + + +class NavigatorBlueprint(ABC): + """Base class for all blueprint classes.""" + + @property + def name(self) -> str: + """The name of the blueprint.""" + return self.__class__.__name__ + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"<{self.name}>" diff --git a/src/gretel_client/navigator/blueprints/config.py b/src/gretel_client/navigator/blueprints/config.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/blueprints/text_to_code/__init__.py b/src/gretel_client/navigator/blueprints/text_to_code/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/blueprints/text_to_code/blueprint.py b/src/gretel_client/navigator/blueprints/text_to_code/blueprint.py new file mode 100644 index 0000000..476daea --- /dev/null +++ b/src/gretel_client/navigator/blueprints/text_to_code/blueprint.py @@ -0,0 +1,157 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Union + +from gretel_client.gretel.config_setup import smart_load_yaml +from gretel_client.navigator.blueprints.base import NavigatorBlueprint +from gretel_client.navigator.blueprints.text_to_code.prompt_templates import ( + CODE_PROMPT, + FIELD_GENERATION_PROMPT, + TEXT_PROMPT, +) +from gretel_client.navigator.blueprints.text_to_code.utils import display_nl2code_sample +from gretel_client.navigator.tasks import ( + GenerateColumnFromTemplate, + GenerateContextualTags, + SeedFromContextualTags, +) +from gretel_client.navigator.tasks.base import Task +from gretel_client.navigator.tasks.io import Dataset +from gretel_client.navigator.workflow import NavigatorWorkflow + +output_parser_instructions = { + "pass_through": "* Return only the requested text, without any additional comments or instructions.", + "json_array": "* Respond only with the list as a valid JSON array.", +} + +output_parser_type_map = { + "str": "pass_through", + "string": "pass_through", + "text": "pass_through", + "json": "json", + "dict": "json", + "list": "json_array", + "json_array": "json_array", + "code": "extract_code", +} + + +@dataclass +class DataPreview: + dataset: Dataset + contextual_columns: list[dict] + blueprint_config: dict + tag_values: dict + + def display_sample(self, index: Optional[int] = None, **kwargs): + if index is None: + record = self.dataset.sample(1).iloc[0] + else: + record = self.dataset.loc[index] + display_nl2code_sample( + lang=self.blueprint_config["programming_language"], + record=record, + contextual_columns=self.contextual_columns, + **kwargs, + ) + + +class TextToCodeBlueprint(NavigatorBlueprint): + + def __init__(self, config: Union[str, dict, Path], **session_kwargs): + self.config = smart_load_yaml(config) + self.lang = self.config["programming_language"] + self.task_list = self._build_sequential_task_list() + self.workflow = NavigatorWorkflow.from_sequential_tasks( + self.task_list, **session_kwargs + ) + + def _create_context_template(self, columns: list) -> str: + return "\n".join( + [f" * {c.replace('_', ' ').capitalize()}: {{{c}}}" for c in columns] + ) + + def _create_contextual_column_task(self, field) -> Task: + output_parser = output_parser_type_map[field["column_type"]] + generation_type = "text" if field["llm_type"] == "nl" else "code" + system_prompt = self.config[f"{generation_type}_generation_instructions"] + return GenerateColumnFromTemplate( + prompt_template=FIELD_GENERATION_PROMPT.format( + name=field["name"], + description=field["description"], + context=self._create_context_template(field["relevant_columns"]), + generation_type=generation_type.capitalize(), + parser_instructions=output_parser_instructions[output_parser], + ), + response_column_name=field["name"], + system_prompt=system_prompt, + workflow_label=f"{field['name'].replace('_', ' ')}", + llm_type=field["llm_type"], + output_parser=output_parser, + ) + + def _build_sequential_task_list(self) -> list[Task]: + additional_context_columns = [] + for field in self.config.get("additional_contextual_columns", []): + additional_context_columns.append( + self._create_contextual_column_task(field) + ) + + generate_text_column = GenerateColumnFromTemplate( + prompt_template=TEXT_PROMPT.format( + lang=self.lang, + context=self._create_context_template( + self.config["text_relevant_columns"] + ), + ), + llm_type="nl", + response_column_name="text", + system_prompt=self.config["text_generation_instructions"], + workflow_label="text prompt", + ) + + generate_code_column = GenerateColumnFromTemplate( + prompt_template=CODE_PROMPT.format( + lang=self.lang, + context=self._create_context_template( + self.config["code_relevant_columns"] + ), + ), + llm_type="nl", + response_column_name="code", + system_prompt=self.config["code_generation_instructions"], + workflow_label="code prompt", + output_parser="extract_code", + ) + + return [ + GenerateContextualTags(**self.config["contextual_tags"]), + SeedFromContextualTags(), + *additional_context_columns, + generate_text_column, + generate_code_column, + ] + + def generate_dataset_preview(self) -> DataPreview: + results = self.workflow.generate_dataset_preview() + + tags = {} + for tag in results.auxiliary_outputs[0]["tags"]: + tags[tag["name"]] = tag["seed_values"] + tag["generated_values"] + + additional_context = self.config.get("additional_contextual_columns", []) + context_cols = [tag["name"] for tag in self.config["contextual_tags"]["tags"]] + return DataPreview( + dataset=results.dataset, + contextual_columns=context_cols + + [field["name"] for field in additional_context], + blueprint_config=self.config, + tag_values=tags, + ) + + def submit_batch_job( + self, num_records: int, project_name: Optional[str] = None + ) -> None: + self.workflow.submit_batch_job( + num_records=num_records, project_name=project_name + ) diff --git a/src/gretel_client/navigator/blueprints/text_to_code/prompt_templates.py b/src/gretel_client/navigator/blueprints/text_to_code/prompt_templates.py new file mode 100644 index 0000000..4ff6035 --- /dev/null +++ b/src/gretel_client/navigator/blueprints/text_to_code/prompt_templates.py @@ -0,0 +1,49 @@ +TEXT_PROMPT = """\ +Your task is to generate the natural language component of a text-to-{lang} dataset, \ +carefully following the given context and instructions. + +### Context: +{context} + +### Instructions: + * Generate text related to {lang} code based on the given context. + * Do NOT return any code in the response. + * Return only the requested text, without any additional comments or instructions. + +### Text: +""" + + +CODE_PROMPT = """\ +Your task is to generate {lang} code that corresponds to the text and context given below. + +### Text: +{{text}} + +### Context: +{context} + +### Instructions: + * Remember to base your response on the given context. + * Include ONLY a SINGLE block of code WITHOUT ANY additional text. + +### Code: +""" + + +FIELD_GENERATION_PROMPT = """\ +Your task is to generate a `{name}` field in a dataset based on the given description and context. + +### Description: +{description} + +### Context: +{context} + +### Instructions: + * Generate `{name}` as described above. + * Remember to base your response on the given context. + {parser_instructions} + +### Response: +""" diff --git a/src/gretel_client/navigator/blueprints/text_to_code/utils.py b/src/gretel_client/navigator/blueprints/text_to_code/utils.py new file mode 100644 index 0000000..4daccbb --- /dev/null +++ b/src/gretel_client/navigator/blueprints/text_to_code/utils.py @@ -0,0 +1,50 @@ +from typing import Optional, Union + +import pandas as pd + +from rich.console import Console +from rich.panel import Panel +from rich.syntax import Syntax +from rich.table import Table +from rich.text import Text + +console = Console() + + +def display_nl2code_sample( + lang: str, + record: Union[dict, pd.Series], + contextual_columns: list[str], + theme: str = "dracula", + background_color: Optional[str] = None, +): + if isinstance(record, (dict, pd.Series)): + record = pd.DataFrame([record]).iloc[0] + else: + raise ValueError("record must be a dictionary or pandas Series") + + table = Table(title="Contextual Columns") + + for col in contextual_columns: + table.add_column(col.replace("_", " ").capitalize()) + table.add_row(*[str(record[col]) for col in contextual_columns]) + + console.print(table) + + panel = Panel( + Text(record.text, justify="left", overflow="fold"), + title="Text", + ) + console.print(panel) + + panel = Panel( + Syntax( + record.code, + lexer=lang.lower(), + theme=theme, + word_wrap=True, + background_color=background_color, + ), + title="Code", + ) + console.print(panel) diff --git a/src/gretel_client/navigator/client/__init__.py b/src/gretel_client/navigator/client/__init__.py new file mode 100644 index 0000000..f7b61e5 --- /dev/null +++ b/src/gretel_client/navigator/client/__init__.py @@ -0,0 +1,11 @@ +from gretel_client.config import configure_session +from gretel_client.navigator.client.interface import Client, ClientAdapter +from gretel_client.navigator.client.remote import RemoteClient + + +def get_navigator_client( + client_adapter: ClientAdapter = RemoteClient(), **session_kwargs +) -> Client: + validate = session_kwargs.get("validate", False) + configure_session(validate=validate, **session_kwargs) + return Client(client_adapter) diff --git a/src/gretel_client/navigator/client/interface.py b/src/gretel_client/navigator/client/interface.py new file mode 100644 index 0000000..4ce7f09 --- /dev/null +++ b/src/gretel_client/navigator/client/interface.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Iterator, Optional, Type, Union + +import pandas as pd + +from pydantic import BaseModel, Field + + +def get_client(adapter: Union[Type[ClientAdapter], ClientAdapter]) -> Client: + if not isinstance(adapter, ClientAdapter): + adapter = adapter() + return Client(adapter) + + +class Client: + + _adapter: ClientAdapter + + def __init__(self, adapter: ClientAdapter): + self._adapter = adapter + + def run_task( + self, name: str, config: dict, inputs: list[TaskInput] = None + ) -> TaskOutput: + if inputs is None: + inputs = [] + return self._adapter.run_task(name, config, inputs) + + def get_workflow_preview(self, workflow_config: dict) -> Iterator: + return self._adapter.stream_workflow_outputs(workflow_config) + + def submit_batch_workflow( + self, + workflow_config: dict, + num_records: int, + project_name: Optional[str] = None, + ): + return self._adapter.submit_batch_workflow( + workflow_config, num_records, project_name + ) + + def registry(self) -> list[dict]: + return self._adapter.registry() + + +class ClientAdapter(ABC): + + @abstractmethod + def run_task( + self, name: str, config: dict, inputs: list[TaskInput] + ) -> TaskOutput: ... + + @abstractmethod + def stream_workflow_outputs(self, workflow: dict) -> Iterator: ... + + @abstractmethod + def registry(self) -> list[dict]: ... + + def submit_batch_workflow( + self, + workflow_config: dict, + num_records: int, + project_name: Optional[str] = None, + ): + raise NotImplementedError("Cannot submit batch Workflows") + + +class TaskOutput(ABC): + """ + Abstract TaskOutput class that represents the output of a task. + Task output (regardless of the client) is always a stream, so one way of consuming + this output is to iterate over it (`__iter__()`). + + Additionally, when the output is consumed, data outputs and attributes are captured + and can be retrieved with `data_outputs()` and `attribute_outputs()` methods. + + Note: if the stream wasn't consumed yet, calling these methods will consume the stream. + """ + + def __init__(self): + self._consumed = False + + def _ensure_consumed(self) -> None: + if not self._consumed: + self._consume() + + self._consumed = True + + def _consume(self) -> None: + if self._consumed: + return + + # exhaust the iterator, without doing anything with the records + for _ in self: + pass + + @abstractmethod + def as_input(self) -> list[TaskInput]: + """ + Converts this output to inputs that can be passed to other tasks. + """ + ... + + @abstractmethod + def data_outputs(self) -> list[pd.DataFrame]: ... + + @abstractmethod + def attribute_outputs(self) -> list[dict]: ... + + @abstractmethod + def __iter__(self) -> Iterator: ... + + +class StructuredInput(BaseModel): + dataset: Optional[list[dict]] = None + attributes: list[dict] = Field(default_factory=list) + + def serialize(self) -> dict: + return self.model_dump(exclude_none=True) + + +class TaskInput(BaseModel): + raw_data: Optional[bytes] = None + structured_data: Optional[StructuredInput] = None + + @classmethod + def from_dataset(cls, dataset: pd.DataFrame) -> TaskInput: + return cls( + structured_data=StructuredInput(dataset=dataset.to_dict(orient="records")) + ) + + @classmethod + def from_attribute(cls, name: str, value: object) -> TaskInput: + return cls(structured_data=StructuredInput(attributes=[{name: value}])) diff --git a/src/gretel_client/navigator/client/remote.py b/src/gretel_client/navigator/client/remote.py new file mode 100644 index 0000000..24cb064 --- /dev/null +++ b/src/gretel_client/navigator/client/remote.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import json +import logging + +from dataclasses import dataclass +from typing import Iterator, Optional, Union + +import pandas as pd +import requests + +from pydantic import BaseModel +from requests import Response +from rich import print + +from gretel_client.config import get_session_config +from gretel_client.gretel.interface import Gretel +from gretel_client.navigator.client.interface import ( + ClientAdapter, + StructuredInput, + TaskInput, + TaskOutput, +) +from gretel_client.navigator.log import get_logger + +DATA_FRAME_OUTPUT_TYPE = "data_frame" +STREAMING_RECORD_OUTPUT_TYPE = "streaming_record" +LOG_OUTPUT_TYPE = "log_line" +STEP_STATE_CHANGE_TYPE = "step_state_change" + +NON_ATTRIBUTE_OUTPUT_TYPES = { + DATA_FRAME_OUTPUT_TYPE, + STREAMING_RECORD_OUTPUT_TYPE, + LOG_OUTPUT_TYPE, + STEP_STATE_CHANGE_TYPE, +} + +gretel_interface_logger = logging.getLogger("gretel_client.gretel.interface") +gretel_interface_logger.setLevel(logging.WARNING) + +logger = get_logger(__name__, level="INFO") + + +@dataclass +class AttributeOutput: + name: str + data: object + + +class StepOutput(BaseModel): + step: str + type: str + # - list is used for data_frame outputs + # - dict is used for streaming_record and attribute outputs + output: Union[list, dict] + + def is_data_frame(self) -> bool: + return self.type == DATA_FRAME_OUTPUT_TYPE + + def is_attribute(self) -> bool: + return self.type not in NON_ATTRIBUTE_OUTPUT_TYPES + + def is_streaming_record(self) -> bool: + return self.type == STREAMING_RECORD_OUTPUT_TYPE + + def is_log(self) -> bool: + return self.type == LOG_OUTPUT_TYPE + + def is_step_state_change(self) -> bool: + return self.type == LOG_OUTPUT_TYPE + + +class RemoteTaskOutput(TaskOutput): + + def __init__(self, response: Response): + super().__init__() + + self._response = response + self._data_outputs: list[pd.DataFrame] = [] + self._attributes: list[AttributeOutput] = [] + + def _consume_single_output(self, record: dict) -> StepOutput: + step_output = StepOutput.model_validate(record) + + # Collect only the data outputs and attributes, the rest is passed through. + if step_output.is_data_frame(): + self._data_outputs.append(pd.DataFrame.from_records(step_output.output)) + elif step_output.is_attribute(): + self._attributes.append( + AttributeOutput(name=step_output.type, data=step_output.output) + ) + + return step_output + + def data_outputs(self) -> list[pd.DataFrame]: + self._ensure_consumed() + return self._data_outputs + + def attribute_outputs(self) -> list[dict]: + self._ensure_consumed() + return [{attr.name: attr.data} for attr in self._attributes] + + def as_input(self) -> list[TaskInput]: + inputs = [] + for output in self.data_outputs(): + inputs.append( + TaskInput( + structured_data=StructuredInput( + dataset=output.to_dict(orient="records") + ), + ) + ) + + attributes = [{output.name: output.data} for output in self._attributes] + if attributes: + inputs.append( + TaskInput(structured_data=StructuredInput(attributes=attributes)) + ) + + return inputs + + def __iter__(self) -> Iterator: + for json_str in self._response.iter_lines(decode_unicode=True): + try: + yield self._consume_single_output(json.loads(json_str)) + except json.JSONDecodeError: + logger.error(f"Failed to decode JSON record: {json_str!r}") + + self._consumed = True + self._response.close() + + +class RemoteClient(ClientAdapter): + + def __init__(self, jarvis_endpoint: str = "https://jarvis.dev.gretel.cloud"): + self._session = get_session_config() + self._req_headers = {"Authorization": self._session.api_key} + self._jarvis_endpoint = jarvis_endpoint + + print(f"🌎 Connecting to {self._jarvis_endpoint}") + + def run_task(self, name: str, config: dict, inputs: list[TaskInput]) -> TaskOutput: + if config is None: + config = {} + if inputs is None: + inputs = [] + + inputs_as_json = [] + for _input in inputs: + if _input.raw_data: + raise NotImplementedError( + "RemoteClient doesn't support raw data inputs." + ) + inputs_as_json.append(_input.structured_data.serialize()) + + response = requests.post( + f"{self._jarvis_endpoint}/tasks/exec", + json={"name": name, "config": config, "inputs": inputs_as_json}, + headers=self._req_headers, + stream=True, + ) + response.raise_for_status() + + return RemoteTaskOutput(response) + + def stream_workflow_outputs(self, workflow: dict) -> Iterator: + with requests.post( + f"{self._jarvis_endpoint}/workflows/exec_streaming", + json=workflow, + headers=self._req_headers, + stream=True, + ) as outputs: + outputs.raise_for_status() + + for output in outputs.iter_lines(): + yield json.loads(output.decode("utf-8")) + + def submit_batch_workflow( + self, + workflow_config: dict, + num_records: int, + project_name: Optional[str] = None, + ) -> dict: + + for step in workflow_config["steps"]: + if "num_records" in step["config"]: + step["config"]["num_records"] = num_records + + gretel = Gretel(session=self._session) + gretel.set_project(name=project_name) + project = gretel.get_project() + + logger.info( + f"🔗 Connecting to your [link={project.get_console_url()}]Gretel Project[/link]", + extra={"markup": True}, + ) + logger.info(f"🚀 Submitting batch workflow to generate {num_records} records") + + response = requests.post( + f"{self._jarvis_endpoint}/workflows/exec_batch", + json={ + "workflow_config": workflow_config, + "project_id": project.project_guid, + }, + headers=self._req_headers, + ) + response.raise_for_status() + workflow_ids = response.json() + workflow_run_url = ( + f"{project.get_console_url().replace(project.project_guid, '')}workflows/" + f"{workflow_ids['workflow_id']}/runs/{workflow_ids['workflow_run_id']}" + ) + + logger.info( + f"👀 Follow along: [link={workflow_run_url}]Workflow Run[/link]", + extra={"markup": True}, + ) + + def registry(self) -> list[dict]: + response = requests.get( + f"{self._jarvis_endpoint}/registry", headers=self._req_headers + ) + response.raise_for_status() + + return response.json()["tasks"] diff --git a/src/gretel_client/navigator/log.py b/src/gretel_client/navigator/log.py new file mode 100644 index 0000000..3261630 --- /dev/null +++ b/src/gretel_client/navigator/log.py @@ -0,0 +1,17 @@ +import logging + +from rich.console import Console +from rich.logging import RichHandler +from rich.theme import Theme + + +def get_logger(name: str, *, level: int = logging.INFO) -> logging.Logger: + logger = logging.getLogger(name) + logger.propagate = False + rich_handler = RichHandler( + console=Console(theme=Theme({"logging.level.info": "green"})) + ) + rich_handler.setFormatter(logging.Formatter("%(message)s", datefmt="[%X]")) + logger.addHandler(rich_handler) + logger.setLevel(logging.INFO) + return logger diff --git a/src/gretel_client/navigator/tasks/__init__.py b/src/gretel_client/navigator/tasks/__init__.py new file mode 100644 index 0000000..8d25872 --- /dev/null +++ b/src/gretel_client/navigator/tasks/__init__.py @@ -0,0 +1,10 @@ +from gretel_client.navigator.tasks.generate.generate_column_from_template import ( + GenerateColumnFromTemplate, +) +from gretel_client.navigator.tasks.generate.generate_contextual_tags import ( + GenerateContextualTags, +) +from gretel_client.navigator.tasks.seed.seed_from_contextual_tags import ( + SeedFromContextualTags, +) +from gretel_client.navigator.tasks.seed.seed_from_records import SeedFromRecords diff --git a/src/gretel_client/navigator/tasks/base.py b/src/gretel_client/navigator/tasks/base.py new file mode 100644 index 0000000..4e590d5 --- /dev/null +++ b/src/gretel_client/navigator/tasks/base.py @@ -0,0 +1,78 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union + +import pandas as pd + +from pydantic import BaseModel + +from gretel_client.navigator.client import ClientAdapter, get_navigator_client +from gretel_client.navigator.client.interface import StructuredInput, TaskInput +from gretel_client.navigator.tasks.io import Dataset + + +@dataclass +class TaskResults: + dataset: Optional[Dataset] = None + attributes: Optional[list[dict]] = None + + +class Task(ABC): + + def __init__(self, config: BaseModel, workflow_label: Optional[str] = None): + self.config = config + self.workflow_label = workflow_label + self._client = get_navigator_client() + + @staticmethod + def _create_task_inputs( + dataset: Optional[Dataset] = None, attributes: Optional[list[dict]] = None + ) -> list[TaskInput]: + if dataset is None and attributes is None: + return [] + structured_data = StructuredInput( + dataset=( + None + if dataset is None + else ( + dataset.to_dict(orient="records") + if isinstance(dataset, Dataset) + else dataset + ) + ), + attributes=attributes or [], + ) + return [TaskInput(structured_data=structured_data)] + + def _set_client(self, adapter: ClientAdapter): + """Set client adapter for task execution. + + This is an internal method that is not useable by end users. + """ + self._client = get_navigator_client(adapter) + + def _run( + self, + dataset: Optional[Union[Dataset, list[dict]]] = None, + attributes: Optional[list[dict]] = None, + ) -> TaskResults: + output = self._client.run_task( + name=self.name, + config=self.config.model_dump(), + inputs=self._create_task_inputs(dataset, attributes), + ) + return TaskResults( + dataset=( + pd.concat(out, axis=0, ignore_index=True) + if (out := output.data_outputs()) + else None + ), + attributes=output.attribute_outputs(), + ) + + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def run(self) -> TaskResults: ... diff --git a/src/gretel_client/navigator/tasks/generate/__init__.py b/src/gretel_client/navigator/tasks/generate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/tasks/generate/generate_column_from_template.py b/src/gretel_client/navigator/tasks/generate/generate_column_from_template.py new file mode 100644 index 0000000..f7903a3 --- /dev/null +++ b/src/gretel_client/navigator/tasks/generate/generate_column_from_template.py @@ -0,0 +1,60 @@ +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel + +from gretel_client.navigator.tasks.base import Task, TaskResults +from gretel_client.navigator.tasks.io import Dataset + +DEFAULT_RESPONSE_COLUMN_NAME = "response" + + +class LLMType(str, Enum): + NL = "nl" + CODE = "code" + JUDGE = "judge" + + +class TextParserType(str, Enum): + EXTRACT_CODE = "extract_code" + JSON = "json" + JSON_ARRAY = "json_array" + PASS_THROUGH = "pass_through" + + +class GenerateColumnFromTemplateConfig(BaseModel): + prompt_template: str + response_column_name: str = DEFAULT_RESPONSE_COLUMN_NAME + output_parser: TextParserType = TextParserType.PASS_THROUGH + llm_type: LLMType = LLMType.NL + system_prompt: Optional[str] = None + + +class GenerateColumnFromTemplate(Task): + + def __init__( + self, + prompt_template: str, + response_column_name: str = DEFAULT_RESPONSE_COLUMN_NAME, + output_parser: TextParserType = TextParserType.PASS_THROUGH, + llm_type: LLMType = LLMType.NL, + system_prompt: Optional[str] = None, + workflow_label: Optional[str] = None, + ): + super().__init__( + config=GenerateColumnFromTemplateConfig( + prompt_template=prompt_template, + response_column_name=response_column_name, + output_parser=output_parser, + llm_type=llm_type, + system_prompt=system_prompt, + ), + workflow_label=workflow_label, + ) + + @property + def name(self) -> str: + return "generate_column_from_template" + + def run(self, template_kwargs: Union[Dataset, list[dict]]) -> TaskResults: + return self._run(dataset=template_kwargs) diff --git a/src/gretel_client/navigator/tasks/generate/generate_contextual_tags.py b/src/gretel_client/navigator/tasks/generate/generate_contextual_tags.py new file mode 100644 index 0000000..cc03c02 --- /dev/null +++ b/src/gretel_client/navigator/tasks/generate/generate_contextual_tags.py @@ -0,0 +1,67 @@ +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel, Field +from typing_extensions import Self + +from gretel_client.gretel.config_setup import smart_load_yaml +from gretel_client.navigator.tasks.base import Task, TaskResults + + +class ContextualTag(BaseModel): + name: str + description: Optional[str] = None + seed_values: list[str] = Field(default=[]) + total_num_tags: Optional[int] = None + + @classmethod + def from_dicts(cls, tags: list[dict]) -> list[Self]: + return [cls(**tag) for tag in tags] + + +class GenerateContextualTagsConfig(BaseModel): + tags: list[ContextualTag] + task_context: str + + +class GenerateContextualTags(Task): + + def __init__( + self, + tags: Union[str, Path, list[dict], list[ContextualTag]], + task_context: str, + workflow_label: Optional[str] = None, + ): + super().__init__( + config=GenerateContextualTagsConfig( + tags=self._check_and_get_tags(tags), + task_context=task_context, + workflow_label=workflow_label, + ) + ) + + @staticmethod + def _check_and_get_tags( + tags: Union[str, Path, list[dict], list[ContextualTag]] + ) -> list[ContextualTag]: + if isinstance(tags, (str, Path)): + tags = smart_load_yaml(tags).get("tags") + + if not isinstance(tags, list): + raise ValueError("`tags` must be a list of dicts or ContextualTag objects") + + # Convert dicts to ContextualTag objects to ensure they are valid. + if all(isinstance(tag, dict) for tag in tags): + tags = ContextualTag.from_dicts(tags) + + if not all(isinstance(tag, ContextualTag) for tag in tags): + raise ValueError("`tags` must be a list of dicts or ContextualTag objects") + + return tags + + @property + def name(self) -> str: + return "generate_contextual_tags" + + def run(self) -> TaskResults: + return self._run() diff --git a/src/gretel_client/navigator/tasks/io.py b/src/gretel_client/navigator/tasks/io.py new file mode 100644 index 0000000..3662f1b --- /dev/null +++ b/src/gretel_client/navigator/tasks/io.py @@ -0,0 +1,3 @@ +import pandas as pd + +Dataset = pd.DataFrame diff --git a/src/gretel_client/navigator/tasks/seed/__init__.py b/src/gretel_client/navigator/tasks/seed/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gretel_client/navigator/tasks/seed/seed_from_contextual_tags.py b/src/gretel_client/navigator/tasks/seed/seed_from_contextual_tags.py new file mode 100644 index 0000000..f97322b --- /dev/null +++ b/src/gretel_client/navigator/tasks/seed/seed_from_contextual_tags.py @@ -0,0 +1,27 @@ +from typing import Optional + +from pydantic import BaseModel + +from gretel_client.navigator.tasks.base import Task, TaskResults + + +class SeedFromContextualTagsConfig(BaseModel): + num_records: int = 10 + + +class SeedFromContextualTags(Task): + + def __init__(self, num_records: int = 10, workflow_label: Optional[str] = None): + super().__init__( + config=SeedFromContextualTagsConfig(num_records=num_records), + workflow_label=workflow_label, + ) + + @property + def name(self): + return "seed_from_contextual_tags" + + def run(self, contextual_tags=list[dict]) -> TaskResults: + if self.config.num_records > 10: + raise ValueError("You can only preview up to to 10 records at a time.") + return self._run(attributes=contextual_tags) diff --git a/src/gretel_client/navigator/tasks/seed/seed_from_records.py b/src/gretel_client/navigator/tasks/seed/seed_from_records.py new file mode 100644 index 0000000..0aba4ff --- /dev/null +++ b/src/gretel_client/navigator/tasks/seed/seed_from_records.py @@ -0,0 +1,24 @@ +from typing import Optional + +from pydantic import BaseModel + +from gretel_client.navigator.tasks.base import Task, TaskResults + + +class SeedFromRecordsConfig(BaseModel): + records: list[dict] + + +class SeedFromRecords(Task): + + def __init__(self, records: list[dict], workflow_label: Optional[str] = None): + super().__init__( + config=SeedFromRecordsConfig(records=records), workflow_label=workflow_label + ) + + @property + def name(self) -> str: + return "seed_from_records" + + def run(self) -> TaskResults: + return self._run(attributes=[{"records": self.config.records}]) diff --git a/src/gretel_client/navigator/workflow.py b/src/gretel_client/navigator/workflow.py new file mode 100644 index 0000000..2d94e0f --- /dev/null +++ b/src/gretel_client/navigator/workflow.py @@ -0,0 +1,176 @@ +import json + +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import yaml + +from pydantic import BaseModel +from typing_extensions import Self + +from gretel_client.navigator.client import get_navigator_client +from gretel_client.navigator.log import get_logger +from gretel_client.navigator.tasks.base import Task + +logger = get_logger(__name__, level="DEBUG") + +DEFAULT_WORKFLOW_NAME = "navigator-workflow" + +TASK_TYPE_EMOJI_MAP = { + "generate": "🦜", + "seed": "🌱", + "tags": "🏷️", +} + + +def _get_task_log_emoji(task_name: str) -> str: + log_emoji = "" + for task_type, emoji in TASK_TYPE_EMOJI_MAP.items(): + if task_name.startswith(task_type): + log_emoji = emoji + " " + return log_emoji + + +@dataclass +class WorkflowResults: + dataset: pd.DataFrame + auxiliary_outputs: Optional[list[dict]] = None + + +class Step(BaseModel): + name: Optional[str] = None + task: str + config: dict + inputs: Optional[list[str]] = [] + + +class NavigatorWorkflow: + + def __init__( + self, + steps: Optional[list[Step]] = None, + workflow_name: Optional[str] = None, + **session_kwargs, + ): + self._workflow_name = ( + workflow_name + or f"{DEFAULT_WORKFLOW_NAME}-{datetime.now().isoformat(timespec='seconds')}" + ) + self._client = get_navigator_client(**session_kwargs) + self._steps = steps or [] + + @classmethod + def from_sequential_tasks( + cls, task_list: list[Task], workflow_name: str = None + ) -> Self: + step_names = [] + workflow = cls(workflow_name=workflow_name) + for i in range(len(task_list)): + inputs = [] + task = task_list[i] + suffix = "" if task.workflow_label is None else f"-{task.workflow_label}" + step_names.append( + f"{task.name}-{i + 1}{suffix}".replace("_", "-").replace(" ", "-") + ) + if i > 0: + prev_name = step_names[i - 1] + inputs = [prev_name] + workflow.add_step( + Step( + name=step_names[i], + task=task.name, + config=task.config.model_dump(), + inputs=inputs, + ) + ) + return workflow + + @classmethod + def from_yaml(cls, yaml_str: str) -> Self: + yaml_dict = yaml.safe_load(yaml_str) + workflow = cls(workflow_name=yaml_dict["name"]) + workflow.add_steps([Step(**step) for step in yaml_dict["steps"]]) + return workflow + + def add_step(self, step: Step) -> None: + self._steps.append(step) + + def add_steps(self, steps: list[Step]) -> None: + self._steps.extend(steps) + + def to_dict(self) -> dict: + return dict( + name=self._workflow_name, + steps=list( + map(lambda x: x.model_dump() if isinstance(x, Step) else x, self._steps) + ), + ) + + def to_json(self, file_path: Optional[Union[Path, str]] = None) -> Optional[str]: + json_str = json.dumps(self.to_dict(), indent=4) + if file_path is None: + return json_str + with open(file_path, "w") as f: + f.write(json_str) + + def to_yaml(self, file_path: Optional[Union[Path, str]] = None) -> Optional[str]: + yaml_str = yaml.dump(json.loads(self.to_json()), default_flow_style=False) + if file_path is None: + return yaml_str + with open(file_path, "w") as f: + f.write(yaml_str) + + def generate_dataset_preview(self) -> WorkflowResults: + current_step = None + auxiliary_outputs = [] + last_step_data_outputs = [] + + logger.info("🚀 Generating dataset preview") + + step_idx = 0 + for step_output in self._client.get_workflow_preview(self.to_dict()): + if not isinstance(step_output, dict): + step_output = step_output.as_dict() + + logger.debug(f"Step output: {json.dumps(step_output, indent=4)}") + + if step_output["step"] != current_step: + current_step = step_output["step"] + # Hacky way to get a decently formatted log output + task_name = self._steps[step_idx].task.replace("_", "-") + step_name = step_output["step"].replace("-" + str(step_idx + 1), "") + label = ( + "" + if task_name == step_name + else f" >>{step_name.split(task_name)[-1].replace('-', ' ')}" + ) + logger.info( + f"{_get_task_log_emoji(task_name)}Step {step_idx + 1}: " + f"{task_name.replace('-', ' ').capitalize()}{label}" + ) + + if step_output["type"] != "step_state_change": + step_idx += 1 + if ( + step_output["step"] == self._steps[-1].name + and step_output["type"] == "data_frame" + ): + last_step_data_outputs.append(step_output["output"]) + elif step_output["type"] != "data_frame": + auxiliary_outputs.append(step_output["output"]) + + df_list = [pd.DataFrame.from_records(r) for r in last_step_data_outputs] + logger.info("👀 Your preview is ready for a peek!") + + return WorkflowResults( + dataset=pd.concat(df_list, axis=0), + auxiliary_outputs=auxiliary_outputs, + ) + + def submit_batch_job(self, num_records: int, project_name: Optional[str] = None): + return self._client.submit_batch_workflow( + self.to_dict(), num_records, project_name + )