Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 16bf3a4b2bf28a0d0cb40d681e08d05153463f61
  • Loading branch information
Gretel Team authored and johnnygreco committed Oct 10, 2024
1 parent b0ec8e1 commit 3ac9943
Show file tree
Hide file tree
Showing 23 changed files with 1,111 additions and 2 deletions.
3 changes: 1 addition & 2 deletions src/gretel_client/inference_api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from gretel_client.config import ClientConfig, configure_session, get_session_config
from gretel_client.rest.api_client import ApiClient
from gretel_client.rest.configuration import Configuration

MODELS_API_PATH = "/v1/inference/models"

Expand Down Expand Up @@ -161,7 +160,7 @@ def __init__(
elif len(session_kwargs) > 0:
raise ValueError("cannot specify session arguments when passing a session")

if session.default_runner != "cloud" and not ".serverless." in session.endpoint:
if session.default_runner != "cloud" and ".serverless." not in session.endpoint:
raise GretelInferenceAPIError(
"Gretel's Inference API is currently only "
"available within Gretel Cloud. Your current runner "
Expand Down
1 change: 1 addition & 0 deletions src/gretel_client/navigator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from gretel_client.navigator.workflow import NavigatorWorkflow
3 changes: 3 additions & 0 deletions src/gretel_client/navigator/blueprints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from gretel_client.navigator.blueprints.text_to_code.blueprint import (
TextToCodeBlueprint,
)
16 changes: 16 additions & 0 deletions src/gretel_client/navigator/blueprints/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC


class NavigatorBlueprint(ABC):
"""Base class for all blueprint classes."""

@property
def name(self) -> str:
"""The name of the blueprint."""
return self.__class__.__name__

def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
return f"<{self.name}>"
Empty file.
Empty file.
157 changes: 157 additions & 0 deletions src/gretel_client/navigator/blueprints/text_to_code/blueprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

from gretel_client.gretel.config_setup import smart_load_yaml
from gretel_client.navigator.blueprints.base import NavigatorBlueprint
from gretel_client.navigator.blueprints.text_to_code.prompt_templates import (
CODE_PROMPT,
FIELD_GENERATION_PROMPT,
TEXT_PROMPT,
)
from gretel_client.navigator.blueprints.text_to_code.utils import display_nl2code_sample
from gretel_client.navigator.tasks import (
GenerateColumnFromTemplate,
GenerateContextualTags,
SeedFromContextualTags,
)
from gretel_client.navigator.tasks.base import Task
from gretel_client.navigator.tasks.io import Dataset
from gretel_client.navigator.workflow import NavigatorWorkflow

output_parser_instructions = {
"pass_through": "* Return only the requested text, without any additional comments or instructions.",
"json_array": "* Respond only with the list as a valid JSON array.",
}

output_parser_type_map = {
"str": "pass_through",
"string": "pass_through",
"text": "pass_through",
"json": "json",
"dict": "json",
"list": "json_array",
"json_array": "json_array",
"code": "extract_code",
}


@dataclass
class DataPreview:
dataset: Dataset
contextual_columns: list[dict]
blueprint_config: dict
tag_values: dict

def display_sample(self, index: Optional[int] = None, **kwargs):
if index is None:
record = self.dataset.sample(1).iloc[0]
else:
record = self.dataset.loc[index]
display_nl2code_sample(
lang=self.blueprint_config["programming_language"],
record=record,
contextual_columns=self.contextual_columns,
**kwargs,
)


class TextToCodeBlueprint(NavigatorBlueprint):

def __init__(self, config: Union[str, dict, Path], **session_kwargs):
self.config = smart_load_yaml(config)
self.lang = self.config["programming_language"]
self.task_list = self._build_sequential_task_list()
self.workflow = NavigatorWorkflow.from_sequential_tasks(
self.task_list, **session_kwargs
)

def _create_context_template(self, columns: list) -> str:
return "\n".join(
[f" * {c.replace('_', ' ').capitalize()}: {{{c}}}" for c in columns]
)

def _create_contextual_column_task(self, field) -> Task:
output_parser = output_parser_type_map[field["column_type"]]
generation_type = "text" if field["llm_type"] == "nl" else "code"
system_prompt = self.config[f"{generation_type}_generation_instructions"]
return GenerateColumnFromTemplate(
prompt_template=FIELD_GENERATION_PROMPT.format(
name=field["name"],
description=field["description"],
context=self._create_context_template(field["relevant_columns"]),
generation_type=generation_type.capitalize(),
parser_instructions=output_parser_instructions[output_parser],
),
response_column_name=field["name"],
system_prompt=system_prompt,
workflow_label=f"{field['name'].replace('_', ' ')}",
llm_type=field["llm_type"],
output_parser=output_parser,
)

def _build_sequential_task_list(self) -> list[Task]:
additional_context_columns = []
for field in self.config.get("additional_contextual_columns", []):
additional_context_columns.append(
self._create_contextual_column_task(field)
)

generate_text_column = GenerateColumnFromTemplate(
prompt_template=TEXT_PROMPT.format(
lang=self.lang,
context=self._create_context_template(
self.config["text_relevant_columns"]
),
),
llm_type="nl",
response_column_name="text",
system_prompt=self.config["text_generation_instructions"],
workflow_label="text prompt",
)

generate_code_column = GenerateColumnFromTemplate(
prompt_template=CODE_PROMPT.format(
lang=self.lang,
context=self._create_context_template(
self.config["code_relevant_columns"]
),
),
llm_type="nl",
response_column_name="code",
system_prompt=self.config["code_generation_instructions"],
workflow_label="code prompt",
output_parser="extract_code",
)

return [
GenerateContextualTags(**self.config["contextual_tags"]),
SeedFromContextualTags(),
*additional_context_columns,
generate_text_column,
generate_code_column,
]

def generate_dataset_preview(self) -> DataPreview:
results = self.workflow.generate_dataset_preview()

tags = {}
for tag in results.auxiliary_outputs[0]["tags"]:
tags[tag["name"]] = tag["seed_values"] + tag["generated_values"]

additional_context = self.config.get("additional_contextual_columns", [])
context_cols = [tag["name"] for tag in self.config["contextual_tags"]["tags"]]
return DataPreview(
dataset=results.dataset,
contextual_columns=context_cols
+ [field["name"] for field in additional_context],
blueprint_config=self.config,
tag_values=tags,
)

def submit_batch_job(
self, num_records: int, project_name: Optional[str] = None
) -> None:
self.workflow.submit_batch_job(
num_records=num_records, project_name=project_name
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
TEXT_PROMPT = """\
Your task is to generate the natural language component of a text-to-{lang} dataset, \
carefully following the given context and instructions.
### Context:
{context}
### Instructions:
* Generate text related to {lang} code based on the given context.
* Do NOT return any code in the response.
* Return only the requested text, without any additional comments or instructions.
### Text:
"""


CODE_PROMPT = """\
Your task is to generate {lang} code that corresponds to the text and context given below.
### Text:
{{text}}
### Context:
{context}
### Instructions:
* Remember to base your response on the given context.
* Include ONLY a SINGLE block of code WITHOUT ANY additional text.
### Code:
"""


FIELD_GENERATION_PROMPT = """\
Your task is to generate a `{name}` field in a dataset based on the given description and context.
### Description:
{description}
### Context:
{context}
### Instructions:
* Generate `{name}` as described above.
* Remember to base your response on the given context.
{parser_instructions}
### Response:
"""
50 changes: 50 additions & 0 deletions src/gretel_client/navigator/blueprints/text_to_code/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional, Union

import pandas as pd

from rich.console import Console
from rich.panel import Panel
from rich.syntax import Syntax
from rich.table import Table
from rich.text import Text

console = Console()


def display_nl2code_sample(
lang: str,
record: Union[dict, pd.Series],
contextual_columns: list[str],
theme: str = "dracula",
background_color: Optional[str] = None,
):
if isinstance(record, (dict, pd.Series)):
record = pd.DataFrame([record]).iloc[0]
else:
raise ValueError("record must be a dictionary or pandas Series")

table = Table(title="Contextual Columns")

for col in contextual_columns:
table.add_column(col.replace("_", " ").capitalize())
table.add_row(*[str(record[col]) for col in contextual_columns])

console.print(table)

panel = Panel(
Text(record.text, justify="left", overflow="fold"),
title="Text",
)
console.print(panel)

panel = Panel(
Syntax(
record.code,
lexer=lang.lower(),
theme=theme,
word_wrap=True,
background_color=background_color,
),
title="Code",
)
console.print(panel)
11 changes: 11 additions & 0 deletions src/gretel_client/navigator/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from gretel_client.config import configure_session
from gretel_client.navigator.client.interface import Client, ClientAdapter
from gretel_client.navigator.client.remote import RemoteClient


def get_navigator_client(
client_adapter: ClientAdapter = RemoteClient(), **session_kwargs
) -> Client:
validate = session_kwargs.get("validate", False)
configure_session(validate=validate, **session_kwargs)
return Client(client_adapter)
Loading

0 comments on commit 3ac9943

Please sign in to comment.