Skip to content
This repository has been archived by the owner on Aug 27, 2024. It is now read-only.

Commit

Permalink
Refactor vector store and encoders, decoupling encoder from vector st…
Browse files Browse the repository at this point in the history
…ore (#41)
  • Loading branch information
maxyu1115 authored Oct 15, 2023
1 parent 41d7d56 commit 8d5affa
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 40 deletions.
1 change: 1 addition & 0 deletions benchmark-tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This subdirectory is meant for random benchmarking experiments on memas internal components.
86 changes: 86 additions & 0 deletions benchmark-tests/compare_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import datasets
from datetime import datetime
from memas.interface.encoder import TextEncoder
from memas.encoder import openai_ada_encoder, universal_sentence_encoder
from memas.text_parsing import text_parsers


def prep_dataset():
wikipedia = datasets.load_dataset("wikipedia", "20220301.en")
test_sentences = []
i = 0

start = datetime.now()
for row in wikipedia["train"]:
test_sentences.extend(text_parsers.split_doc(row["text"], 1024))
i += 1
if i > 10:
break
end = datetime.now()
print(f"Splitting {i} documents into {len(test_sentences)} sentences took {(end - start).total_seconds()}s")

batch_sentences = {}
for batch_size in [5, 10, 20, 50, 100]:
batched_list = [test_sentences[i:i+batch_size] for i in range(0, len(test_sentences), batch_size)]
# pop the last one since likely not fully populated
batched_list.pop()
batch_sentences[batch_size] = batched_list

return test_sentences, batch_sentences


def benchmark_single(test_sentences: list[str], encoder: TextEncoder):
start = datetime.now()
i = 0
for sentence in test_sentences:
i += 1
try:
encoder.embed(sentence)
except Exception as err:
print(err)
print(f"{i}!", sentence)

end = datetime.now()
return (end - start).total_seconds()


def benchmark_batch(batched_list: list[list[str]], encoder: TextEncoder):
start = datetime.now()
i = 0
for batch in batched_list:
i += 1
try:
encoder.embed_multiple(batch)
except Exception as err:
print(err)
print(f"{i}!", batch)
end = datetime.now()
return (end - start).total_seconds()


def compare_encoders(encoders: dict[str, TextEncoder]):
test_sentences, batch_sentences = prep_dataset()
print(len(test_sentences))
output = {"single": {}}
for name, encoder in encoders.items():
single = benchmark_single(test_sentences, encoder)
print(f"[{name}] Single: total {single}s, avg {single/len(test_sentences)}s per item")
output["single"][name] = (single, single/len(test_sentences))

for batch_size, batched_list in batch_sentences.items():
output[batch_size] = {}
for name, encoder in encoders.items():
batch_time = benchmark_batch(batched_list, encoder)
output[batch_size][name] = (batch_time, batch_time/len(batched_list))
print(f"[{name}] {batch_size} batch: total {batch_time}s, avg {batch_time/len(batched_list)}s per item")
return output


if __name__ == "__main__":
USE_encoder = universal_sentence_encoder.USETextEncoder()
USE_encoder.init()
output = compare_encoders({
"ada": openai_ada_encoder.ADATextEncoder("PLACE_HOLDER"),
"use": USE_encoder
})
print(output)
4 changes: 4 additions & 0 deletions benchmark-tests/requirements-no-deps.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# This is needed to resolve a bug with apache beam + datasets
# Read https://github.com/huggingface/datasets/issues/5613 for more details
multiprocess==0.70.11
dill==0.3.6
7 changes: 7 additions & 0 deletions benchmark-tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ipykernel
datasets==2.13.1
apache_beam==2.49.0
openai

memas-sdk
memas-client
4 changes: 4 additions & 0 deletions benchmark-tests/setup-env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
pip install -r requirements.txt
# TODO: remove this after beam/datasets package upgrade
pip install --no-deps -r requirements-no-deps.txt
4 changes: 2 additions & 2 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# TODO: properly create different sets of configs and load them according to scenario

corpus_doc_store.CORPUS_INDEX = "memas-integ-test"
corpus_vector_store.USE_COLLECTION_NAME = "memas_USE_integ_test"
corpus_vector_store.ENCODER_COLLECTION_NAME = "memas_{encoder}_integ_test"


CONFIG_PATH = "../integration-tests/integ-test-config.yml"
Expand All @@ -33,7 +33,7 @@ def clean_resources():

try:
milvus_connection.connect("default", host=constants.milvus_ip, port=constants.milvus_port)
utility.drop_collection(collection_name=corpus_vector_store.USE_COLLECTION_NAME)
utility.drop_collection(collection_name=corpus_vector_store.ENCODER_COLLECTION_NAME.format(encoder="USE"))
milvus_connection.disconnect("default")
except Exception:
pass
Expand Down
5 changes: 3 additions & 2 deletions integration-tests/storage_driver/test_corpus_vector_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import uuid
import time
from memas.encoder.universal_sentence_encoder import USETextEncoder
from memas.interface.storage_driver import DocumentEntity
from memas.storage_driver.corpus_vector_store import MilvusUSESentenceVectorStore
from memas.storage_driver.corpus_vector_store import MilvusSentenceVectorStore

store = MilvusUSESentenceVectorStore()
store = MilvusSentenceVectorStore(USETextEncoder())


def test_init():
Expand Down
3 changes: 2 additions & 1 deletion memas/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cassandra.cqlengine import connection as c_connection
from elasticsearch import Elasticsearch
from pymilvus import connections as milvus_connection
from memas.encoder.universal_sentence_encoder import USETextEncoder
from memas.interface.exceptions import IllegalStateException
from memas.interface.storage_driver import CorpusDocumentMetadataStore, CorpusDocumentStore, CorpusVectorStore, MemasMetadataStore
from memas.storage_driver import corpus_doc_metadata, corpus_doc_store, corpus_vector_store, memas_metadata
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(self, app_config: Config):
# Data Stores
self.memas_metadata: MemasMetadataStore = memas_metadata.SINGLETON
self.corpus_metadata: CorpusDocumentMetadataStore = corpus_doc_metadata.SINGLETON
self.corpus_vec: CorpusVectorStore = corpus_vector_store.SINGLETON
self.corpus_vec: CorpusVectorStore = corpus_vector_store.MilvusSentenceVectorStore(USETextEncoder())
self.corpus_doc: CorpusDocumentStore

# clients
Expand Down
22 changes: 22 additions & 0 deletions memas/encoder/openai_ada_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import openai
from memas.interface.encoder import TextEncoder


ADA_MODEL="text-embedding-ada-002"


class ADATextEncoder(TextEncoder):
def __init__(self, api_key) -> None:
super().__init__(ENCODER_NAME="ADA", VECTOR_DIMENSION=1536)
openai.api_key = api_key

def init(self):
pass

def embed(self, text: str) -> np.ndarray:
return np.array(openai.Embedding.create(input = [text], model=ADA_MODEL)['data'][0]['embedding'])

def embed_multiple(self, text_list: list[str]) -> list[np.ndarray]:
embeddings = openai.Embedding.create(input = text_list, model=ADA_MODEL)['data']
return [np.array(resp['embedding']) for resp in embeddings]
4 changes: 2 additions & 2 deletions memas/encoder/universal_sentence_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

class USETextEncoder(TextEncoder):
def __init__(self, model_url: str = USE_LOCAL_MODEL_URL) -> None:
super().__init__()
super().__init__(ENCODER_NAME="USE", VECTOR_DIMENSION=USE_VECTOR_DIMENSION)
self.model_url: str = model_url

def init(self):
self.encoder = hub.load(self.model_url)

def embed(self, text: str) -> np.ndarray:
return self.encoder(text).numpy()
return self.encoder([text]).numpy()

def embed_multiple(self, text_list: list[str]) -> list[np.ndarray]:
return [x.numpy() for x in self.encoder(text_list)]
4 changes: 4 additions & 0 deletions memas/interface/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


class TextEncoder(ABC):
def __init__(self, ENCODER_NAME: str, VECTOR_DIMENSION: int) -> None:
self.ENCODER_NAME: str = ENCODER_NAME
self.VECTOR_DIMENSION: str = VECTOR_DIMENSION

@abstractmethod
def init(self):
"""Initialize the encoder
Expand Down
64 changes: 31 additions & 33 deletions memas/storage_driver/corpus_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
DataType,
Collection,
)
from memas.encoder.universal_sentence_encoder import USE_VECTOR_DIMENSION, USETextEncoder
from memas.interface.encoder import TextEncoder
from memas.interface.storage_driver import CorpusVectorStore, DocumentEntity
from memas.text_parsing.text_parsers import split_doc


_log = logging.getLogger(__name__)


USE_COLLECTION_NAME = "corpus_USE_sentence_store"
ENCODER_COLLECTION_NAME = "corpus_{encoder}_sentence_store"


COMPOSITE_ID = "composite_id"
CORPUS_FIELD = "corpus_id"
DOCUMENT_NAME = "document_name"
EMBEDDING_FIELD = "embedding"
Expand All @@ -32,26 +33,8 @@
MAX_TEXT_LENGTH = 1024


fields = [
# The first 32 length is the document id, while the later 32 is the sentence id.
# the sentence id is just used to avoid key collision.
FieldSchema(name="composite_id", dtype=DataType.VARCHAR,
max_length=64, is_primary=True, auto_id=False),
FieldSchema(name=CORPUS_FIELD, dtype=DataType.VARCHAR,
max_length=32, is_partition_key=True),
FieldSchema(name=DOCUMENT_NAME, dtype=DataType.VARCHAR,
max_length=256),
FieldSchema(name=TEXT_PREVIEW, dtype=DataType.VARCHAR, max_length=MAX_TEXT_LENGTH),
FieldSchema(name=EMBEDDING_FIELD, dtype=DataType.FLOAT_VECTOR, dim=USE_VECTOR_DIMENSION),
FieldSchema(name=START_FIELD, dtype=DataType.INT64),
FieldSchema(name=END_FIELD, dtype=DataType.INT64),
]
sentance_schema = CollectionSchema(
fields, "Corpus Vector Table for storing Universal Sentence Encoder embeddings")


@dataclass
class USESentenceObject:
class MilvusSentenceObject:
composite_id: str
corpus_id: str
document_name: str
Expand All @@ -68,7 +51,7 @@ def hash_sentence_id(document_id: uuid.UUID, sentence: str) -> uuid.UUID:
return uuid.uuid5(document_id, sentence)


def convert_batch(objects: list[USESentenceObject]):
def convert_batch(objects: list[MilvusSentenceObject]):
composite_ids, corpus_ids, document_names, text_previews, embeddings, start_indices, end_indices = [], [], [], [], [], [], []
for obj in objects:
composite_ids.append(obj.composite_id)
Expand All @@ -82,13 +65,31 @@ def convert_batch(objects: list[USESentenceObject]):
return [composite_ids, corpus_ids, document_names, text_previews, np.row_stack(embeddings), start_indices, end_indices]


class MilvusUSESentenceVectorStore(CorpusVectorStore):
def __init__(self) -> None:
super().__init__(USETextEncoder())
class MilvusSentenceVectorStore(CorpusVectorStore):
def __init__(self, sentence_encoder: TextEncoder) -> None:
super().__init__(sentence_encoder)
# Don't instantiate the Collection object yet, since the constructor creates the collection in milvus
self.collection: Collection
fields = [
# The first 32 length is the document id, while the later 32 is the sentence id.
# the sentence id is just used to avoid key collision.
FieldSchema(name=COMPOSITE_ID, dtype=DataType.VARCHAR,
max_length=64, is_primary=True, auto_id=False),
FieldSchema(name=CORPUS_FIELD, dtype=DataType.VARCHAR,
max_length=32, is_partition_key=True),
FieldSchema(name=DOCUMENT_NAME, dtype=DataType.VARCHAR,
max_length=256),
FieldSchema(name=TEXT_PREVIEW, dtype=DataType.VARCHAR, max_length=MAX_TEXT_LENGTH),
FieldSchema(name=EMBEDDING_FIELD, dtype=DataType.FLOAT_VECTOR, dim=self.encoder.VECTOR_DIMENSION),
FieldSchema(name=START_FIELD, dtype=DataType.INT64),
FieldSchema(name=END_FIELD, dtype=DataType.INT64),
]
self.sentance_schema: CollectionSchema = CollectionSchema(
fields, "Corpus Vector Table for storing sentence embeddings")
self.collection_name: str = ENCODER_COLLECTION_NAME.format(encoder=self.encoder.ENCODER_NAME)

def first_init(self):
self.collection: Collection = Collection(USE_COLLECTION_NAME, sentance_schema)
self.collection: Collection = Collection(self.collection_name, self.sentance_schema)
index = {
"index_type": "FLAT",
"metric_type": "L2",
Expand All @@ -99,7 +100,7 @@ def first_init(self):
self.encoder.init()

def init(self):
self.collection: Collection = Collection(USE_COLLECTION_NAME, sentance_schema)
self.collection: Collection = Collection(self.collection_name, self.sentance_schema)
self.collection.load()
self.encoder.init()

Expand Down Expand Up @@ -135,7 +136,7 @@ def save_documents(self, doc_entities: list[DocumentEntity]) -> bool:
sentence_count = 0
for doc_entity in doc_entities:
sentences = split_doc(doc_entity.document, MAX_TEXT_LENGTH)
objects: list[USESentenceObject] = []
objects: list[MilvusSentenceObject] = []
sentence_count = sentence_count + len(sentences)

doc_embeddings = self.encoder.embed_multiple(sentences)
Expand All @@ -146,14 +147,11 @@ def save_documents(self, doc_entities: list[DocumentEntity]) -> bool:
sentence_id = hash_sentence_id(doc_entity.document_id, sentence)
composite_id = doc_entity.document_id.hex + sentence_id.hex
end = start + len(sentence)
objects.append(USESentenceObject(composite_id, doc_entity.corpus_id.hex, doc_entity.document_name,
sentence[:MAX_TEXT_LENGTH], doc_embeddings[index], start, end))
objects.append(MilvusSentenceObject(composite_id, doc_entity.corpus_id.hex, doc_entity.document_name,
sentence[:MAX_TEXT_LENGTH], doc_embeddings[index], start, end))
index = index + 1
start = end

insert_count = insert_count + self.collection.insert(convert_batch(objects)).insert_count

return insert_count == sentence_count


SINGLETON: CorpusVectorStore = MilvusUSESentenceVectorStore()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ pymilvus==2.2.8
elasticsearch==8.8.0
scylla-driver==3.26.2
nltk
openai
gunicorn[eventlet]
futurist

0 comments on commit 8d5affa

Please sign in to comment.