Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: e424718f4c45dbeaef3429f68e813a57b471dbd9
  • Loading branch information
Gretel Team authored and johnnygreco committed Oct 10, 2024
1 parent b0ec8e1 commit 845c293
Show file tree
Hide file tree
Showing 22 changed files with 1,115 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/gretel_client/navigator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from gretel_client.navigator.workflow import NavigatorWorkflow
3 changes: 3 additions & 0 deletions src/gretel_client/navigator/blueprints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from gretel_client.navigator.blueprints.text_to_code.blueprint import (
TextToCodeBlueprint,
)
16 changes: 16 additions & 0 deletions src/gretel_client/navigator/blueprints/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC


class NavigatorBlueprint(ABC):
"""Base class for all blueprint classes."""

@property
def name(self) -> str:
"""The name of the blueprint."""
return self.__class__.__name__

def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
return f"<{self.name}>"
Empty file.
Empty file.
159 changes: 159 additions & 0 deletions src/gretel_client/navigator/blueprints/text_to_code/blueprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

from gretel_client.gretel.config_setup import smart_load_yaml
from gretel_client.navigator.blueprints.base import NavigatorBlueprint
from gretel_client.navigator.blueprints.text_to_code.prompt_templates import (
CODE_PROMPT,
FIELD_GENERATION_PROMPT,
TEXT_PROMPT,
)
from gretel_client.navigator.blueprints.text_to_code.utils import display_nl2code_sample
from gretel_client.navigator.tasks import (
GenerateColumnFromTemplate,
GenerateContextualTags,
SeedFromContextualTags,
)
from gretel_client.navigator.tasks.base import Task
from gretel_client.navigator.tasks.io import Dataset
from gretel_client.navigator.workflow import NavigatorWorkflow

output_parser_instructions = {
"pass_through": "* Return only the requested text, without any additional comments or instructions.",
"json_array": "* Respond only with the list as a valid JSON array.",
}


@dataclass
class DataPreview:
dataset: Dataset
contextual_columns: list[dict]
blueprint_config: dict

def display_sample(self, index: Optional[int] = None, **kwargs):
if index is None:
record = self.dataset.sample(1).iloc[0]
else:
record = self.dataset.loc[index]
display_nl2code_sample(
lang=self.blueprint_config["lang"],
record=record,
contextual_columns=self.contextual_columns,
**kwargs,
)


class TextToCodeBlueprint(NavigatorBlueprint):

def __init__(self, config: Union[str, dict, Path]):
self.config = smart_load_yaml(config)
self.lang = self.config["lang"]
self.task_list = self._build_sequential_task_list()
self.workflow = NavigatorWorkflow.from_sequential_tasks(self.task_list)

self._contextual_columns = [
tag["name"] for tag in self.config["contextual_tags"]["tags"]
] + [
field["name"]
for field in self.config.get("additional_code_context_columns", [])
]

def _get_context_string(
self,
exclude_tags: Optional[list[str]] = None,
additional_context: Optional[list[str]] = None,
) -> str:
context_string = "\n".join(
[
f" * {tag['name'].replace('_', ' ').capitalize()}: {{{tag['name']}}}"
for tag in self.config["contextual_tags"]["tags"]
if tag["name"] not in (exclude_tags or [])
]
)
if additional_context is not None:
context_string += "\n" + "\n".join(
[
f" * {column.replace('_', ' ').capitalize()}: {{{column}}}"
for column in additional_context
]
)
return context_string

def _create_code_context_task(self, field):
exclude_columns = [
tag["name"]
for tag in self.config["contextual_tags"]["tags"]
if tag["name"] not in field["relevant_columns"]
]
llm_type = "nl" if field["generation_type"] == "text" else "code"
system_prompt = self.config[
field.get("generation_type", "text") + "_generation_instructions"
]
return GenerateColumnFromTemplate(
prompt_template=FIELD_GENERATION_PROMPT.format(
name=field["name"],
description=field["description"],
context=self._get_context_string(exclude_columns),
generation_type=field["generation_type"].capitalize(),
parser_instructions=output_parser_instructions[field["output_parser"]],
),
response_column_name=field["name"],
system_prompt=system_prompt,
workflow_label=f"{field['name'].replace('_', ' ')}",
llm_type=llm_type,
output_parser=field["output_parser"],
)

