Skip to content

Commit

Permalink
Fix up according to linters
Browse files Browse the repository at this point in the history
  • Loading branch information
rizerphe committed Jul 3, 2023
1 parent cbeca90 commit 6aedcab
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
4 changes: 2 additions & 2 deletions openai_functions/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Literal, TYPE_CHECKING, overload

import openai
from openai.error import RateLimitError
import openai.error

from .functions.union import UnionSkillSet
from .openai_types import (
Expand Down Expand Up @@ -138,7 +138,7 @@ def _generate_message(
while True:
try:
response = self._generate_raw_message(function_call)
except RateLimitError as error:
except openai.error.RateLimitError as error:
if retries == 0:
raise
retries -= 1
Expand Down
5 changes: 3 additions & 2 deletions openai_functions/functions/basic_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def run_function(self, input_data: FunctionCall) -> FunctionResult:
Raises:
FunctionNotFoundError: If the function is not found
InvalidJsonError: If the arguments are not valid JSON
"""
function = self.find_function(input_data["name"])
try:
arguments = json.loads(input_data["arguments"])
except json.decoder.JSONDecodeError as e:
raise InvalidJsonError(input_data["arguments"]) from e
except json.decoder.JSONDecodeError as err:
raise InvalidJsonError(input_data["arguments"]) from err
result = self.get_function_result(function, arguments)
return FunctionResult(
function.name, result, function.remove_call, function.interpret_as_response
Expand Down
7 changes: 5 additions & 2 deletions openai_functions/functions/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def parse_arguments(self, arguments: dict[str, JsonType]) -> OrderedDict[str, An
Args:
arguments (dict[str, JsonType]): The arguments to parse
Raises:
BrokenSchemaError: If the arguments do not match the schema
Returns:
OrderedDict[str, Any]: The parsed arguments
"""
Expand All @@ -252,8 +255,8 @@ def parse_arguments(self, arguments: dict[str, JsonType]) -> OrderedDict[str, An
(name, argument_parsers[name].parse_value(value))
for name, value in arguments.items()
)
except KeyError as e:
raise BrokenSchemaError(arguments, self.arguments_schema) from e
except KeyError as err:
raise BrokenSchemaError(arguments, self.arguments_schema) from err

def __call__(self, arguments: dict[str, JsonType]) -> Any:
"""Call the wrapped function
Expand Down
49 changes: 28 additions & 21 deletions openai_functions/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,34 @@ def __call__(
...


@dataclass
class NLPWrapperConfig:
"""A configuration for the nlp decorator"""

name: str | None = None
description: str | None = None
serialize: bool = True

model: str = "gpt-3.5-turbo-0613"
system_prompt: str | None = None


class Wrapper(Generic[Param, Return]):
"""A wrapper for a function that provides a natural language interface"""

def __init__(
self,
origin: Callable[..., Return],
system_prompt: str | None = None,
model: str = "gpt-3.5-turbo-0613",
name: str | None = None,
description: str | None = None,
serialize: bool = True,
config: NLPWrapperConfig,
) -> None:
self.origin = origin
self.system_prompt = system_prompt
self.conversation = Conversation(model=model)
self.config = config
self.conversation = Conversation(model=config.model)
self.openai_function = FunctionWrapper(
self.origin,
WrapperConfig(serialize=serialize),
name=name,
description=description,
WrapperConfig(serialize=config.serialize),
name=config.name,
description=config.description,
)
self.conversation.add_function(self.openai_function)

Expand All @@ -67,11 +75,11 @@ def __call__(self, *args: Param.args, **kwds: Param.kwargs) -> Return:
def _initialize_conversation(self) -> None:
"""Initialize the conversation"""
self.conversation.clear_messages()
if self.system_prompt is not None:
if self.config.system_prompt is not None:
self.conversation.add_message(
{
"role": "system",
"content": self.system_prompt,
"content": self.config.system_prompt,
}
)

Expand Down Expand Up @@ -153,18 +161,17 @@ def _nlp(
The function, with natural language input, or a decorator to add natural
language input to a function
"""

wrapped: Wrapper[Param, Return] = Wrapper(
return Wrapper(
function,
system_prompt=system_prompt,
model=model,
name=name,
description=description,
serialize=serialize,
NLPWrapperConfig(
system_prompt=system_prompt,
model=model,
name=name,
description=description,
serialize=serialize,
),
)

return wrapped


@overload
def nlp(
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ basepython = python3.11
deps =
pylint
commands =
pylint openai_functions
pylint --fail-under 9 openai_functions

[testenv:flake8]
description = Flake8 environment
Expand Down

0 comments on commit 6aedcab

Please sign in to comment.