Skip to content

Commit

Permalink
Add natural language iapi to the high-level interface
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 8719b6d90c202130f718b9c10947119576696890
  • Loading branch information
johnnygreco committed May 2, 2024
1 parent 4c19ec2 commit b3708fe
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 67 deletions.
19 changes: 10 additions & 9 deletions src/gretel_client/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
GretelInferenceAPIError,
InferenceAPIModelType,
)
from gretel_client.inference_api.tabular import (
NAVIGATOR_DEFAULT_MODEL,
NavigatorInferenceAPI,
)
from gretel_client.inference_api.natural_language import NaturalLanguageInferenceAPI
from gretel_client.inference_api.tabular import NavigatorInferenceAPI

logger = logging.getLogger(__name__)
logger.propagate = False
Expand All @@ -37,31 +35,34 @@ def __init__(self, *, session: Optional[ClientConfig] = None, **session_kwargs):
def initialize_inference_api(
self,
model_type: InferenceAPIModelType = InferenceAPIModelType.NAVIGATOR,
*,
backend_model: Optional[str] = None,
) -> BaseInferenceAPI:
"""Initializes and returns a gretel inference API object.
Args:
model_type: The type of the inference API model.
backend_model: The model used under the hood by the inference API.
If None, the latest default model will be used.
Raises:
GretelInferenceAPIError: If the specified model type is not valid.
Returns:
An instance of the initialized inference API object.
"""

if model_type == InferenceAPIModelType.NAVIGATOR:
gretel_api = NavigatorInferenceAPI(
backend_model=backend_model or NAVIGATOR_DEFAULT_MODEL,
session=self._session,
)
inference_api_cls = NavigatorInferenceAPI
elif model_type == InferenceAPIModelType.NATURAL_LANGUAGE:
inference_api_cls = NaturalLanguageInferenceAPI
else:
raise GretelInferenceAPIError(
f"{model_type} is not a valid inference API model type."
f"Valid types are {[t.value for t in InferenceAPIModelType]}"
)
gretel_api = inference_api_cls(
backend_model=backend_model, session=self._session
)
logger.info("API path: %s%s", gretel_api.endpoint, gretel_api.api_path)
logger.info("Initialized %s 🚀", gretel_api.name)
return gretel_api
9 changes: 9 additions & 0 deletions src/gretel_client/gretel/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,22 @@
)


# Default parameters for the Navigator and Natural Language inference APIs.
@dataclass(frozen=True)
class NavigatorDefaultParams:
temperature: float = 0.7
top_k: int = 40
top_p: float = 0.95


@dataclass(frozen=True)
class NaturalLanguageDefaultParams:
temperature: float = 0.6
top_k: int = 43
top_p: float = 0.9
max_tokens: int = 512


