Skip to content

Commit

Permalink
Fix: FAISSDocumentStore - make write_documents properly work in c…
Browse files Browse the repository at this point in the history
…ombination w `update_embeddings` (deepset-ai#5221)

* Update VERSION.txt

* first draft

* simplify method and test

* rm unnecessary pb.close

* integrate feedback
  • Loading branch information
anakin87 authored Jul 3, 2023
1 parent aee8628 commit 1be3936
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 46 deletions.
101 changes: 55 additions & 46 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,54 +257,63 @@ def write_documents(
document_objects = self._handle_duplicate_documents(
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
if len(document_objects) > 0:
add_vectors = all(doc.embedding is not None for doc in document_objects)

if self.duplicate_documents == "overwrite" and add_vectors:
logger.warning(
"You have to provide `duplicate_documents = 'overwrite'` arg and "
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
"Please call `update_embeddings` method to repopulate `faiss_index`"
)

vector_id = self.faiss_indexes[index].ntotal
with tqdm(
total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents"
) as progress_bar:
for i in range(0, len(document_objects), batch_size):
if len(document_objects) == 0:
return

vector_id = self.faiss_indexes[index].ntotal
add_vectors = all(doc.embedding is not None for doc in document_objects)

if vector_id > 0 and self.duplicate_documents == "overwrite" and add_vectors:
logger.warning(
"`FAISSDocumentStore` is adding new vectors to an existing `faiss_index`.\n"
"Please call `update_embeddings` method to correctly repopulate `faiss_index`"
)

with tqdm(
total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents"
) as progress_bar:
for i in range(0, len(document_objects), batch_size):
batch_documents = document_objects[i : i + batch_size]
if add_vectors:
if not self.faiss_indexes[index].is_trained:
raise ValueError(
f"FAISS index of type {self.faiss_index_factory_str} must be trained before adding vectors. Call `train_index()` "
"method before adding the vectors. For details, refer to the documentation: "
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
)

embeddings = [doc.embedding for doc in batch_documents]
embeddings_to_index = np.array(embeddings, dtype="float32")

if self.similarity == "cosine":
self.normalize_embedding(embeddings_to_index)

self.faiss_indexes[index].add(embeddings_to_index)

# write_documents method (duplicate_documents="overwrite") should properly work in combination with
# update_embeddings method (update_existing_embeddings=False).
# If no new embeddings are provided, we save the existing FAISS vector ids
elif self.duplicate_documents == "overwrite":
existing_docs = self.get_documents_by_id(ids=[doc.id for doc in batch_documents], index=index)
existing_docs_vector_ids = {
doc.id: doc.meta["vector_id"] for doc in existing_docs if doc.meta and "vector_id" in doc.meta
}

docs_to_write_in_sql = []
for doc in batch_documents:
meta = doc.meta
if add_vectors:
if not self.faiss_indexes[index].is_trained:
raise ValueError(
"FAISS index of type {} must be trained before adding vectors. Call `train_index()` "
"method before adding the vectors. For details, refer to the documentation: "
"[FAISSDocumentStore API](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstoretrain_index)."
"".format(self.faiss_index_factory_str)
)

embeddings = [doc.embedding for doc in document_objects[i : i + batch_size]]
embeddings_to_index = np.array(embeddings, dtype="float32")

if self.similarity == "cosine":
self.normalize_embedding(embeddings_to_index)

self.faiss_indexes[index].add(embeddings_to_index)

docs_to_write_in_sql = []
for doc in document_objects[i : i + batch_size]:
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_id
vector_id += 1
docs_to_write_in_sql.append(doc)

super(FAISSDocumentStore, self).write_documents(
docs_to_write_in_sql,
index=index,
duplicate_documents=duplicate_documents,
batch_size=batch_size,
)
progress_bar.update(batch_size)
progress_bar.close()
meta["vector_id"] = vector_id
vector_id += 1
elif self.duplicate_documents == "overwrite" and doc.id in existing_docs_vector_ids:
meta["vector_id"] = existing_docs_vector_ids[doc.id]
docs_to_write_in_sql.append(doc)

super(FAISSDocumentStore, self).write_documents(
docs_to_write_in_sql, index=index, duplicate_documents=duplicate_documents, batch_size=batch_size
)
progress_bar.update(batch_size)

def _create_document_field_map(self) -> Dict:
return {self.index: self.embedding_field}
Expand Down
19 changes: 19 additions & 0 deletions test/document_stores/test_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ def test_update_docs_different_indexes(self, ds, documents_with_embeddings):
assert len(docs_from_index_b) == len(docs_b)
assert {int(doc.meta["vector_id"]) for doc in docs_from_index_b} == {0, 1, 2, 3}

@pytest.mark.integration
def test_dont_update_existing_embeddings(self, ds, docs):
retriever = MockDenseRetriever(document_store=ds)
first_doc_id = docs[0].id

for i in range(1, 4):
ds.write_documents(docs[:i])
ds.update_embeddings(retriever=retriever, update_existing_embeddings=False)

assert ds.get_document_count() == i
assert ds.get_embedding_count() == i
assert ds.get_document_by_id(id=first_doc_id).meta["vector_id"] == "0"

# Check if the embeddings of the first document remain unchanged after multiple updates
if i == 1:
first_doc_embedding = ds.get_document_by_id(id=first_doc_id).embedding
else:
assert np.array_equal(ds.get_document_by_id(id=first_doc_id).embedding, first_doc_embedding)

@pytest.mark.integration
def test_passing_index_from_outside(self, documents_with_embeddings, tmp_path):
d = 768
Expand Down

0 comments on commit 1be3936

Please sign in to comment.