def _build_sequential_task_list(self) -> list[Task]:
generate_text_column = GenerateColumnFromTemplate(
prompt_template=TEXT_PROMPT.format(
lang=self.lang, context=self._get_context_string()
),
response_column_name="text",
system_prompt=self.config["text_generation_instructions"],
workflow_label="text prompt",
)

additional_context = []
additional_code_context_tasks = []
for field in self.config.get("additional_code_context_columns", []):
additional_context.append(field["name"])
additional_code_context_tasks.append(self._create_code_context_task(field))

generate_code_column = GenerateColumnFromTemplate(
prompt_template=CODE_PROMPT.format(
lang=self.lang,
context=self._get_context_string(additional_context=additional_context),
),
response_column_name="code",
system_prompt=self.config["code_generation_instructions"],
workflow_label="code prompt",
output_parser="extract_code",
)

return [
GenerateContextualTags(**self.config["contextual_tags"]),
SeedFromContextualTags(),
generate_text_column,
*additional_code_context_tasks,
generate_code_column,
]

def generate_dataset_preview(self):
results = self.workflow.generate_dataset_preview()

tags = {}
for tag in results.auxiliary_outputs[0]["tags"]:
tags[tag["name"]] = tag["seed_values"] + tag["generated_values"]

return DataPreview(
dataset=results.dataset,
contextual_columns=self._contextual_columns,
blueprint_config=self.config,
)

def submit_batch_job(self, num_records: int, project_name: Optional[str] = None):
self.workflow.submit_batch_job(
num_records=num_records, project_name=project_name
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
TEXT_PROMPT = """\
Your task is to generate the natural language component of a text-to-{lang} dataset, \
carefully following the given context and instructions.
### Context:
{context}
### Instructions:
* Generate text related to {lang} code based on the given context.
* Do NOT return any code in the response.
* Return only the requested text, without any additional comments or instructions.
### Text:
"""


CODE_PROMPT = """\
Your task is to generate {lang} code that corresponds to the text and context given below.
### Text:
{{text}}
### Context:
{context}
### Instructions:
* Remember to base your response on the given context.
* Include ONLY a SINGLE block of code WITHOUT ANY additional text.
### Code:
"""


FIELD_GENERATION_PROMPT = """\
Your task is to generate a `{name}` field in a dataset based on the given description and context.
### Description:
{description}
### Context:
{context}
### Instructions:
* Generate `{name}` as described above.
* Remember to base your response on the given context.
{parser_instructions}
### Response:
"""
50 changes: 50 additions & 0 deletions src/gretel_client/navigator/blueprints/text_to_code/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional, Union

import pandas as pd

from rich.console import Console
from rich.panel import Panel
from rich.syntax import Syntax
from rich.table import Table
from rich.text import Text

console = Console()


def display_nl2code_sample(
lang: str,
record: Union[dict, pd.Series],
contextual_columns: list[str],
theme: str = "dracula",
background_color: Optional[str] = None,
):
if isinstance(record, (dict, pd.Series)):
record = pd.DataFrame([record]).iloc[0]
else:
raise ValueError("record must be a dictionary or pandas Series")

table = Table(title="Contextual Columns")

for col in contextual_columns:
table.add_column(col.replace("_", " ").capitalize())
table.add_row(*[str(record[col]) for col in contextual_columns])

console.print(table)

panel = Panel(
Text(record.text, justify="left", overflow="fold"),
title="Text",
)
console.print(panel)

panel = Panel(
Syntax(
record.code,
lexer=lang.lower(),
theme=theme,
word_wrap=True,
background_color=background_color,
),
title="Code",
)
console.print(panel)
22 changes: 22 additions & 0 deletions src/gretel_client/navigator/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from gretel_client.navigator.client.interface import Client, ClientAdapter
from gretel_client.navigator.client.remote import RemoteClient


class ClientManager:

def __init__(self):
self._client = None

def get_client(self):
if self._client is None:
self.set_remote_client()
return self._client

def set_client(self, adapter: ClientAdapter):
self._client = Client(adapter)

def set_remote_client(self):
self.set_client(RemoteClient())


client_manager = ClientManager()
Loading

0 comments on commit 845c293

Please sign in to comment.