class ModelType(str, Enum):
"""Name of the model parameter dict in the config.
Expand Down
66 changes: 47 additions & 19 deletions src/gretel_client/inference_api/base.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,43 @@
import json
import logging
import sys

from abc import ABC, abstractproperty
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from gretel_client import analysis_utils
from gretel_client.config import ClientConfig, configure_session, get_session_config
from gretel_client.dataframe import _DataFrameT

MODELS_API_PATH = "/v1/inference/models"

logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.setLevel(logging.INFO)


class GretelInferenceAPIError(Exception):
"""Raised when an error occurs with the Inference API."""


class InferenceAPIModelType(str, Enum):
NAVIGATOR = "navigator"
NATURAL_LANGUAGE = "natural_language"


class BaseInferenceAPI(ABC):
"""Base class for Gretel Inference API objects."""

_available_backend_models: list[str]
_model_type: str

def __init__(self, *, session: Optional[ClientConfig] = None, **session_kwargs):
def __init__(
self,
backend_model: Optional[str] = None,
*,
session: Optional[ClientConfig] = None,
**session_kwargs,
):
if session is None:
if len(session_kwargs) > 0:
configure_session(**session_kwargs)
Expand All @@ -38,15 +51,44 @@ def __init__(self, *, session: Optional[ClientConfig] = None, **session_kwargs):
"available within Gretel Cloud. Your current runner "
f"is configured to: {session.default_runner}"
)
self._api_client = session._get_api_client()
self.endpoint = session.endpoint
self._api_client = session._get_api_client()
self._available_backend_models = [
m for m in self._call_api("get", self.models_api_path).get("models", [])
]
self.backend_model = backend_model

@abstractproperty
def api_path(self) -> str: ...

@abstractproperty
def model_type(self) -> str: ...

@property
def backend_model_list(self) -> List[str]:
"""Returns list of backend models for this model type."""
return [
m["model_id"]
for m in self._available_backend_models
if m["model_type"].upper() == self.model_type.upper()
]

@property
def backend_model(self) -> str:
return self._backend_model

@backend_model.setter
def backend_model(self, backend_model: str) -> None:
if backend_model is None:
backend_model = self.backend_model_list[0]
elif backend_model not in self.backend_model_list:
raise GretelInferenceAPIError(
f"Model {backend_model} is not a valid backend model. "
f"Valid models are: {self.backend_model_list}"
)
self._backend_model = backend_model
logger.info("Backend model: %s", backend_model)

@property
def models_api_path(self) -> str:
return MODELS_API_PATH
Expand All @@ -55,20 +97,6 @@ def models_api_path(self) -> str:
def name(self) -> str:
return self.__class__.__name__

def display_dataframe_in_notebook(
self, dataframe: _DataFrameT, settings: Optional[dict] = None
) -> None:
"""Display pandas DataFrame in notebook with better settings for readability.
This function is intended to be used in a Jupyter notebook.
Args:
dataframe: The pandas DataFrame to display.
settings: Optional properties to set on the DataFrame's style.
If None, default settings with text wrapping are used.
"""
analysis_utils.display_dataframe_in_notebook(dataframe, settings)

def _call_api(
self,
method: str,
Expand Down
79 changes: 79 additions & 0 deletions src/gretel_client/inference_api/natural_language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
import sys

from gretel_client.gretel.config_setup import NaturalLanguageDefaultParams
from gretel_client.inference_api.base import BaseInferenceAPI

logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.setLevel(logging.INFO)

NATURAL_LANGUAGE_API_PATH = "/v1/inference/natural/"


class NaturalLanguageInferenceAPI(BaseInferenceAPI):
"""Inference API for real-time text generation with Gretel Natural Language.
Args:
backend_model (str, optional): The model that is used under the hood.
If None, the latest default model will be used. See the
`backend_model_list` property for a list of available models.
**session_kwargs: kwargs for your Gretel session.
Raises:
GretelInferenceAPIError: If the specified backend model is not valid.
"""

@property
def api_path(self) -> str:
return NATURAL_LANGUAGE_API_PATH

@property
def model_type(self) -> str:
return "natural"

@property
def generate_api_path(self) -> str:
return self.api_path + "generate"

@property
def name(self) -> str:
"""Returns display name for this inference api."""
return "Gretel Natural Language"

def generate(
self,
prompt: str,
temperature: float = NaturalLanguageDefaultParams.temperature,
max_tokens: int = NaturalLanguageDefaultParams.max_tokens,
top_p: float = NaturalLanguageDefaultParams.top_p,
top_k: int = NaturalLanguageDefaultParams.top_k,
):
"""Generate synthetic text.
Args:
prompt: The prompt for generating synthetic tabular data.
temperature: Sampling temperature. Higher values make output more random.
max_tokens: The maximum number of tokens to generate.
top_k: Number of highest probability tokens to keep for top-k filtering.
top_p: The cumulative probability cutoff for sampling tokens.
Returns:
The generated text as a string.
"""
response = self._call_api(
method="post",
path=self.generate_api_path,
body={
"model_id": self.backend_model,
"prompt": prompt,
"params": {
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
},
},
)
return response["text"]
61 changes: 28 additions & 33 deletions src/gretel_client/inference_api/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tqdm import tqdm

from gretel_client import analysis_utils
from gretel_client.dataframe import _DataFrameT
from gretel_client.gretel.artifact_fetching import PANDAS_IS_INSTALLED
from gretel_client.gretel.config_setup import NavigatorDefaultParams
Expand All @@ -25,20 +26,20 @@
MAX_ROWS_PER_STREAM = 50
REQUEST_TIMEOUT_SEC = 60
TABULAR_API_PATH = "/v1/inference/tabular/"
NAVIGATOR_DEFAULT_MODEL = "gretelai/tabular-v0"

PROGRESS_BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed},{rate_noinv_fmt}]"

StreamReturnType = Union[
Iterator[dict[str, Any]], _DataFrameT, List[dict[str, Any]], dict[str, Any]
]


class NavigatorInferenceAPI(BaseInferenceAPI):
"""Inference API for real-time data generation with Gretel's tabular LLM.
"""Inference API for real-time data generation with Gretel Navigator.
Args:
backend_model (str, optional): The model that is used under the hood.
See the `backend_model_list` property for a list of available models.
If None, the latest default model will be used. See the
`backend_model_list` property for a list of available models.
**session_kwargs: kwargs for your Gretel session.
Raises:
Expand All @@ -62,14 +63,14 @@ class NavigatorInferenceAPI(BaseInferenceAPI):
process raises a user-facing error.
"""

