Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 4, 2023
1 parent 98ec11f commit 705000a
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 8 deletions.
3 changes: 3 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ In-context learning

.. automodule:: chemlift.icl.fewshotregressor
:members:

.. automodule:: chemlift.icl.fewshotpredictor
:members:
10 changes: 8 additions & 2 deletions src/chemlift/errorestimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
from gptchem.evaluator import get_regression_metrics


def estimate_rounding_error(y, num_digit):
def estimate_rounding_error(y, num_digit) -> dict:
"""
Estimates the regression performance metrics (minimal error)
due to rounding of the target values.
Args:
y: The target values.
num_digit: The number of digits to round to.
Returns:
A dictionary containing the regression performance metrics.
"""
rounded = np.round(y, num_digit)
return get_regression_metrics(y, rounded)

30 changes: 28 additions & 2 deletions src/chemlift/finetune/peftmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,19 @@ def predict(
formatted: Optional[pd.DataFrame] = None,
return_std: bool = False,
):
"""
Args:
X (ArrayLike): Input data (typically array of molecular representations)
temperature (float, optional): Temperature for sampling. Defaults to 0.7.
do_sample (bool, optional): Whether to sample or not. Defaults to False.
formatted (pd.DataFrame, optional): Formatted data (typically output of `formatter`).
Defaults to None. If None, X must be provided.
return_std (bool, optional): Whether to return the standard deviation of the predictions.
Defaults to False.
Returns:
ArrayLike: Predicted property values (and standard deviation if `return_std` is True)
"""
predictions = self._predict(
X=X, temperature=temperature, do_sample=do_sample, formatted=formatted
)
Expand Down Expand Up @@ -459,11 +472,24 @@ def __init__(
def predict(
self,
X: Optional[ArrayLike] = None,
temperature=0.7,
do_sample=False,
temperature: float = 0.7,
do_sample: int = False,
formatted: Optional[pd.DataFrame] = None,
return_std: bool = False,
):
"""
Args:
X (ArrayLike): Input data (typically array of molecular representations)
temperature (float, optional): Temperature for sampling. Defaults to 0.7.
do_sample (bool, optional): Whether to sample or not. Defaults to False.
formatted (pd.DataFrame, optional): Formatted data (typically output of `formatter`).
Defaults to None. If None, X must be provided.
return_std (bool, optional): Whether to return the standard deviation of the predictions.
Defaults to False.
Returns:
ArrayLike: Predicted property values (and standard deviation if `return_std` is True)
"""
predictions = self._predict(
X=X, temperature=temperature, do_sample=do_sample, formatted=formatted
)
Expand Down
12 changes: 10 additions & 2 deletions src/chemlift/icl/fewshotclassifier.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from loguru import logger
from numpy.typing import ArrayLike

from typing import Union
from chemlift.icl.fewshotpredictor import FewShotPredictor
from chemlift.icl.utils import LangChainChatModelWrapper


class FewShotClassifier(FewShotPredictor):
"""A few-shot classifier using in-context learning."""

intify = True

def _extract(self, generations, expected_len):
Expand All @@ -16,7 +19,6 @@ def _extract(self, generations, expected_len):
],
[],
)
print(generations, len(generations))
if len(generations) != expected_len:
logger.warning(f"Expected {expected_len} generations, got {len(generations)}")
return [None] * expected_len
Expand All @@ -33,5 +35,11 @@ def _extract(self, generations, expected_len):
return generations

def predict(self, X: ArrayLike, generation_kwargs: dict = {}):
"""Predict the class of a list of examples.
Args:
X: A list of examples.
generation_kwargs: Keyword arguments to pass to the language model.
"""
generations = self._predict(X, generation_kwargs)
return self._extract(generations, expected_len=len(X))
33 changes: 31 additions & 2 deletions src/chemlift/icl/fewshotpredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from more_itertools import chunked
from numpy.typing import ArrayLike
import enum
from typing import Union
from chemlift.icl.fewshotpredictor import FewShotPredictor
from chemlift.icl.utils import LangChainChatModelWrapper


class Strategy(enum.Enum):
Expand Down Expand Up @@ -39,14 +42,41 @@ class FewShotPredictor:

def __init__(
self,
llm: BaseLLM,
llm: Union[BaseLLM, LangChainChatModelWrapper],
property_name: str,
n_support: int = 5,
strategy: Strategy = Strategy.RANDOM,
seed: int = 42,
prefix: str = "You are an expert chemist. ",
max_test: int = 5,
):
"""Initialize the few-shot predictor.
Args:
llm (Union[BaseLLM, LangChainChatModelWrapper]): The language model to use.
property_name (str): The property to predict.
n_support (int, optional): The number of examples to use as support set.
Defaults to 5.
strategy (Strategy, optional): The strategy to use to pick the support set.
Defaults to Strategy.RANDOM.
seed (int, optional): The random seed to use. Defaults to 42.
prefix (str, optional): The prefix to use for the prompt.
Defaults to "You are an expert chemist. ".
max_test (int, optional): The maximum number of examples to predict at once.
Defaults to 5.
Raises:
ValueError: If the strategy is unknown.
Examples:
>>> from chemlift.icl.fewshotpredictor import FewShotPredictor
>>> from langchain.llms import OpenAI
>>> llm = OpenAI(model_name="text-ada-001")
>>> predictor = FewShotPredictor(llm, property_name="melting point")
>>> predictor.fit(["water", "ethanol"], [0, 1])
>>> predictor.predict(["methanol"])
[0]
"""
self._support_set = None
self._llm = llm
self._n_support = n_support
Expand Down Expand Up @@ -134,7 +164,6 @@ def _predict(self, X: ArrayLike, generation_kwargs: dict = {}):
else:
examples = self._format_examples(support_examples, support_targets)
queries = chunk[0]
allowed_values = ", ".join(map(str, list(self._allowed_values)))
prompt = self.template_single.format(
property_name=self._property_name,
query=queries,
Expand Down
File renamed without changes.

0 comments on commit 705000a

Please sign in to comment.