Skip to content

Commit

Permalink
Add support for Pydantic v2 (#288)
Browse files Browse the repository at this point in the history
* Add support for Pydantic v2

* Fix lint & mypy

* Fix comments

* Fix tests

* Fix message
  • Loading branch information
izellevy authored Feb 15, 2024
1 parent ae7b635 commit 04e43dc
Show file tree
Hide file tree
Showing 33 changed files with 120 additions and 123 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python = ">=3.9,<3.13"
python-dotenv = "^1.0.0"
openai = "^1.2.3"
tiktoken = "^0.3.3"
pydantic = "^1.10.7"
pydantic = "^2.0.0"
pandas-stubs = "^2.0.3.230814"
fastapi = ">=0.93.0, <1.0.0"
uvicorn = ">=0.20.0, <1.0.0"
Expand Down
2 changes: 1 addition & 1 deletion src/canopy/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def chat(self,
model_params=model_params_dict)
debug_info = {}
if CANOPY_DEBUG_INFO:
debug_info['context'] = context.dict()
debug_info['context'] = context.model_dump()
debug_info['context'].update(context.debug_info)

if stream:
Expand Down
14 changes: 6 additions & 8 deletions src/canopy/context_engine/context_builder/stuffing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from itertools import zip_longest
from typing import List, Tuple

Expand All @@ -23,15 +24,12 @@ class ContextQueryResult(BaseModel):


class StuffingContextContent(ContextContent):
__root__: List[ContextQueryResult]

def dict(self, **kwargs):
return super().dict(**kwargs)['__root__']
root: List[ContextQueryResult]

# In the case of StuffingContextBuilder, we simply want the text representation to
# be a json. Other ContextContent subclasses may render into text differently
def to_text(self, **kwargs):
return self.json(**kwargs)
return json.dumps(self.model_dump(), **kwargs)


# ------------- CONTEXT BUILDER -------------
Expand All @@ -52,10 +50,10 @@ def build(self,
ContextQueryResult(query=qr.query, snippets=[])
for qr in query_results]
debug_info = {"num_docs": len(sorted_docs_with_origin), "snippet_ids": []}
content = StuffingContextContent(__root__=context_query_results)
content = StuffingContextContent(context_query_results)

if self._tokenizer.token_count(content.to_text()) > max_context_tokens:
return Context(content=StuffingContextContent(__root__=[]),
return Context(content=StuffingContextContent([]),
num_tokens=1, debug_info=debug_info)

seen_doc_ids = set()
Expand All @@ -78,7 +76,7 @@ def build(self,

# remove queries with no snippets
content = StuffingContextContent(
__root__=[qr for qr in context_query_results if len(qr.snippets) > 0]
[qr for qr in context_query_results if len(qr.snippets) > 0]
)

return Context(content=content,
Expand Down
2 changes: 1 addition & 1 deletion src/canopy/context_engine/context_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def query(self, queries: List[Query],

if CANOPY_DEBUG_INFO:
context.debug_info["query_results"] = [
{**qr.dict(), **qr.debug_info} for qr in query_results
{**qr.model_dump(), **qr.debug_info} for qr in query_results
]
return context

Expand Down
6 changes: 3 additions & 3 deletions src/canopy/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def query(self,
query=rr.query,
documents=[
DocumentWithScore(
**d.dict(exclude={
**d.model_dump(exclude={
'document_id'
})
)
Expand All @@ -455,13 +455,13 @@ def query(self,
query=r.query,
documents=[
DocumentWithScore(
**d.dict(exclude={
**d.model_dump(exclude={
'document_id'
})
)
for d in r.documents
]
).dict()} if CANOPY_DEBUG_INFO else {}
).model_dump()} if CANOPY_DEBUG_INFO else {}
) for rr, r in zip(ranked_results, results)
]

Expand Down
6 changes: 4 additions & 2 deletions src/canopy/knowledge_base/record_encoder/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _encode_documents_batch(self,
encoded chunks: A list of KBEncodedDocChunk, with the `values` field populated by the generated embeddings vector.
""" # noqa: E501
dense_values = self._dense_encoder.encode_documents([d.text for d in documents])
return [KBEncodedDocChunk(**d.dict(), values=v) for d, v in
return [KBEncodedDocChunk(**d.model_dump(), values=v) for d, v in
zip(documents, dense_values)]

def _encode_queries_batch(self, queries: List[Query]) -> List[KBQuery]:
Expand All @@ -52,7 +52,9 @@ def _encode_queries_batch(self, queries: List[Query]) -> List[KBQuery]:
encoded queries: A list of KBQuery, with the `values` field populated by the generated embeddings vector.
""" # noqa: E501
dense_values = self._dense_encoder.encode_queries([q.text for q in queries])
return [KBQuery(**q.dict(), values=v) for q, v in zip(queries, dense_values)]
return [
KBQuery(**q.model_dump(), values=v) for q, v in zip(queries, dense_values)
]

@cached_property
def dimension(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/canopy/knowledge_base/record_encoder/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _encode_queries_batch(self, queries: List[Query]) -> List[KBQuery]:
zip(dense_queries, sparse_values)
]

return [q.copy(update=dict(values=v, sparse_values=sv)) for q, (v, sv) in
return [q.model_copy(update=dict(values=v, sparse_values=sv)) for q, (v, sv) in
zip(dense_queries, scaled_values)]

@property
Expand Down
2 changes: 1 addition & 1 deletion src/canopy/knowledge_base/reranker/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:

reranked_docs = []
for rerank_result in response:
doc = result.documents[rerank_result.index].copy(
doc = result.documents[rerank_result.index].model_copy(
deep=True,
update=dict(score=rerank_result.relevance_score)
)
Expand Down
4 changes: 2 additions & 2 deletions src/canopy/llm/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ def generate_documents_from_stuffing_context_content(
"""
documents = []

for result in content.__root__:
for result in content.root:
for snippet in result.snippets:
documents.append(snippet.dict())
documents.append(snippet.model_dump())

return documents
11 changes: 6 additions & 5 deletions src/canopy/llm/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, List, Union

from pydantic import BaseModel
from pydantic import BaseModel, model_serializer


class FunctionPrimitiveProperty(BaseModel):
Expand All @@ -17,8 +17,8 @@ class FunctionArrayProperty(BaseModel):
# because the model is more struggling with them
description: str

def dict(self, *args, **kwargs):
super_dict = super().dict(*args, **kwargs)
def model_dump(self, *args, **kwargs):
super_dict = super().model_dump(*args, **kwargs)
if "items_type" in super_dict:
super_dict["type"] = "array"
super_dict["items"] = {"type": super_dict.pop("items_type")}
Expand All @@ -32,11 +32,12 @@ class FunctionParameters(BaseModel):
required_properties: List[FunctionProperty]
optional_properties: List[FunctionProperty] = []

def dict(self, *args, **kwargs):
@model_serializer()
def serialize_model(self):
return {
"type": "object",
"properties": {
pro.name: pro.dict(exclude_none=True, exclude={"name"})
pro.name: pro.model_dump(exclude_none=True, exclude={"name"})
for pro in self.required_properties + self.optional_properties
},
"required": [pro.name for pro in self.required_properties],
Expand Down
26 changes: 14 additions & 12 deletions src/canopy/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import openai
import json

from openai.types.chat import ChatCompletionToolParam
from openai import Stream
from openai.types.chat import (ChatCompletionToolParam, ChatCompletionChunk,
ChatCompletion)
from tenacity import (
retry,
stop_after_attempt,
Expand Down Expand Up @@ -121,8 +123,8 @@ def chat_completion(self,
system_message = system_prompt
else:
system_message = system_prompt + f"\nContext: {context.to_text()}"
messages = [SystemMessage(content=system_message).dict()
] + [m.dict() for m in chat_history]
messages = [SystemMessage(content=system_message).model_dump()
] + [m.model_dump() for m in chat_history]
try:
response = self._client.chat.completions.create(model=model,
messages=messages,
Expand All @@ -131,14 +133,14 @@ def chat_completion(self,
except openai.OpenAIError as e:
self._handle_chat_error(e)

def streaming_iterator(response):
for chunk in response:
yield StreamingChatChunk.parse_obj(chunk)
def streaming_iterator(chunks: Stream[ChatCompletionChunk]):
for chunk in chunks:
yield StreamingChatChunk.model_validate(chunk.model_dump())

if stream:
return streaming_iterator(response)
return streaming_iterator(cast(Stream[ChatCompletionChunk], response))

return ChatResponse.parse_obj(response)
return ChatResponse.model_validate(cast(ChatCompletion, response).model_dump())

@retry(
reraise=True,
Expand Down Expand Up @@ -206,10 +208,10 @@ def enforced_function_call(self,
model = model_params_dict.pop("model", self.model_name)

function_dict = cast(ChatCompletionToolParam,
{"type": "function", "function": function.dict()})
{"type": "function", "function": function.model_dump()})

messages = [SystemMessage(content=system_prompt).dict()
] + [m.dict() for m in chat_history]
messages = [SystemMessage(content=system_prompt).model_dump()
] + [m.model_dump() for m in chat_history]
try:
chat_completion = self._client.chat.completions.create(
model=model,
Expand All @@ -226,7 +228,7 @@ def enforced_function_call(self,
result = chat_completion.choices[0].message.tool_calls[0].function.arguments
arguments = json.loads(result)

jsonschema.validate(instance=arguments, schema=function.parameters.dict())
jsonschema.validate(instance=arguments, schema=function.parameters.model_dump())
return arguments

async def achat_completion(self,
Expand Down
8 changes: 2 additions & 6 deletions src/canopy/models/api_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Sequence, Iterable

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field

from canopy.models.data_models import MessageBase

Expand All @@ -20,11 +20,7 @@ class _StreamChoice(BaseModel):
class TokenCounts(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: Optional[int] = None

@validator("total_tokens", always=True)
def calc_total_tokens(cls, v, values, **kwargs):
return values["prompt_tokens"] + values["completion_tokens"]
total_tokens: int


class ChatResponse(BaseModel):
Expand Down
21 changes: 10 additions & 11 deletions src/canopy/models/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from enum import Enum
from typing import Optional, List, Union, Dict, Literal

from pydantic import BaseModel, Field, validator, Extra
from typing import TypedDict
from pydantic import field_validator, ConfigDict, BaseModel, Field, RootModel
from typing_extensions import TypedDict

Metadata = Dict[str, Union[str, int, float, List[str]]]

Expand Down Expand Up @@ -42,11 +42,10 @@ class Document(BaseModel):
default_factory=dict,
description="The document metadata. To learn more about metadata, see https://docs.pinecone.io/docs/manage-data", # noqa: E501
)
model_config = ConfigDict(extra="forbid", coerce_numbers_to_str=True)

class Config:
extra = Extra.forbid

@validator("metadata")
@field_validator("metadata")
@classmethod
def metadata_reseved_fields(cls, v):
if "text" in v:
raise ValueError('Metadata cannot contain reserved field "text"')
Expand All @@ -57,7 +56,7 @@ def metadata_reseved_fields(cls, v):
return v


class ContextContent(BaseModel, ABC):
class ContextContent(RootModel, ABC):
# Any context should be able to be represented as well formatted text.
# In the most minimal case, that could simply be a call to `.json()`.
@abstractmethod
Expand All @@ -69,10 +68,10 @@ def __str__(self):


class StringContextContent(ContextContent):
__root__: str
root: str

def to_text(self, **kwargs) -> str:
return self.__root__
return self.root


class Context(BaseModel):
Expand All @@ -98,8 +97,8 @@ class MessageBase(BaseModel):
"Can be one of ['User', 'Assistant', 'System']")
content: str = Field(description="The contents of the message.")

def dict(self, *args, **kwargs):
d = super().dict(*args, **kwargs)
def model_dump(self, *args, **kwargs):
d = super().model_dump(*args, **kwargs)
d["role"] = d["role"].value
return d

Expand Down
4 changes: 2 additions & 2 deletions src/canopy/tokenizer/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def messages_token_count(self, messages: Messages) -> int:
num_tokens = 0
for message in messages:
num_tokens += self.MESSAGE_TOKENS_OVERHEAD
for key, value in message.dict().items():
for key, value in message.model_dump().items():
num_tokens += self.token_count(value)
num_tokens += self.FIXED_PREFIX_TOKENS
return num_tokens
Expand Down Expand Up @@ -191,7 +191,7 @@ def messages_token_count(self, messages: Messages) -> int:
num_tokens = 0
for message in messages:
num_tokens += self.MESSAGE_TOKENS_OVERHEAD
for key, value in message.dict().items():
for key, value in message.model_dump().items():
num_tokens += self.token_count(value)
num_tokens += self.FIXED_PREFIX_TOKENS
return num_tokens
2 changes: 1 addition & 1 deletion src/canopy/tokenizer/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def messages_token_count(self, messages: Messages) -> int:
num_tokens = 0
for message in messages:
num_tokens += self.MESSAGE_TOKENS_OVERHEAD
for key, value in message.dict().items():
for key, value in message.model_dump().items():
num_tokens += self.token_count(value)
num_tokens += self.FIXED_PREFIX_TOKENS
return num_tokens
2 changes: 1 addition & 1 deletion src/canopy/tokenizer/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def messages_token_count(self, messages: Messages) -> int:
num_tokens = 0
for message in messages:
num_tokens += self.MESSAGE_TOKENS_OVERHEAD
for key, value in message.dict().items():
for key, value in message.model_dump().items():
num_tokens += self.token_count(value)
num_tokens += self.FIXED_PREFIX_TOKENS
return num_tokens
2 changes: 1 addition & 1 deletion src/canopy_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def upsert(index_name: str,
)
raise CLIError(msg)
pd.options.display.max_colwidth = 20
click.echo(pd.DataFrame([doc.dict(exclude_none=True) for doc in data[:5]]))
click.echo(pd.DataFrame([doc.model_dump(exclude_none=True) for doc in data[:5]]))
click.echo(click.style(f"\nTotal records: {len(data)}"))
click.confirm(click.style("\nDoes this data look right?", fg="red"),
abort=True)
Expand Down
2 changes: 1 addition & 1 deletion src/canopy_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def chat(
session_id = request.user or "None" # noqa: F841
question_id = str(uuid.uuid4())
logger.debug(f"Received chat request: {request.messages[-1].content}")
model_params = request.dict(exclude={"messages", "stream"})
model_params = request.model_dump(exclude={"messages", "stream"})
answer = await run_in_threadpool(
chat_engine.chat,
messages=request.messages,
Expand Down
6 changes: 2 additions & 4 deletions src/canopy_server/models/v1/api_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field
from pydantic import ConfigDict, BaseModel, Field

from canopy.models.data_models import Messages, Query, Document

Expand Down Expand Up @@ -70,9 +70,7 @@ class ChatRequest(BaseModel):
default=None,
description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Unused, reserved for future extensions", # noqa: E501
)

class Config:
extra = "ignore"
model_config = ConfigDict(extra="ignore")


class ContextQueryRequest(BaseModel):
Expand Down
Loading

0 comments on commit 04e43dc

Please sign in to comment.