From 72a3bb18a68c8f4a96896ddef3dd6653c07d0d96 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Mon, 30 Sep 2024 10:26:33 -0700 Subject: [PATCH] add indexing heartbeat --- .../background/indexing/run_indexing.py | 16 +++- backend/danswer/indexing/chunker.py | 15 +++- backend/danswer/indexing/embedder.py | 42 +++------ .../danswer/indexing/indexing_heartbeat.py | 41 +++++++++ backend/danswer/indexing/indexing_pipeline.py | 20 ++--- .../search_nlp_models.py | 6 ++ .../tests/unit/danswer/indexing/conftest.py | 18 ++++ .../unit/danswer/indexing/test_chunker.py | 50 +++++++++-- .../unit/danswer/indexing/test_embedder.py | 90 +++++++++++++++++++ .../unit/danswer/indexing/test_heartbeat.py | 80 +++++++++++++++++ 10 files changed, 324 insertions(+), 54 deletions(-) create mode 100644 backend/danswer/indexing/indexing_heartbeat.py create mode 100644 backend/tests/unit/danswer/indexing/conftest.py create mode 100644 backend/tests/unit/danswer/indexing/test_embedder.py create mode 100644 backend/tests/unit/danswer/indexing/test_heartbeat.py diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index a29ddd76c2b..499899ac225 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -29,6 +29,7 @@ from danswer.db.models import IndexModelStatus from danswer.document_index.factory import get_default_document_index from danswer.indexing.embedder import DefaultIndexingEmbedder +from danswer.indexing.indexing_heartbeat import IndexingHeartbeat from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.utils.logger import IndexAttemptSingleton from danswer.utils.logger import setup_logger @@ -103,15 +104,24 @@ def _run_indexing( ) embedding_model = DefaultIndexingEmbedder.from_db_search_settings( - search_settings=search_settings + search_settings=search_settings, + heartbeat=IndexingHeartbeat( + index_attempt_id=index_attempt.id, + db_session=db_session, + # let the world know we're still making progress after + # every 10 batches + freq=10, + ), ) indexing_pipeline = build_indexing_pipeline( attempt_id=index_attempt.id, embedder=embedding_model, document_index=document_index, - ignore_time_skip=index_attempt.from_beginning - or (search_settings.status == IndexModelStatus.FUTURE), + ignore_time_skip=( + index_attempt.from_beginning + or (search_settings.status == IndexModelStatus.FUTURE) + ), db_session=db_session, ) diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 03a03f30f49..a25cfc3d32b 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -10,6 +10,7 @@ get_metadata_keys_to_ignore, ) from danswer.connectors.models import Document +from danswer.indexing.indexing_heartbeat import Heartbeat from danswer.indexing.models import DocAwareChunk from danswer.natural_language_processing.utils import BaseTokenizer from danswer.utils.logger import setup_logger @@ -123,6 +124,7 @@ def __init__( chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE, chunk_overlap: int = CHUNK_OVERLAP, mini_chunk_size: int = MINI_CHUNK_SIZE, + heartbeat: Heartbeat | None = None, ) -> None: from llama_index.text_splitter import SentenceSplitter @@ -131,6 +133,7 @@ def __init__( self.enable_multipass = enable_multipass self.enable_large_chunks = enable_large_chunks self.tokenizer = tokenizer + self.heartbeat = heartbeat self.blurb_splitter = SentenceSplitter( tokenizer=tokenizer.tokenize, @@ -255,7 +258,7 @@ def _create_chunk( # If the chunk does not have any useable content, it will not be indexed return chunks - def chunk(self, document: Document) -> list[DocAwareChunk]: + def _handle_single_document(self, document: Document) -> list[DocAwareChunk]: # Specifically for reproducing an issue with gmail if document.source == DocumentSource.GMAIL: logger.debug(f"Chunking {document.semantic_identifier}") @@ -302,3 +305,13 @@ def chunk(self, document: Document) -> list[DocAwareChunk]: normal_chunks.extend(large_chunks) return normal_chunks + + def chunk(self, documents: list[Document]) -> list[DocAwareChunk]: + final_chunks: list[DocAwareChunk] = [] + for document in documents: + final_chunks.extend(self._handle_single_document(document)) + + if self.heartbeat: + self.heartbeat.heartbeat() + + return final_chunks diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index d25a0659c62..259bebd3fd9 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -1,12 +1,8 @@ from abc import ABC from abc import abstractmethod -from sqlalchemy.orm import Session - -from danswer.db.models import IndexModelStatus from danswer.db.models import SearchSettings -from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings +from danswer.indexing.indexing_heartbeat import Heartbeat from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk @@ -24,6 +20,9 @@ class IndexingEmbedder(ABC): + """Converts chunks into chunks with embeddings. Note that one chunk may have + multiple embeddings associated with it.""" + def __init__( self, model_name: str, @@ -33,6 +32,7 @@ def __init__( provider_type: EmbeddingProvider | None, api_key: str | None, api_url: str | None, + heartbeat: Heartbeat | None, ): self.model_name = model_name self.normalize = normalize @@ -54,6 +54,7 @@ def __init__( server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, retrim_content=True, + heartbeat=heartbeat, ) @abstractmethod @@ -74,6 +75,7 @@ def __init__( provider_type: EmbeddingProvider | None = None, api_key: str | None = None, api_url: str | None = None, + heartbeat: Heartbeat | None = None, ): super().__init__( model_name, @@ -83,6 +85,7 @@ def __init__( provider_type, api_key, api_url, + heartbeat, ) @log_function_time() @@ -166,7 +169,7 @@ def embed_chunks( title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk( - **chunk.dict(), + **chunk.model_dump(), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], @@ -180,7 +183,7 @@ def embed_chunks( @classmethod def from_db_search_settings( - cls, search_settings: SearchSettings + cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None ) -> "DefaultIndexingEmbedder": return cls( model_name=search_settings.model_name, @@ -190,28 +193,5 @@ def from_db_search_settings( provider_type=search_settings.provider_type, api_key=search_settings.api_key, api_url=search_settings.api_url, + heartbeat=heartbeat, ) - - -def get_embedding_model_from_search_settings( - db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT -) -> IndexingEmbedder: - search_settings: SearchSettings | None - if index_model_status == IndexModelStatus.PRESENT: - search_settings = get_current_search_settings(db_session) - elif index_model_status == IndexModelStatus.FUTURE: - search_settings = get_secondary_search_settings(db_session) - if not search_settings: - raise RuntimeError("No secondary index configured") - else: - raise RuntimeError("Not supporting embedding model rollbacks") - - return DefaultIndexingEmbedder( - model_name=search_settings.model_name, - normalize=search_settings.normalize, - query_prefix=search_settings.query_prefix, - passage_prefix=search_settings.passage_prefix, - provider_type=search_settings.provider_type, - api_key=search_settings.api_key, - api_url=search_settings.api_url, - ) diff --git a/backend/danswer/indexing/indexing_heartbeat.py b/backend/danswer/indexing/indexing_heartbeat.py new file mode 100644 index 00000000000..c500a0ad559 --- /dev/null +++ b/backend/danswer/indexing/indexing_heartbeat.py @@ -0,0 +1,41 @@ +import abc +from typing import Any + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from danswer.db.index_attempt import get_index_attempt +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class Heartbeat(abc.ABC): + """Useful for any long-running work that goes through a bunch of items + and needs to occasionally give updates on progress. + e.g. chunking, embedding, updating vespa, etc.""" + + @abc.abstractmethod + def heartbeat(self, metadata: Any = None) -> None: + raise NotImplementedError + + +class IndexingHeartbeat(Heartbeat): + def __init__(self, index_attempt_id: int, db_session: Session, freq: int): + self.cnt = 0 + + self.index_attempt_id = index_attempt_id + self.db_session = db_session + self.freq = freq + + def heartbeat(self, metadata: Any = None) -> None: + self.cnt += 1 + if self.cnt % self.freq == 0: + index_attempt = get_index_attempt( + db_session=self.db_session, index_attempt_id=self.index_attempt_id + ) + if index_attempt: + index_attempt.time_updated = func.now() + self.db_session.commit() + else: + logger.error("Index attempt not found, this should not happen!") diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 5d7412ea9d7..f58ab8a69aa 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentMetadata from danswer.indexing.chunker import Chunker from danswer.indexing.embedder import IndexingEmbedder +from danswer.indexing.indexing_heartbeat import IndexingHeartbeat from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.utils.logger import setup_logger @@ -283,18 +284,10 @@ def index_doc_batch( return 0, 0 logger.debug("Starting chunking") - chunks: list[DocAwareChunk] = [] - for document in ctx.updatable_docs: - chunks.extend(chunker.chunk(document=document)) + chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs) logger.debug("Starting embedding") - chunks_with_embeddings = ( - embedder.embed_chunks( - chunks=chunks, - ) - if chunks - else [] - ) + chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else [] updatable_ids = [doc.id for doc in ctx.updatable_docs] @@ -406,6 +399,13 @@ def build_indexing_pipeline( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=multipass, enable_large_chunks=enable_large_chunks, + # after every doc, update status in case there are a bunch of + # really long docs + heartbeat=IndexingHeartbeat( + index_attempt_id=attempt_id, db_session=db_session, freq=1 + ) + if attempt_id + else None, ) return partial( diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 6dcec724345..2fbf94a5be2 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -16,6 +16,7 @@ ) from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.models import SearchSettings +from danswer.indexing.indexing_heartbeat import Heartbeat from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.utils.logger import setup_logger @@ -95,6 +96,7 @@ def __init__( api_url: str | None, provider_type: EmbeddingProvider | None, retrim_content: bool = False, + heartbeat: Heartbeat | None = None, ) -> None: self.api_key = api_key self.provider_type = provider_type @@ -107,6 +109,7 @@ def __init__( self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) + self.heartbeat = heartbeat model_server_url = build_model_server_url(server_host, server_port) self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" @@ -166,6 +169,9 @@ def _batch_encode_texts( response = self._make_model_server_request(embed_request) embeddings.extend(response.embeddings) + + if self.heartbeat: + self.heartbeat.heartbeat() return embeddings def encode( diff --git a/backend/tests/unit/danswer/indexing/conftest.py b/backend/tests/unit/danswer/indexing/conftest.py new file mode 100644 index 00000000000..36e5659143f --- /dev/null +++ b/backend/tests/unit/danswer/indexing/conftest.py @@ -0,0 +1,18 @@ +from typing import Any + +import pytest + +from danswer.indexing.indexing_heartbeat import Heartbeat + + +class MockHeartbeat(Heartbeat): + def __init__(self) -> None: + self.call_count = 0 + + def heartbeat(self, metadata: Any = None) -> None: + self.call_count += 1 + + +@pytest.fixture +def mock_heartbeat() -> MockHeartbeat: + return MockHeartbeat() diff --git a/backend/tests/unit/danswer/indexing/test_chunker.py b/backend/tests/unit/danswer/indexing/test_chunker.py index f3a72fe17a3..71c3bbd886f 100644 --- a/backend/tests/unit/danswer/indexing/test_chunker.py +++ b/backend/tests/unit/danswer/indexing/test_chunker.py @@ -1,11 +1,24 @@ +import pytest + from danswer.configs.constants import DocumentSource from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.indexing.chunker import Chunker from danswer.indexing.embedder import DefaultIndexingEmbedder +from tests.unit.danswer.indexing.conftest import MockHeartbeat + + +@pytest.fixture +def embedder() -> DefaultIndexingEmbedder: + return DefaultIndexingEmbedder( + model_name="intfloat/e5-base-v2", + normalize=True, + query_prefix=None, + passage_prefix=None, + ) -def test_chunk_document() -> None: +def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None: short_section_1 = "This is a short section." long_section = ( "This is a long section that should be split into multiple chunks. " * 100 @@ -30,18 +43,11 @@ def test_chunk_document() -> None: ], ) - embedder = DefaultIndexingEmbedder( - model_name="intfloat/e5-base-v2", - normalize=True, - query_prefix=None, - passage_prefix=None, - ) - chunker = Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=False, ) - chunks = chunker.chunk(document) + chunks = chunker.chunk([document]) assert len(chunks) == 5 assert short_section_1 in chunks[0].content @@ -49,3 +55,29 @@ def test_chunk_document() -> None: assert short_section_4 in chunks[-1].content assert "tag1" in chunks[0].metadata_suffix_keyword assert "tag2" in chunks[0].metadata_suffix_semantic + + +def test_chunker_heartbeat( + embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat +) -> None: + document = Document( + id="test_doc", + source=DocumentSource.WEB, + semantic_identifier="Test Document", + metadata={"tags": ["tag1", "tag2"]}, + doc_updated_at=None, + sections=[ + Section(text="This is a short section.", link="link1"), + ], + ) + + chunker = Chunker( + tokenizer=embedder.embedding_model.tokenizer, + enable_multipass=False, + heartbeat=mock_heartbeat, + ) + + chunks = chunker.chunk([document]) + + assert mock_heartbeat.call_count == 1 + assert len(chunks) > 0 diff --git a/backend/tests/unit/danswer/indexing/test_embedder.py b/backend/tests/unit/danswer/indexing/test_embedder.py new file mode 100644 index 00000000000..6611c2e5985 --- /dev/null +++ b/backend/tests/unit/danswer/indexing/test_embedder.py @@ -0,0 +1,90 @@ +from collections.abc import Generator +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +from danswer.configs.constants import DocumentSource +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.indexing.embedder import DefaultIndexingEmbedder +from danswer.indexing.models import ChunkEmbedding +from danswer.indexing.models import DocAwareChunk +from danswer.indexing.models import IndexChunk +from shared_configs.enums import EmbeddingProvider +from shared_configs.enums import EmbedTextType + + +@pytest.fixture +def mock_embedding_model() -> Generator[Mock, None, None]: + with patch("danswer.indexing.embedder.EmbeddingModel") as mock: + yield mock + + +def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None: + # Setup + embedder = DefaultIndexingEmbedder( + model_name="test-model", + normalize=True, + query_prefix=None, + passage_prefix=None, + provider_type=EmbeddingProvider.OPENAI, + ) + + # Mock the encode method of the embedding model + mock_embedding_model.return_value.encode.side_effect = [ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # Main chunk embeddings + [[7.0, 8.0, 9.0]], # Title embedding + ] + + # Create test input + source_doc = Document( + id="test_doc", + source=DocumentSource.WEB, + semantic_identifier="Test Document", + metadata={"tags": ["tag1", "tag2"]}, + doc_updated_at=None, + sections=[ + Section(text="This is a short section.", link="link1"), + ], + ) + chunks: list[DocAwareChunk] = [ + DocAwareChunk( + chunk_id=1, + blurb="This is a short section.", + content="Test chunk", + source_links={0: "link1"}, + section_continuation=False, + source_document=source_doc, + title_prefix="Title: ", + metadata_suffix_semantic="", + metadata_suffix_keyword="", + mini_chunk_texts=None, + large_chunk_reference_ids=[], + ) + ] + + # Execute + result: list[IndexChunk] = embedder.embed_chunks(chunks) + + # Assert + assert len(result) == 1 + assert isinstance(result[0], IndexChunk) + assert result[0].content == "Test chunk" + assert result[0].embeddings == ChunkEmbedding( + full_embedding=[1.0, 2.0, 3.0], + mini_chunk_embeddings=[], + ) + assert result[0].title_embedding == [7.0, 8.0, 9.0] + + # Verify the embedding model was called correctly + mock_embedding_model.return_value.encode.assert_any_call( + texts=["Title: Test chunk"], + text_type=EmbedTextType.PASSAGE, + large_chunks_present=False, + ) + # title only embedding call + mock_embedding_model.return_value.encode.assert_any_call( + ["Test Document"], + text_type=EmbedTextType.PASSAGE, + ) diff --git a/backend/tests/unit/danswer/indexing/test_heartbeat.py b/backend/tests/unit/danswer/indexing/test_heartbeat.py new file mode 100644 index 00000000000..a59fac81283 --- /dev/null +++ b/backend/tests/unit/danswer/indexing/test_heartbeat.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from sqlalchemy.orm import Session + +from danswer.db.index_attempt import IndexAttempt +from danswer.indexing.indexing_heartbeat import IndexingHeartbeat + + +@pytest.fixture +def mock_db_session() -> MagicMock: + return MagicMock(spec=Session) + + +@pytest.fixture +def mock_index_attempt() -> MagicMock: + return MagicMock(spec=IndexAttempt) + + +def test_indexing_heartbeat( + mock_db_session: MagicMock, mock_index_attempt: MagicMock +) -> None: + with patch( + "danswer.indexing.indexing_heartbeat.get_index_attempt" + ) as mock_get_index_attempt: + mock_get_index_attempt.return_value = mock_index_attempt + + heartbeat = IndexingHeartbeat( + index_attempt_id=1, db_session=mock_db_session, freq=5 + ) + + # Test that heartbeat doesn't update before freq is reached + for _ in range(4): + heartbeat.heartbeat() + + mock_db_session.commit.assert_not_called() + + # Test that heartbeat updates when freq is reached + heartbeat.heartbeat() + + mock_get_index_attempt.assert_called_once_with( + db_session=mock_db_session, index_attempt_id=1 + ) + assert mock_index_attempt.time_updated is not None + mock_db_session.commit.assert_called_once() + + # Reset mock calls + mock_db_session.reset_mock() + mock_get_index_attempt.reset_mock() + + # Test that heartbeat updates again after freq more calls + for _ in range(5): + heartbeat.heartbeat() + + mock_get_index_attempt.assert_called_once() + mock_db_session.commit.assert_called_once() + + +def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None: + with patch( + "danswer.indexing.indexing_heartbeat.get_index_attempt" + ) as mock_get_index_attempt, patch( + "danswer.indexing.indexing_heartbeat.logger" + ) as mock_logger: + mock_get_index_attempt.return_value = None + + heartbeat = IndexingHeartbeat( + index_attempt_id=1, db_session=mock_db_session, freq=1 + ) + + heartbeat.heartbeat() + + mock_get_index_attempt.assert_called_once_with( + db_session=mock_db_session, index_attempt_id=1 + ) + mock_logger.error.assert_called_once_with( + "Index attempt not found, this should not happen!" + ) + mock_db_session.commit.assert_not_called()