Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add id field for Element #183

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_key() for element in results])
print([element.get_text_representation() for element in results])


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/chroma_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_key() for element in results])
print([element.get_text_representation() for element in results])


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ class VectorStoreEntry(BaseModel):
An object representing a vector database entry.
"""

key: str
id: str
vector: list[float]
content: str
metadata: dict


Expand Down
18 changes: 9 additions & 9 deletions packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
from hashlib import sha256
from typing import Literal

import chromadb
Expand Down Expand Up @@ -84,9 +83,8 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
Args:
entries: The entries to store.
"""
# TODO: Think about better id components for hashing and move hash computing to VectorStoreEntry
ids = [sha256(entry.key.encode("utf-8")).hexdigest() for entry in entries]
documents = [entry.key for entry in entries]
ids = [entry.id for entry in entries]
documents = [entry.content for entry in entries]
embeddings = [entry.vector for entry in entries]
metadatas = [entry.metadata for entry in entries]
metadatas = (
Expand Down Expand Up @@ -132,12 +130,13 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None

return [
VectorStoreEntry(
key=document,
id=id,
content=document,
vector=list(embeddings),
metadata=metadata, # type: ignore
)
for batch in zip(metadatas, embeddings, distances, documents, strict=False)
for metadata, embeddings, distance, document in zip(*batch, strict=False)
for batch in zip(ids, metadatas, embeddings, distances, documents, strict=False)
for id, metadata, embeddings, distance, document in zip(*batch, strict=False)
if options.max_distance is None or distance <= options.max_distance
]

Expand Down Expand Up @@ -182,9 +181,10 @@ async def list(

return [
VectorStoreEntry(
key=document,
id=id,
content=document,
vector=list(embedding),
metadata=metadata, # type: ignore
)
for metadata, embedding, document in zip(metadatas, embeddings, documents, strict=False)
for id, metadata, embedding, document in zip(ids, metadatas, embeddings, documents, strict=False)
]
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
entries: The entries to store.
"""
for entry in entries:
self._storage[entry.key] = entry
self._storage[entry.id] = entry

@traceable
async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
Expand Down
17 changes: 12 additions & 5 deletions packages/ragbits-core/tests/unit/vector_stores/test_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ async def test_get_chroma_collection(mock_chromadb_store: ChromaVectorStore) ->
async def test_store(mock_chromadb_store: ChromaVectorStore) -> None:
data = [
VectorStoreEntry(
key="test_key",
id="test_key",
content="test content",
vector=[0.1, 0.2, 0.3],
metadata={
"content": "test content",
Expand All @@ -39,15 +40,15 @@ async def test_store(mock_chromadb_store: ChromaVectorStore) -> None:

mock_chromadb_store._client.get_or_create_collection().add.assert_called_once() # type: ignore
mock_chromadb_store._client.get_or_create_collection().add.assert_called_with( # type: ignore
ids=["92488e1e3eeecdf99f3ed2ce59233efb4b4fb612d5655c0ce9ea52b5a502e655"],
ids=[data[0].id],
embeddings=[[0.1, 0.2, 0.3]],
metadatas=[
{
"__metadata": '{"content": "test content", "document": {"title": "test title", "source":'
' {"path": "/test/path"}, "document_type": "test_type"}}',
}
],
documents=["test_key"],
documents=["test content"],
)


Expand Down Expand Up @@ -85,7 +86,7 @@ async def test_retrieve(
],
"embeddings": [[[0.12, 0.25, 0.29], [0.13, 0.26, 0.30]]],
"distances": [[0.1, 0.2]],
"documents": [["test_key_1", "test_key_2"]],
"documents": [["test content 1", "test content 2"]],
"ids": [["test_id_1", "test_id_2"]],
}

Expand All @@ -96,6 +97,8 @@ async def test_retrieve(
assert entry.metadata["content"] == result["content"]
assert entry.metadata["document"]["title"] == result["title"]
assert entry.vector == result["vector"]
assert entry.id == f"test_id_{results.index(result) + 1}"
assert entry.content == result["content"]


async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
Expand All @@ -112,7 +115,7 @@ async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
},
],
"embeddings": [[0.12, 0.25, 0.29], [0.13, 0.26, 0.30]],
"documents": ["test_key", "test_key_2"],
"documents": ["test content 1", "test content2"],
"ids": ["test_id_1", "test_id_2"],
}

Expand All @@ -122,6 +125,10 @@ async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
assert entries[0].metadata["content"] == "test content"
assert entries[0].metadata["document"]["title"] == "test title"
assert entries[0].vector == [0.12, 0.25, 0.29]
assert entries[0].content == "test content 1"
assert entries[0].id == "test_id_1"
assert entries[1].metadata["content"] == "test content 2"
assert entries[1].metadata["document"]["title"] == "test title 2"
assert entries[1].vector == [0.13, 0.26, 0.30]
assert entries[1].content == "test content2"
assert entries[1].id == "test_id_2"
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ class AnimalElement(Element):
type: str
age: int

def get_key(self) -> str:
"""
Get the key of the element which will be used to generate the vector.

Returns:
The key.
"""
return self.name

def get_text_representation(self) -> str:
"""
Get the text representation of the element.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,6 @@ async def insert_elements(self, elements: list[Element]) -> None:
Args:
elements: The list of Elements to insert.
"""
vectors = await self.embedder.embed_text([element.get_key() for element in elements])
vectors = await self.embedder.embed_text([element.get_text_for_embedding() for element in elements])
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)]
await self.vector_store.store(entries)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from abc import ABC, abstractmethod
from typing import Any, ClassVar

from pydantic import BaseModel
from pydantic import BaseModel, computed_field

from ragbits.core.vector_stores.base import VectorStoreEntry
from ragbits.document_search.documents.document import DocumentMeta
Expand All @@ -27,12 +28,32 @@ class Element(BaseModel, ABC):

_elements_registry: ClassVar[dict[str, type["Element"]]] = {}

def get_key(self) -> str:
@computed_field # type: ignore[prop-decorator]
@property
def id(self) -> str:
"""
Get the key of the element which will be used to generate the vector.
Get the ID of the element. The id is primarly used as a key in the vector store.
The current representation is a UUID5 hash of various element metadata, including
its contents and location where it was sourced from.

Returns:
The key.
The ID in the form of a UUID5 hash.
"""
id_components = [
self.document_meta.id,
self.get_text_for_embedding(),
self.get_text_representation(),
str(self.location),
]

return str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components)))

def get_text_for_embedding(self) -> str:
"""
Get the text representation of the element for embedding.

Returns:
The text representation for embedding.
"""
return self.get_text_representation()

Expand Down Expand Up @@ -82,8 +103,9 @@ def to_vector_db_entry(self, vector: list[float]) -> VectorStoreEntry:
The vector database entry
"""
return VectorStoreEntry(
key=self.get_key(),
id=self.id,
vector=vector,
content=self.get_text_for_embedding(),
metadata=self.model_dump(),
)

Expand Down
9 changes: 4 additions & 5 deletions packages/ragbits-document-search/tests/unit/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ class MyElement(Element):
element_type: str = "custom_element"
foo: str

def get_key(self) -> str:
return self.foo + self.foo

def get_text_representation(self) -> str:
return self.foo + self.foo

element = Element.from_vector_db_entry(
db_entry=VectorStoreEntry(
key="key",
id="test id",
content="test content",
vector=[0.1, 0.2],
metadata={
"element_type": "custom_element",
Expand All @@ -31,6 +29,7 @@ def get_text_representation(self) -> str:

assert isinstance(element, MyElement)
assert element.foo == "bar"
assert element.get_key() == "barbar"
assert element.get_text_for_embedding() == "barbar"
assert element.get_text_representation() == "barbar"
assert element.document_meta.document_type == DocumentType.TXT
assert element.document_meta.source.source_type == "local_file_source"
Loading