Skip to content

Commit

Permalink
feat: Add id field for Element (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Nov 13, 2024
1 parent a91fcc3 commit a2e9abc
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_key() for element in results])
print([element.get_text_representation() for element in results])


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/chroma_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_key() for element in results])
print([element.get_text_representation() for element in results])


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ class VectorStoreEntry(BaseModel):
An object representing a vector database entry.
"""

key: str
id: str
vector: list[float]
content: str
metadata: dict


Expand Down
18 changes: 9 additions & 9 deletions packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
from hashlib import sha256
from typing import Literal

import chromadb
Expand Down Expand Up @@ -84,9 +83,8 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
Args:
entries: The entries to store.
"""
# TODO: Think about better id components for hashing and move hash computing to VectorStoreEntry
ids = [sha256(entry.key.encode("utf-8")).hexdigest() for entry in entries]
documents = [entry.key for entry in entries]
ids = [entry.id for entry in entries]
documents = [entry.content for entry in entries]
embeddings = [entry.vector for entry in entries]
metadatas = [entry.metadata for entry in entries]
metadatas = (
Expand Down Expand Up @@ -132,12 +130,13 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None

return [
VectorStoreEntry(
key=document,
id=id,
content=document,
vector=list(embeddings),
metadata=metadata, # type: ignore
)
for batch in zip(metadatas, embeddings, distances, documents, strict=False)
for metadata, embeddings, distance, document in zip(*batch, strict=False)
for batch in zip(ids, metadatas, embeddings, distances, documents, strict=False)
for id, metadata, embeddings, distance, document in zip(*batch, strict=False)
if options.max_distance is None or distance <= options.max_distance
]

Expand Down Expand Up @@ -182,9 +181,10 @@ async def list(

return [
VectorStoreEntry(
key=document,
id=id,
content=document,
vector=list(embedding),
metadata=metadata, # type: ignore
)
for metadata, embedding, document in zip(metadatas, embeddings, documents, strict=False)
for id, metadata, embedding, document in zip(ids, metadatas, embeddings, documents, strict=False)
]
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
entries: The entries to store.
"""
for entry in entries:
self._storage[entry.key] = entry
self._storage[entry.id] = entry

@traceable
async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
Expand Down
17 changes: 12 additions & 5 deletions packages/ragbits-core/tests/unit/vector_stores/test_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ async def test_get_chroma_collection(mock_chromadb_store: ChromaVectorStore) ->
async def test_store(mock_chromadb_store: ChromaVectorStore) -> None:
data = [
VectorStoreEntry(
key="test_key",
id="test_key",
content="test content",
vector=[0.1, 0.2, 0.3],
metadata={
"content": "test content",
Expand All @@ -39,15 +40,15 @@ async def test_store(mock_chromadb_store: ChromaVectorStore) -> None:

mock_chromadb_store._client.get_or_create_collection().add.assert_called_once() # type: ignore
mock_chromadb_store._client.get_or_create_collection().add.assert_called_with( # type: ignore
ids=["92488e1e3eeecdf99f3ed2ce59233efb4b4fb612d5655c0ce9ea52b5a502e655"],
ids=[data[0].id],
embeddings=[[0.1, 0.2, 0.3]],
metadatas=[
{
"__metadata": '{"content": "test content", "document": {"title": "test title", "source":'
' {"path": "/test/path"}, "document_type": "test_type"}}',
}
],
documents=["test_key"],
documents=["test content"],
)


Expand Down Expand Up @@ -85,7 +86,7 @@ async def test_retrieve(
],
"embeddings": [[[0.12, 0.25, 0.29], [0.13, 0.26, 0.30]]],
"distances": [[0.1, 0.2]],
"documents": [["test_key_1", "test_key_2"]],
"documents": [["test content 1", "test content 2"]],
"ids": [["test_id_1", "test_id_2"]],
}

Expand All @@ -96,6 +97,8 @@ async def test_retrieve(
assert entry.metadata["content"] == result["content"]
assert entry.metadata["document"]["title"] == result["title"]
assert entry.vector == result["vector"]
assert entry.id == f"test_id_{results.index(result) + 1}"
assert entry.content == result["content"]


async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
Expand All @@ -112,7 +115,7 @@ async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
},
],
"embeddings": [[0.12, 0.25, 0.29], [0.13, 0.26, 0.30]],
"documents": ["test_key", "test_key_2"],
"documents": ["test content 1", "test content2"],
"ids": ["test_id_1", "test_id_2"],
}

Expand All @@ -122,6 +125,10 @@ async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
assert entries[0].metadata["content"] == "test content"
assert entries[0].metadata["document"]["title"] == "test title"
assert entries[0].vector == [0.12, 0.25, 0.29]
assert entries[0].content == "test content 1"
assert entries[0].id == "test_id_1"
assert entries[1].metadata["content"] == "test content 2"
assert entries[1].metadata["document"]["title"] == "test title 2"
assert entries[1].vector == [0.13, 0.26, 0.30]
assert entries[1].content == "test content2"
assert entries[1].id == "test_id_2"
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ class AnimalElement(Element):
type: str
age: int

def get_key(self) -> str:
"""
Get the key of the element which will be used to generate the vector.
Returns:
The key.
"""
return self.name

def get_text_representation(self) -> str:
"""
Get the text representation of the element.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,6 @@ async def insert_elements(self, elements: list[Element]) -> None:
Args:
elements: The list of Elements to insert.
"""
vectors = await self.embedder.embed_text([element.get_key() for element in elements])
vectors = await self.embedder.embed_text([element.get_text_for_embedding() for element in elements])
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)]
await self.vector_store.store(entries)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from abc import ABC, abstractmethod
from typing import Any, ClassVar

from pydantic import BaseModel
from pydantic import BaseModel, computed_field

from ragbits.core.vector_stores.base import VectorStoreEntry
from ragbits.document_search.documents.document import DocumentMeta
Expand All @@ -27,12 +28,32 @@ class Element(BaseModel, ABC):

_elements_registry: ClassVar[dict[str, type["Element"]]] = {}

def get_key(self) -> str:
@computed_field # type: ignore[prop-decorator]
@property
def id(self) -> str:
"""
Get the key of the element which will be used to generate the vector.
Get the ID of the element. The id is primarly used as a key in the vector store.
The current representation is a UUID5 hash of various element metadata, including
its contents and location where it was sourced from.
Returns:
The key.
The ID in the form of a UUID5 hash.
"""
id_components = [
self.document_meta.id,
self.get_text_for_embedding(),
self.get_text_representation(),
str(self.location),
]

return str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components)))

def get_text_for_embedding(self) -> str:
"""
Get the text representation of the element for embedding.
Returns:
The text representation for embedding.
"""
return self.get_text_representation()

Expand Down Expand Up @@ -82,8 +103,9 @@ def to_vector_db_entry(self, vector: list[float]) -> VectorStoreEntry:
The vector database entry
"""
return VectorStoreEntry(
key=self.get_key(),
id=self.id,
vector=vector,
content=self.get_text_for_embedding(),
metadata=self.model_dump(),
)

Expand Down
9 changes: 4 additions & 5 deletions packages/ragbits-document-search/tests/unit/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ class MyElement(Element):
element_type: str = "custom_element"
foo: str

def get_key(self) -> str:
return self.foo + self.foo

def get_text_representation(self) -> str:
return self.foo + self.foo

element = Element.from_vector_db_entry(
db_entry=VectorStoreEntry(
key="key",
id="test id",
content="test content",
vector=[0.1, 0.2],
metadata={
"element_type": "custom_element",
Expand All @@ -31,6 +29,7 @@ def get_text_representation(self) -> str:

assert isinstance(element, MyElement)
assert element.foo == "bar"
assert element.get_key() == "barbar"
assert element.get_text_for_embedding() == "barbar"
assert element.get_text_representation() == "barbar"
assert element.document_meta.document_type == DocumentType.TXT
assert element.document_meta.source.source_type == "local_file_source"

0 comments on commit a2e9abc

Please sign in to comment.