From 250ea70f30ec64b733d61747314f0c51985db3fb Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Fri, 29 Nov 2024 09:59:17 +0100 Subject: [PATCH] feat(document-search): Option to choose between image and text embeddings for ImageElements (#205) Co-authored-by: Ludwik Trammer --- examples/apps/documents_chat.py | 4 +- .../src/ragbits/core/embeddings/__init__.py | 4 +- .../src/ragbits/core/embeddings/base.py | 16 ++++ .../unit/vector_stores/test_in_memory.py | 5 +- .../src/ragbits/document_search/_main.py | 26 ++++-- .../document_search/documents/element.py | 82 +++++++++++++------ 6 files changed, 99 insertions(+), 38 deletions(-) diff --git a/examples/apps/documents_chat.py b/examples/apps/documents_chat.py index 8f7f24a4..576d6e37 100644 --- a/examples/apps/documents_chat.py +++ b/examples/apps/documents_chat.py @@ -125,7 +125,9 @@ async def _handle_message( if not self._documents_ingested: yield self.NO_DOCUMENTS_INGESTED_MESSAGE results = await self.document_search.search(message[-1]) - prompt = RAGPrompt(QueryWithContext(query=message, context=[i.text_representation for i in results])) + prompt = RAGPrompt( + QueryWithContext(query=message, context=[i.text_representation for i in results if i.text_representation]) + ) response = await self._llm.generate(prompt) yield response.answer diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py b/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py index 15e47529..5ad6c218 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py @@ -2,10 +2,10 @@ from ragbits.core.utils.config_handling import get_cls_from_config -from .base import Embeddings +from .base import Embeddings, EmbeddingType from .noop import NoopEmbeddings -__all__ = ["Embeddings", "NoopEmbeddings"] +__all__ = ["EmbeddingType", "Embeddings", "NoopEmbeddings"] module = sys.modules[__name__] diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/base.py b/packages/ragbits-core/src/ragbits/core/embeddings/base.py index 66c2716d..e03087b6 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/base.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/base.py @@ -1,4 +1,20 @@ from abc import ABC, abstractmethod +from enum import Enum + + +class EmbeddingType(Enum): + """ + Indicates the type of embedding in regard to what kind of data has been embedded. + + Used to specify the embedding type for a given element. Unlike `Element` type, + which categorizes the element itself, `EmbeddingType` determines how the + element's data is represented. For example, an image element can support + multiple embedding types, such as a description, OCR output, or raw bytes, + allowing for the creation of different embeddings for the same element. + """ + + TEXT: str = "text" + IMAGE: str = "image" class Embeddings(ABC): diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py b/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py index 2c167330..9896d6a9 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py @@ -3,6 +3,7 @@ import pytest from pydantic import computed_field +from ragbits.core.embeddings import EmbeddingType from ragbits.core.vector_stores.base import VectorStoreOptions from ragbits.core.vector_stores.in_memory import InMemoryVectorStore from ragbits.document_search.documents.document import DocumentMeta, DocumentType @@ -48,7 +49,9 @@ async def store_fixture() -> InMemoryVectorStore: ), ] - entries = [element[0].to_vector_db_entry(vector=element[1]) for element in elements] + entries = [ + element[0].to_vector_db_entry(vector=element[1], embedding_type=EmbeddingType.TEXT) for element in elements + ] store = InMemoryVectorStore() await store.store(entries) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 706e1e86..db432eba 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from ragbits.core.audit import traceable -from ragbits.core.embeddings import Embeddings, get_embeddings +from ragbits.core.embeddings import Embeddings, EmbeddingType, get_embeddings from ragbits.core.vector_stores import VectorStore, get_vector_store from ragbits.core.vector_stores.base import VectorStoreOptions from ragbits.document_search.documents.document import Document, DocumentMeta @@ -150,23 +150,31 @@ async def insert_elements(self, elements: list[Element]) -> None: Args: elements: The list of Elements to insert. """ - vectors = await self.embedder.embed_text([element.key for element in elements]) + elements_with_text = [element for element in elements if element.key] + images_with_text = [element for element in elements_with_text if isinstance(element, ImageElement)] + vectors = await self.embedder.embed_text([element.key for element in elements_with_text]) # type: ignore image_elements = [element for element in elements if isinstance(element, ImageElement)] - entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)] + + entries = [ + element.to_vector_db_entry(vector, EmbeddingType.TEXT) + for element, vector in zip(elements_with_text, vectors, strict=False) + ] + not_embedded_image_elements = [ + image_element for image_element in image_elements if image_element not in images_with_text + ] if image_elements and self.embedder.image_support(): image_vectors = await self.embedder.embed_image([element.image_bytes for element in image_elements]) entries.extend( [ - element.to_vector_db_entry(vector) + element.to_vector_db_entry(vector, EmbeddingType.IMAGE) for element, vector in zip(image_elements, image_vectors, strict=False) ] ) - elif image_elements: - warnings.warn( - f"Image elements are not supported by the embedder {self.embedder}. " - f"Skipping {len(image_elements)} image elements." - ) + not_embedded_image_elements = [] + + for image_element in not_embedded_image_elements: + warnings.warn(f"Image: {image_element.id} could not be embedded") await self.vector_store.store(entries) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py index 63bf82bd..d049d349 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py @@ -1,9 +1,11 @@ +import hashlib import uuid from abc import ABC, abstractmethod from typing import Any, ClassVar from pydantic import BaseModel, computed_field +from ragbits.core.embeddings import EmbeddingType from ragbits.core.vector_stores.base import VectorStoreEntry from ragbits.document_search.documents.document import DocumentMeta @@ -28,29 +30,37 @@ class Element(BaseModel, ABC): _elements_registry: ClassVar[dict[str, type["Element"]]] = {} - @computed_field # type: ignore[prop-decorator] + # type: ignore[prop-decorator] @property def id(self) -> str: """ - Get the ID of the element. The id is primarly used as a key in the vector store. - The current representation is a UUID5 hash of various element metadata, including - its contents and location where it was sourced from. + Retrieve the ID of the element, primarily used to represent the element's data. Returns: - The ID in the form of a UUID5 hash. + str: string representing element """ - id_components = [ - self.document_meta.id, - self.element_type, - self.key, - self.text_representation, - str(self.location), - ] - return str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components))) + id_components = self.get_id_components() + return "&".join(f"{k}={v}" for k, v in id_components.items()) + + def get_id_components(self) -> dict[str, str]: + """ + Creates a dictionary of key value pairs of id components + + Returns: + dict: a dictionary + """ + id_components = { + "meta": self.document_meta.id, + "type": self.element_type, + "key": str(self.key), + "text": str(self.text_representation), + "location": str(self.location), + } + return id_components @computed_field # type: ignore[prop-decorator] @property - def key(self) -> str: + def key(self) -> str | None: """ Get the representation of the element for embedding. @@ -62,7 +72,7 @@ def key(self) -> str: @computed_field # type: ignore[prop-decorator] @property @abstractmethod - def text_representation(self) -> str: + def text_representation(self) -> str | None: """ Get the text representation of the element. @@ -90,24 +100,28 @@ def from_vector_db_entry(cls, db_entry: VectorStoreEntry) -> "Element": """ element_type = db_entry.metadata["element_type"] element_cls = Element._elements_registry[element_type] + if "embedding_type" in db_entry.metadata: + del db_entry.metadata["embedding_type"] return element_cls(**db_entry.metadata) - def to_vector_db_entry(self, vector: list[float]) -> VectorStoreEntry: + def to_vector_db_entry(self, vector: list[float], embedding_type: EmbeddingType) -> VectorStoreEntry: """ Create a vector database entry from the element. Args: vector: The vector. - + embedding_type: EmbeddingTypes Returns: The vector database entry """ - return VectorStoreEntry( - id=self.id, - key=self.key, - vector=vector, - metadata=self.model_dump(exclude={"id", "key"}), - ) + id_components = [ + self.id, + str(embedding_type), + ] + vector_store_entry_id = str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components))) + metadata = self.model_dump(exclude={"id", "key"}) + metadata["embedding_type"] = str(embedding_type) + return VectorStoreEntry(id=vector_store_entry_id, key=str(self.key), vector=vector, metadata=metadata) class TextElement(Element): @@ -142,11 +156,29 @@ class ImageElement(Element): @computed_field # type: ignore[prop-decorator] @property - def text_representation(self) -> str: + def text_representation(self) -> str | None: """ Get the text representation of the element. Returns: The text representation. """ - return f"Description: {self.description}\nExtracted text: {self.ocr_extracted_text}" + if not self.description and not self.ocr_extracted_text: + return None + repr = "" + if self.description: + repr += f"Description: {self.description}\n" + if self.ocr_extracted_text: + repr += f"Extracted text: {self.ocr_extracted_text}" + return repr + + def get_id_components(self) -> dict[str, str]: + """ + Creates a dictionary of key value pairs of id components + + Returns: + dict: a dictionary + """ + id_components = super().get_id_components() + id_components["image_hash"] = hashlib.sha256(self.image_bytes).hexdigest() + return id_components