diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index 14d2b006bb8..ffa12ab4b34 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -14,10 +14,14 @@ import danswer.background.celery.apps.app_base as app_base from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_utils import celery_is_worker_primary +from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME +from danswer.db.engine import get_session_with_default_tenant from danswer.db.engine import SqlEngine +from danswer.db.index_attempt import get_index_attempt +from danswer.db.index_attempt import mark_attempt_failed from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from danswer.redis.redis_connector_delete import RedisConnectorDelete from danswer.redis.redis_connector_index import RedisConnectorIndex @@ -134,6 +138,23 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: RedisConnectorStop.reset_all(r) + # mark orphaned index attempts as failed + with get_session_with_default_tenant() as db_session: + unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r) + for attempt_id in unfenced_attempt_ids: + attempt = get_index_attempt(db_session, attempt_id) + if not attempt: + continue + + failure_reason = ( + f"Orphaned index attempt found on startup: " + f"index_attempt={attempt.id} " + f"cc_pair={attempt.connector_credential_pair_id} " + f"search_settings={attempt.search_settings_id}" + ) + logger.warning(failure_reason) + mark_attempt_failed(attempt.id, db_session, failure_reason) + @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 360481015bb..213b46ae594 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -1,12 +1,12 @@ from datetime import datetime from datetime import timezone -import redis from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger @@ -87,7 +87,7 @@ def try_generate_document_cc_pair_cleanup_tasks( cc_pair_id: int, db_session: Session, r: Redis, - lock_beat: redis.lock.Lock, + lock_beat: RedisLock, tenant_id: str | None, ) -> int | None: """Returns an int if syncing is needed. The int represents the number of sync tasks generated. diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 666defd9586..0af1f019502 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -3,13 +3,14 @@ from http import HTTPStatus from time import sleep -import redis import sentry_sdk from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.exceptions import LockError +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger @@ -44,7 +45,7 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.redis.redis_connector import RedisConnector -from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData +from danswer.redis.redis_connector_index import RedisConnectorIndexPayload from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version @@ -61,14 +62,18 @@ def __init__( self, stop_key: str, generator_progress_key: str, - redis_lock: redis.lock.Lock, + redis_lock: RedisLock, redis_client: Redis, ): super().__init__() - self.redis_lock: redis.lock.Lock = redis_lock + self.redis_lock: RedisLock = redis_lock self.stop_key: str = stop_key self.generator_progress_key: str = generator_progress_key self.redis_client = redis_client + self.started: datetime = datetime.now(timezone.utc) + self.redis_lock.reacquire() + + self.last_lock_reacquire: datetime = datetime.now(timezone.utc) def should_stop(self) -> bool: if self.redis_client.exists(self.stop_key): @@ -76,7 +81,19 @@ def should_stop(self) -> bool: return False def progress(self, amount: int) -> None: - self.redis_lock.reacquire() + try: + self.redis_lock.reacquire() + self.last_lock_reacquire = datetime.now(timezone.utc) + except LockError: + logger.exception( + f"RunIndexingCallback - lock.reacquire exceptioned. " + f"lock_timeout={self.redis_lock.timeout} " + f"start={self.started} " + f"last_reacquired={self.last_lock_reacquire} " + f"now={datetime.now(timezone.utc)}" + ) + raise + self.redis_client.incrby(self.generator_progress_key, amount) @@ -325,7 +342,7 @@ def try_creating_indexing_task( redis_connector_index.generator_clear() # set a basic fence to start - payload = RedisConnectorIndexingFenceData( + payload = RedisConnectorIndexPayload( index_attempt_id=None, started=None, submitted=datetime.now(timezone.utc), @@ -368,7 +385,7 @@ def try_creating_indexing_task( redis_connector_index.set_fence(payload) except Exception: - redis_connector_index.set_fence(payload) + redis_connector_index.set_fence(None) task_logger.exception( f"Unexpected exception: " f"tenant={tenant_id} " diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index b01a0eac815..11274666944 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -13,6 +13,7 @@ from celery.result import AsyncResult from celery.states import READY_STATES from redis import Redis +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from tenacity import RetryError @@ -162,7 +163,7 @@ def try_generate_stale_document_sync_tasks( celery_app: Celery, db_session: Session, r: Redis, - lock_beat: redis.lock.Lock, + lock_beat: RedisLock, tenant_id: str | None, ) -> int | None: # the fence is up, do nothing @@ -180,7 +181,12 @@ def try_generate_stale_document_sync_tasks( f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair." ) - task_logger.info("RedisConnector.generate_tasks starting by cc_pair.") + task_logger.info( + "RedisConnector.generate_tasks starting by cc_pair. " + "Documents spanning multiple cc_pairs will only be synced once." + ) + + docs_to_skip: set[str] = set() # rkuo: we could technically sync all stale docs in one big pass. # but I feel it's more understandable to group the docs by cc_pair @@ -188,22 +194,21 @@ def try_generate_stale_document_sync_tasks( cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id) - tasks_generated = rc.generate_tasks( - celery_app, db_session, r, lock_beat, tenant_id - ) + rc.set_skip_docs(docs_to_skip) + result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id) - if tasks_generated is None: + if result is None: continue - if tasks_generated == 0: + if result[1] == 0: continue task_logger.info( f"RedisConnector.generate_tasks finished for single cc_pair. " - f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" + f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}" ) - total_tasks_generated += tasks_generated + total_tasks_generated += result[0] task_logger.info( f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}" @@ -218,7 +223,7 @@ def try_generate_document_set_sync_tasks( document_set_id: int, db_session: Session, r: Redis, - lock_beat: redis.lock.Lock, + lock_beat: RedisLock, tenant_id: str | None, ) -> int | None: lock_beat.reacquire() @@ -246,12 +251,11 @@ def try_generate_document_set_sync_tasks( ) # Add all documents that need to be updated into the queue - tasks_generated = rds.generate_tasks( - celery_app, db_session, r, lock_beat, tenant_id - ) - if tasks_generated is None: + result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id) + if result is None: return None + tasks_generated = result[0] # Currently we are allowing the sync to proceed with 0 tasks. # It's possible for sets/groups to be generated initially with no entries # and they still need to be marked as up to date. @@ -260,7 +264,7 @@ def try_generate_document_set_sync_tasks( task_logger.info( f"RedisDocumentSet.generate_tasks finished. " - f"document_set_id={document_set.id} tasks_generated={tasks_generated}" + f"document_set={document_set.id} tasks_generated={tasks_generated}" ) # set this only after all tasks have been added @@ -273,7 +277,7 @@ def try_generate_user_group_sync_tasks( usergroup_id: int, db_session: Session, r: Redis, - lock_beat: redis.lock.Lock, + lock_beat: RedisLock, tenant_id: str | None, ) -> int | None: lock_beat.reacquire() @@ -302,12 +306,11 @@ def try_generate_user_group_sync_tasks( task_logger.info( f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}" ) - tasks_generated = rug.generate_tasks( - celery_app, db_session, r, lock_beat, tenant_id - ) - if tasks_generated is None: + result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id) + if result is None: return None + tasks_generated = result[0] # Currently we are allowing the sync to proceed with 0 tasks. # It's possible for sets/groups to be generated initially with no entries # and they still need to be marked as up to date. @@ -316,7 +319,7 @@ def try_generate_user_group_sync_tasks( task_logger.info( f"RedisUserGroup.generate_tasks finished. " - f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}" + f"usergroup={usergroup.id} tasks_generated={tasks_generated}" ) # set this only after all tasks have been added @@ -580,8 +583,8 @@ def monitor_ccpair_indexing_taskset( progress = redis_connector_index.get_progress() if progress is not None: task_logger.info( - f"Connector indexing progress: cc_pair_id={cc_pair_id} " - f"search_settings_id={search_settings_id} " + f"Connector indexing progress: cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " f"progress={progress} " f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}" ) @@ -602,8 +605,8 @@ def monitor_ccpair_indexing_taskset( # if it isn't, then the worker crashed task_logger.info( f"Connector indexing aborted: " - f"cc_pair_id={cc_pair_id} " - f"search_settings_id={search_settings_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}" ) @@ -621,8 +624,8 @@ def monitor_ccpair_indexing_taskset( status_enum = HTTPStatus(status_int) task_logger.info( - f"Connector indexing finished: cc_pair_id={cc_pair_id} " - f"search_settings_id={search_settings_id} " + f"Connector indexing finished: cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " f"status={status_enum.name} " f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}" ) @@ -630,6 +633,37 @@ def monitor_ccpair_indexing_taskset( redis_connector_index.reset() +def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]: + """Gets a list of unfenced index attempts. Should not be possible, so we'd typically + want to clean them up. + + Unfenced = attempt not in terminal state and fence does not exist. + """ + unfenced_attempts: list[int] = [] + + # do some cleanup before clearing fences + # check the db for any outstanding index attempts + attempts: list[IndexAttempt] = [] + attempts.extend( + get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session) + ) + attempts.extend( + get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session) + ) + + for attempt in attempts: + # if attempts exist in the db but we don't detect them in redis, mark them as failed + fence_key = RedisConnectorIndex.fence_key_with_ids( + attempt.connector_credential_pair_id, attempt.search_settings_id + ) + if r.exists(fence_key): + continue + + unfenced_attempts.append(attempt.id) + + return unfenced_attempts + + @shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True) def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: """This is a celery beat task that monitors and finalizes metadata sync tasksets. @@ -643,7 +677,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: """ r = get_redis_client(tenant_id=tenant_id) - lock_beat: redis.lock.Lock = r.lock( + lock_beat: RedisLock = r.lock( DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) @@ -677,31 +711,24 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: f"pruning={n_pruning}" ) - # do some cleanup before clearing fences - # check the db for any outstanding index attempts + # Fail any index attempts in the DB that don't have fences with get_session_with_tenant(tenant_id) as db_session: - attempts: list[IndexAttempt] = [] - attempts.extend( - get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session) - ) - attempts.extend( - get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session) - ) - - for a in attempts: - # if attempts exist in the db but we don't detect them in redis, mark them as failed - fence_key = RedisConnectorIndex.fence_key_with_ids( - a.connector_credential_pair_id, a.search_settings_id + unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r) + for attempt_id in unfenced_attempt_ids: + attempt = get_index_attempt(db_session, attempt_id) + if not attempt: + continue + + failure_reason = ( + f"Unfenced index attempt found in DB: " + f"index_attempt={attempt.id} " + f"cc_pair={attempt.connector_credential_pair_id} " + f"search_settings={attempt.search_settings_id}" + ) + task_logger.warning(failure_reason) + mark_attempt_failed( + attempt.id, db_session, failure_reason=failure_reason ) - if not r.exists(fence_key): - failure_reason = ( - f"Unknown index attempt. Might be left over from a process restart: " - f"index_attempt={a.id} " - f"cc_pair={a.connector_credential_pair_id} " - f"search_settings={a.search_settings_id}" - ) - task_logger.warning(failure_reason) - mark_attempt_failed(a.id, db_session, failure_reason=failure_reason) lock_beat.reacquire() if r.exists(RedisConnectorCredentialPair.get_fence_key()): diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 7b3ea8e81bb..0bc0dc60008 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -74,7 +74,7 @@ # needs to be long enough to cover the maximum time it takes to download an object # if we can get callbacks as object bytes download, we could lower this a lot. -CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min +CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min # needs to be long enough to cover the maximum time it takes to download an object # if we can get callbacks as object bytes download, we could lower this a lot. diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 32b2530e7c4..9dd448f6a6d 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -169,6 +169,7 @@ def get_document_connector_counts( def get_document_counts_for_cc_pairs( db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier] ) -> Sequence[tuple[int, int, int]]: + """Returns a sequence of tuples of (connector_id, credential_id, document count)""" stmt = ( select( DocumentByConnectorCredentialPair.connector_id, @@ -509,7 +510,7 @@ def prepare_to_modify_documents( db_session.commit() # ensure that we're not in a transaction lock_acquired = False - for _ in range(_NUM_LOCK_ATTEMPTS): + for i in range(_NUM_LOCK_ATTEMPTS): try: with db_session.begin() as transaction: lock_acquired = acquire_document_locks( @@ -520,7 +521,7 @@ def prepare_to_modify_documents( break except OperationalError as e: logger.warning( - f"Failed to acquire locks for documents, retrying. Error: {e}" + f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}" ) time.sleep(retry_delay) diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 9d179f440b6..df4909a13b3 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -93,14 +93,15 @@ def _upsert_documents_in_db( document_id=doc.id, db_session=db_session, ) - else: - create_or_add_document_tag( - tag_key=k, - tag_value=v, - source=doc.source, - document_id=doc.id, - db_session=db_session, - ) + continue + + create_or_add_document_tag( + tag_key=k, + tag_value=v, + source=doc.source, + document_id=doc.id, + db_session=db_session, + ) def get_doc_ids_to_update( @@ -196,7 +197,7 @@ def index_doc_batch_prepare( db_session: Session, ignore_time_skip: bool = False, ) -> DocumentBatchPrepareContext | None: - """This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc. + """Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc. This preceeds indexing it into the actual document index.""" documents: list[Document] = [] for document in document_batch: @@ -213,16 +214,17 @@ def index_doc_batch_prepare( logger.warning( f"Skipping document with ID {document.id} as it has neither title nor content." ) - elif ( - document.title is not None and not document.title.strip() and empty_contents - ): + continue + + if document.title is not None and not document.title.strip() and empty_contents: # The title is explicitly empty ("" and not None) and the document is empty # so when building the chunk text representation, it will be empty and unuseable logger.warning( f"Skipping document with ID {document.id} as the chunks will be empty." ) - else: - documents.append(document) + continue + + documents.append(document) # Create a trimmed list of docs that don't have a newer updated at # Shortcuts the time-consuming flow on connector index retries @@ -284,7 +286,10 @@ def index_doc_batch( ) -> tuple[int, int]: """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the - memory requirements""" + memory requirements + + Returns a tuple where the first element is the number of new docs and the + second element is the number of chunks.""" no_access = DocumentAccess.build( user_emails=[], @@ -327,9 +332,9 @@ def index_doc_batch( # we're concerned about race conditions where multiple simultaneous indexings might result # in one set of metadata overwriting another one in vespa. - # we still write data here for immediate and most likely correct sync, but + # we still write data here for the immediate and most likely correct sync, but # to resolve this, an update of the last modified field at the end of this loop - # always triggers a final metadata sync + # always triggers a final metadata sync via the celery queue access_aware_chunks = [ DocMetadataAwareIndexChunk.from_index_chunk( index_chunk=chunk, @@ -366,7 +371,8 @@ def index_doc_batch( ids_to_new_updated_at = {} for doc in successful_docs: last_modified_ids.append(doc.id) - # doc_updated_at is the connector source's idea of when the doc was last modified + # doc_updated_at is the source's idea (on the other end of the connector) + # of when the doc was last modified if doc.doc_updated_at is None: continue ids_to_new_updated_at[doc.id] = doc.doc_updated_at @@ -381,10 +387,13 @@ def index_doc_batch( db_session.commit() - return len([r for r in insertion_records if r.already_existed is False]), len( - access_aware_chunks + result = ( + len([r for r in insertion_records if r.already_existed is False]), + len(access_aware_chunks), ) + return result + def build_indexing_pipeline( *, diff --git a/backend/danswer/redis/redis_connector_credential_pair.py b/backend/danswer/redis/redis_connector_credential_pair.py index bbad3700111..7ed09d76a2d 100644 --- a/backend/danswer/redis/redis_connector_credential_pair.py +++ b/backend/danswer/redis/redis_connector_credential_pair.py @@ -1,9 +1,10 @@ import time +from typing import cast from uuid import uuid4 -import redis from celery import Celery from redis import Redis +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -13,6 +14,7 @@ from danswer.db.document import ( construct_document_select_for_connector_credential_pair_by_needs_sync, ) +from danswer.db.models import Document from danswer.redis.redis_object_helper import RedisObjectHelper @@ -30,6 +32,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper): def __init__(self, tenant_id: str | None, id: int) -> None: super().__init__(tenant_id, str(id)) + # documents that should be skipped + self.skip_docs: set[str] = set() + @classmethod def get_fence_key(cls) -> str: return RedisConnectorCredentialPair.FENCE_PREFIX @@ -45,14 +50,19 @@ def taskset_key(self) -> str: # example: connector_taskset return f"{self.TASKSET_PREFIX}" + def set_skip_docs(self, skip_docs: set[str]) -> None: + # documents that should be skipped. Note that this classes updates + # the list on the fly + self.skip_docs = skip_docs + def generate_tasks( self, celery_app: Celery, db_session: Session, redis_client: Redis, - lock: redis.lock.Lock, + lock: RedisLock, tenant_id: str | None, - ) -> int | None: + ) -> tuple[int, int] | None: last_lock_time = time.monotonic() async_results = [] @@ -63,7 +73,10 @@ def generate_tasks( stmt = construct_document_select_for_connector_credential_pair_by_needs_sync( cc_pair.connector_id, cc_pair.credential_id ) + + num_docs = 0 for doc in db_session.scalars(stmt).yield_per(1): + doc = cast(Document, doc) current_time = time.monotonic() if current_time - last_lock_time >= ( CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 @@ -71,6 +84,12 @@ def generate_tasks( lock.reacquire() last_lock_time = current_time + num_docs += 1 + + # check if we should skip the document (typically because it's already syncing) + if doc.id in self.skip_docs: + continue + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" # the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" # we prefix the task id so it's easier to keep track of who created the task @@ -93,5 +112,6 @@ def generate_tasks( ) async_results.append(result) + self.skip_docs.add(doc.id) - return len(async_results) + return len(async_results), num_docs diff --git a/backend/danswer/redis/redis_connector_delete.py b/backend/danswer/redis/redis_connector_delete.py index ba250c5b0f7..6de4a9ec079 100644 --- a/backend/danswer/redis/redis_connector_delete.py +++ b/backend/danswer/redis/redis_connector_delete.py @@ -6,6 +6,7 @@ import redis from celery import Celery from pydantic import BaseModel +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -83,7 +84,7 @@ def generate_tasks( self, celery_app: Celery, db_session: Session, - lock: redis.lock.Lock, + lock: RedisLock, ) -> int | None: """Returns None if the cc_pair doesn't exist. Otherwise, returns an int with the number of generated tasks.""" diff --git a/backend/danswer/redis/redis_connector_index.py b/backend/danswer/redis/redis_connector_index.py index 3883ddceaa3..10fd3667fda 100644 --- a/backend/danswer/redis/redis_connector_index.py +++ b/backend/danswer/redis/redis_connector_index.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -class RedisConnectorIndexingFenceData(BaseModel): +class RedisConnectorIndexPayload(BaseModel): index_attempt_id: int | None started: datetime | None submitted: datetime @@ -71,22 +71,20 @@ def fenced(self) -> bool: return False @property - def payload(self) -> RedisConnectorIndexingFenceData | None: + def payload(self) -> RedisConnectorIndexPayload | None: # read related data and evaluate/print task progress fence_bytes = cast(bytes, self.redis.get(self.fence_key)) if fence_bytes is None: return None fence_str = fence_bytes.decode("utf-8") - payload = RedisConnectorIndexingFenceData.model_validate_json( - cast(str, fence_str) - ) + payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str)) return payload def set_fence( self, - payload: RedisConnectorIndexingFenceData | None, + payload: RedisConnectorIndexPayload | None, ) -> None: if not payload: self.redis.delete(self.fence_key) diff --git a/backend/danswer/redis/redis_connector_prune.py b/backend/danswer/redis/redis_connector_prune.py index 8892c12b647..25e0a6314de 100644 --- a/backend/danswer/redis/redis_connector_prune.py +++ b/backend/danswer/redis/redis_connector_prune.py @@ -4,6 +4,7 @@ import redis from celery import Celery +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -105,7 +106,7 @@ def generate_tasks( documents_to_prune: set[str], celery_app: Celery, db_session: Session, - lock: redis.lock.Lock | None, + lock: RedisLock | None, ) -> int | None: last_lock_time = time.monotonic() diff --git a/backend/danswer/redis/redis_document_set.py b/backend/danswer/redis/redis_document_set.py index 102e910feec..879d955eb88 100644 --- a/backend/danswer/redis/redis_document_set.py +++ b/backend/danswer/redis/redis_document_set.py @@ -5,6 +5,7 @@ import redis from celery import Celery from redis import Redis +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -50,9 +51,9 @@ def generate_tasks( celery_app: Celery, db_session: Session, redis_client: Redis, - lock: redis.lock.Lock, + lock: RedisLock, tenant_id: str | None, - ) -> int | None: + ) -> tuple[int, int] | None: last_lock_time = time.monotonic() async_results = [] @@ -84,7 +85,7 @@ def generate_tasks( async_results.append(result) - return len(async_results) + return len(async_results), len(async_results) def reset(self) -> None: self.redis.delete(self.taskset_key) diff --git a/backend/danswer/redis/redis_object_helper.py b/backend/danswer/redis/redis_object_helper.py index 629f15e6058..35366a36aab 100644 --- a/backend/danswer/redis/redis_object_helper.py +++ b/backend/danswer/redis/redis_object_helper.py @@ -1,9 +1,9 @@ from abc import ABC from abc import abstractmethod -import redis from celery import Celery from redis import Redis +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.redis.redis_pool import get_redis_client @@ -85,7 +85,13 @@ def generate_tasks( celery_app: Celery, db_session: Session, redis_client: Redis, - lock: redis.lock.Lock, + lock: RedisLock, tenant_id: str | None, - ) -> int | None: - pass + ) -> tuple[int, int] | None: + """First element should be the number of actual tasks generated, second should + be the number of docs that were candidates to be synced for the cc pair. + + The need for this is when we are syncing stale docs referenced by multiple + connectors. In a single pass across multiple cc pairs, we only want a task + for be created for a particular document id the first time we see it. + The rest can be skipped.""" diff --git a/backend/danswer/redis/redis_usergroup.py b/backend/danswer/redis/redis_usergroup.py index 53d2d4fc0a9..7c49b9c7fb8 100644 --- a/backend/danswer/redis/redis_usergroup.py +++ b/backend/danswer/redis/redis_usergroup.py @@ -5,6 +5,7 @@ import redis from celery import Celery from redis import Redis +from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -51,15 +52,15 @@ def generate_tasks( celery_app: Celery, db_session: Session, redis_client: Redis, - lock: redis.lock.Lock, + lock: RedisLock, tenant_id: str | None, - ) -> int | None: + ) -> tuple[int, int] | None: last_lock_time = time.monotonic() async_results = [] if not global_version.is_ee_version(): - return 0 + return 0, 0 try: construct_document_select_by_usergroup = fetch_versioned_implementation( @@ -67,7 +68,7 @@ def generate_tasks( "construct_document_select_by_usergroup", ) except ModuleNotFoundError: - return 0 + return 0, 0 stmt = construct_document_select_by_usergroup(int(self._id)) for doc in db_session.scalars(stmt).yield_per(1): @@ -97,7 +98,7 @@ def generate_tasks( async_results.append(result) - return len(async_results) + return len(async_results), len(async_results) def reset(self) -> None: self.redis.delete(self.taskset_key)