Skip to content

Commit

Permalink
make sure id field of Documents in FAISS docstore have the same id as…
Browse files Browse the repository at this point in the history
… values in index_to_docstore_id, implement get_by_ids method
  • Loading branch information
nhols committed Oct 4, 2024
1 parent e8e5d67 commit 9e7c4df
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
22 changes: 16 additions & 6 deletions libs/community/langchain_community/vectorstores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Iterable,
List,
Optional,
Sequence,
Sized,
Tuple,
Union,
Expand Down Expand Up @@ -292,13 +293,17 @@ def __add(
)

_len_check_if_sized(texts, metadatas, "texts", "metadatas")

ids = ids or [str(uuid.uuid4()) for _ in texts]
_len_check_if_sized(texts, ids, "texts", "ids")

_metadatas = metadatas or ({} for _ in texts)
documents = [
Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
Document(id=id_, page_content=t, metadata=m)
for id_, t, m in zip(ids, texts, _metadatas)
]

_len_check_if_sized(documents, embeddings, "documents", "embeddings")
_len_check_if_sized(documents, ids, "documents", "ids")

if ids and len(ids) != len(set(ids)):
raise ValueError("Duplicate ids found in the ids list.")
Expand All @@ -310,7 +315,6 @@ def __add(
self.index.add(vector)

# Add information to docstore and index.
ids = ids or [str(uuid.uuid4()) for _ in texts]
self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
starting_len = len(self.index_to_docstore_id)
index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
Expand Down Expand Up @@ -1359,10 +1363,16 @@ def _create_filter_func(

def filter_func(metadata: Dict[str, Any]) -> bool:
return all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
)
for key, value in filter.items() # type: ignore
)

return filter_func

def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
docs = [self.docstore.search(id_) for id_ in ids]
return [doc for doc in docs if isinstance(doc, Document)]
30 changes: 30 additions & 0 deletions libs/community/tests/unit_tests/vectorstores/test_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,33 @@ def test_faiss_with_duplicate_ids() -> None:
FAISS.from_texts(texts, FakeEmbeddings(), ids=duplicate_ids)

assert "Duplicate ids found in the ids list." in str(exc_info.value)


@pytest.mark.requires("faiss")
def test_faiss_document_ids() -> None:
"""Test whether FAISS assigns the correct document ids."""
ids = ["id1", "id2", "id3"]
texts = ["foo", "bar", "baz"]

vstore = FAISS.from_texts(texts, FakeEmbeddings(), ids=ids)
for id_, text in (ids, texts):
doc = vstore.docstore.search(id_)
assert doc.id == id_
assert doc.page_content == text


@pytest.mark.requires("faiss")
def test_faiss_get_by_ids() -> None:
"""Test FAISS `get_by_ids` method."""
ids = ["id1", "id2", "id3"]
texts = ["foo", "bar", "baz"]

vstore = FAISS.from_texts(texts, FakeEmbeddings(), ids=ids)
docs = vstore.get_by_ids(ids)
assert len(docs) == 3
assert {doc.id for doc in docs} == set(ids)

for id_ in ids:
res = vstore.get_by_ids([id_])
assert len(res) == 1
assert res[0] == id_

0 comments on commit 9e7c4df

Please sign in to comment.