Skip to content

Commit

Permalink
feat(document-search): Option to choose between image and text embedd…
Browse files Browse the repository at this point in the history
…ings for ImageElements (#205)

Co-authored-by: Ludwik Trammer <[email protected]>
  • Loading branch information
kdziedzic68 and ludwiktrammer authored Nov 29, 2024
1 parent 0c4ef7b commit 250ea70
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 38 deletions.
4 changes: 3 additions & 1 deletion examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ async def _handle_message(
if not self._documents_ingested:
yield self.NO_DOCUMENTS_INGESTED_MESSAGE
results = await self.document_search.search(message[-1])
prompt = RAGPrompt(QueryWithContext(query=message, context=[i.text_representation for i in results]))
prompt = RAGPrompt(
QueryWithContext(query=message, context=[i.text_representation for i in results if i.text_representation])
)
response = await self._llm.generate(prompt)
yield response.answer

Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import Embeddings
from .base import Embeddings, EmbeddingType
from .noop import NoopEmbeddings

__all__ = ["Embeddings", "NoopEmbeddings"]
__all__ = ["EmbeddingType", "Embeddings", "NoopEmbeddings"]

module = sys.modules[__name__]

Expand Down
16 changes: 16 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
from abc import ABC, abstractmethod
from enum import Enum


class EmbeddingType(Enum):
"""
Indicates the type of embedding in regard to what kind of data has been embedded.
Used to specify the embedding type for a given element. Unlike `Element` type,
which categorizes the element itself, `EmbeddingType` determines how the
element's data is represented. For example, an image element can support
multiple embedding types, such as a description, OCR output, or raw bytes,
allowing for the creation of different embeddings for the same element.
"""

TEXT: str = "text"
IMAGE: str = "image"


class Embeddings(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pydantic import computed_field

from ragbits.core.embeddings import EmbeddingType
from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
Expand Down Expand Up @@ -48,7 +49,9 @@ async def store_fixture() -> InMemoryVectorStore:
),
]

entries = [element[0].to_vector_db_entry(vector=element[1]) for element in elements]
entries = [
element[0].to_vector_db_entry(vector=element[1], embedding_type=EmbeddingType.TEXT) for element in elements
]

store = InMemoryVectorStore()
await store.store(entries)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel, Field

from ragbits.core.audit import traceable
from ragbits.core.embeddings import Embeddings, get_embeddings
from ragbits.core.embeddings import Embeddings, EmbeddingType, get_embeddings
from ragbits.core.vector_stores import VectorStore, get_vector_store
from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.document_search.documents.document import Document, DocumentMeta
Expand Down Expand Up @@ -150,23 +150,31 @@ async def insert_elements(self, elements: list[Element]) -> None:
Args:
elements: The list of Elements to insert.
"""
vectors = await self.embedder.embed_text([element.key for element in elements])
elements_with_text = [element for element in elements if element.key]
images_with_text = [element for element in elements_with_text if isinstance(element, ImageElement)]
vectors = await self.embedder.embed_text([element.key for element in elements_with_text]) # type: ignore

image_elements = [element for element in elements if isinstance(element, ImageElement)]
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)]

entries = [
element.to_vector_db_entry(vector, EmbeddingType.TEXT)
for element, vector in zip(elements_with_text, vectors, strict=False)
]
not_embedded_image_elements = [
image_element for image_element in image_elements if image_element not in images_with_text
]

if image_elements and self.embedder.image_support():
image_vectors = await self.embedder.embed_image([element.image_bytes for element in image_elements])
entries.extend(
[
element.to_vector_db_entry(vector)
element.to_vector_db_entry(vector, EmbeddingType.IMAGE)
for element, vector in zip(image_elements, image_vectors, strict=False)
]
)
elif image_elements:
warnings.warn(
f"Image elements are not supported by the embedder {self.embedder}. "
f"Skipping {len(image_elements)} image elements."
)
not_embedded_image_elements = []

for image_element in not_embedded_image_elements:
warnings.warn(f"Image: {image_element.id} could not be embedded")

await self.vector_store.store(entries)
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hashlib
import uuid
from abc import ABC, abstractmethod
from typing import Any, ClassVar

from pydantic import BaseModel, computed_field

from ragbits.core.embeddings import EmbeddingType
from ragbits.core.vector_stores.base import VectorStoreEntry
from ragbits.document_search.documents.document import DocumentMeta

Expand All @@ -28,29 +30,37 @@ class Element(BaseModel, ABC):

_elements_registry: ClassVar[dict[str, type["Element"]]] = {}

@computed_field # type: ignore[prop-decorator]
# type: ignore[prop-decorator]
@property
def id(self) -> str:
"""
Get the ID of the element. The id is primarly used as a key in the vector store.
The current representation is a UUID5 hash of various element metadata, including
its contents and location where it was sourced from.
Retrieve the ID of the element, primarily used to represent the element's data.
Returns:
The ID in the form of a UUID5 hash.
str: string representing element
"""
id_components = [
self.document_meta.id,
self.element_type,
self.key,
self.text_representation,
str(self.location),
]
return str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components)))
id_components = self.get_id_components()
return "&".join(f"{k}={v}" for k, v in id_components.items())

def get_id_components(self) -> dict[str, str]:
"""
Creates a dictionary of key value pairs of id components
Returns:
dict: a dictionary
"""
id_components = {
"meta": self.document_meta.id,
"type": self.element_type,
"key": str(self.key),
"text": str(self.text_representation),
"location": str(self.location),
}
return id_components

@computed_field # type: ignore[prop-decorator]
@property
def key(self) -> str:
def key(self) -> str | None:
"""
Get the representation of the element for embedding.
Expand All @@ -62,7 +72,7 @@ def key(self) -> str:
@computed_field # type: ignore[prop-decorator]
@property
@abstractmethod
def text_representation(self) -> str:
def text_representation(self) -> str | None:
"""
Get the text representation of the element.
Expand Down Expand Up @@ -90,24 +100,28 @@ def from_vector_db_entry(cls, db_entry: VectorStoreEntry) -> "Element":
"""
element_type = db_entry.metadata["element_type"]
element_cls = Element._elements_registry[element_type]
if "embedding_type" in db_entry.metadata:
del db_entry.metadata["embedding_type"]
return element_cls(**db_entry.metadata)

def to_vector_db_entry(self, vector: list[float]) -> VectorStoreEntry:
def to_vector_db_entry(self, vector: list[float], embedding_type: EmbeddingType) -> VectorStoreEntry:
"""
Create a vector database entry from the element.
Args:
vector: The vector.
embedding_type: EmbeddingTypes
Returns:
The vector database entry
"""
return VectorStoreEntry(
id=self.id,
key=self.key,
vector=vector,
metadata=self.model_dump(exclude={"id", "key"}),
)
id_components = [
self.id,
str(embedding_type),
]
vector_store_entry_id = str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components)))
metadata = self.model_dump(exclude={"id", "key"})
metadata["embedding_type"] = str(embedding_type)
return VectorStoreEntry(id=vector_store_entry_id, key=str(self.key), vector=vector, metadata=metadata)


class TextElement(Element):
Expand Down Expand Up @@ -142,11 +156,29 @@ class ImageElement(Element):

@computed_field # type: ignore[prop-decorator]
@property
def text_representation(self) -> str:
def text_representation(self) -> str | None:
"""
Get the text representation of the element.
Returns:
The text representation.
"""
return f"Description: {self.description}\nExtracted text: {self.ocr_extracted_text}"
if not self.description and not self.ocr_extracted_text:
return None
repr = ""
if self.description:
repr += f"Description: {self.description}\n"
if self.ocr_extracted_text:
repr += f"Extracted text: {self.ocr_extracted_text}"
return repr

def get_id_components(self) -> dict[str, str]:
"""
Creates a dictionary of key value pairs of id components
Returns:
dict: a dictionary
"""
id_components = super().get_id_components()
id_components["image_hash"] = hashlib.sha256(self.image_bytes).hexdigest()
return id_components

0 comments on commit 250ea70

Please sign in to comment.