def __init__(self, backend_model: str = NAVIGATOR_DEFAULT_MODEL, **session_kwargs):
super().__init__(**session_kwargs)
self.backend_model = backend_model

@property
def api_path(self) -> str:
return TABULAR_API_PATH

@property
def model_type(self) -> str:
return "tabular"

@property
def stream_api_path(self) -> str:
return self.api_path + "stream"
Expand All @@ -82,35 +83,25 @@ def iterate_api_path(self) -> str:
def generate_api_path(self) -> str:
return self.api_path + "generate"

@property
def backend_model(self) -> str:
return self._backend_model

@backend_model.setter
def backend_model(self, backend_model: str) -> None:
if backend_model not in self.backend_model_list:
raise GretelInferenceAPIError(
f"Model {backend_model} is not a valid backend model. "
f"Valid models are: {self.backend_model_list}"
)
self._backend_model = backend_model
logger.info("Backend model: %s", backend_model)
self._reset_stream()

@property
def backend_model_list(self) -> List[str]:
"""Returns list of available tabular backend models."""
return [
m["model_id"]
for m in self._available_backend_models
if m["model_type"].upper() == "TABULAR"
]

@property
def name(self) -> str:
"""Returns display name for this inference api."""
return "Gretel Navigator"

def display_dataframe_in_notebook(
self, dataframe: _DataFrameT, settings: Optional[dict] = None
) -> None:
"""Display pandas DataFrame in notebook with better settings for readability.
This function is intended to be used in a Jupyter notebook.
Args:
dataframe: The pandas DataFrame to display.
settings: Optional properties to set on the DataFrame's style.
If None, default settings with text wrapping are used.
"""
analysis_utils.display_dataframe_in_notebook(dataframe, settings)

def _reset_stream(self) -> None:
"""Reset the stream state."""
self._curr_stream_id = None
Expand Down Expand Up @@ -304,7 +295,11 @@ def _get_stream_results(

generated_records = []
with tqdm(
total=num_records, desc=pbar_desc, disable=disable_pbar, initial=1
total=num_records,
desc=pbar_desc,
disable=disable_pbar,
unit=" records",
bar_format=PROGRESS_BAR_FORMAT,
) as pbar:
for record in stream_iterator:
generated_records.append(record)
Expand Down
Loading

0 comments on commit b3708fe

Please sign in to comment.