diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index ccc673c62..0888407a4 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -14,21 +14,19 @@ import gooey_gui as gui import numpy as np import requests -from app_users.models import AppUser -from daras_ai.image_input import ( - get_mimetype_from_response, - safe_filename, - upload_file_from_bytes, -) from django.db import transaction -from django.db.models import F +from django.db.models import F, Q from django.utils import timezone -from embeddings.models import EmbeddedFile, EmbeddingsReference -from files.models import FileMetadata from furl import furl from loguru import logger from pydantic import BaseModel, Field +from app_users.models import AppUser +from daras_ai.image_input import ( + get_mimetype_from_response, + safe_filename, + upload_file_from_bytes, +) from daras_ai_v2 import settings from daras_ai_v2.asr import ( AsrModels, @@ -71,6 +69,8 @@ remove_quotes, ) from daras_ai_v2.text_splitter import Document, text_splitter +from embeddings.models import EmbeddedFile, EmbeddingsReference +from files.models import FileMetadata class DocSearchRequest(BaseModel): @@ -137,9 +137,7 @@ def get_top_k_references( """ from recipes.BulkRunner import url_to_runs - yield "Fetching latest knowledge docs..." input_docs = request.documents or [] - check_document_updates = request.check_document_updates if request.doc_extract_url: page_cls, sr, pr = url_to_runs(request.doc_extract_url) @@ -154,22 +152,34 @@ def get_top_k_references( EmbeddedFile._meta.get_field("embedding_model").default ), ) - embedded_files: list[EmbeddedFile] = yield from apply_parallel( - lambda f_url: get_or_create_embedded_file( - f_url=f_url, - max_context_words=request.max_context_words, - scroll_jump=request.scroll_jump, - google_translate_target=google_translate_target, - selected_asr_model=selected_asr_model, - embedding_model=embedding_model, - is_user_url=is_user_url, - check_document_updates=check_document_updates, - current_user=current_user, - ), - input_docs, - max_workers=4, - message="Fetching latest knowledge docs & Embeddings...", + + embedded_files, args_to_create = yield from do_check_document_updates( + input_docs=input_docs, + max_context_words=request.max_context_words, + scroll_jump=request.scroll_jump, + google_translate_target=google_translate_target, + selected_asr_model=selected_asr_model, + embedding_model=embedding_model, + check_document_updates=request.check_document_updates, ) + + if args_to_create: + embedded_files += yield from apply_parallel( + lambda args: create_embedded_file( + *args, + max_context_words=request.max_context_words, + scroll_jump=request.scroll_jump, + google_translate_target=google_translate_target, + selected_asr_model=selected_asr_model, + embedding_model=embedding_model, + is_user_url=is_user_url, + current_user=current_user, + ), + args_to_create, + max_workers=4, + message="Creating knowledge embeddings...", + ) + if not embedded_files: yield "No embeddings found - skipping search" return [] @@ -239,7 +249,7 @@ def query_vespa( return {"root": {"children": []}} file_ids_str = ", ".join(map(repr, file_ids)) query = f"select * from {settings.VESPA_SCHEMA} where file_id in (@fileIds) and (userQuery() or ({{targetHits: {limit}}}nearestNeighbor(embedding, q))) limit {limit}" - logger.debug(f"Vespa query: {'-'*80}\n{query}\n{'-'*80}") + logger.debug(f"Vespa query: {query!r}") if semantic_weight == 1.0: ranking = "semantic" elif semantic_weight == 0.0: @@ -286,20 +296,22 @@ def doc_or_yt_url_to_file_metas( def yt_info_to_playlist_metadata(data: dict) -> FileMetadata: + etag = data.get("modified_date") or data.get("playlist_count") return FileMetadata( name=data.get("title", "YouTube Playlist"), # youtube doesn't provide etag, so we use modified_date / playlist_count - etag=data.get("modified_date") or data.get("playlist_count"), + etag=etag and str(etag) or None, # will be converted later & saved as wav mime_type="audio/wav", ) def yt_info_to_video_metadata(data: dict) -> FileMetadata: + etag = data.get("filesize_approx") or data.get("upload_date") return FileMetadata( name=data.get("title", "YouTube Video"), # youtube doesn't provide etag, so we use filesize_approx or upload_date - etag=data.get("filesize_approx") or data.get("upload_date"), + etag=etag and str(etag) or None, # we will later convert & save as wav mime_type="audio/wav", total_bytes=data.get("filesize_approx", 0), @@ -328,15 +340,17 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: total_bytes = int(meta.get("size") or 0) export_links = meta.get("exportLinks", None) else: + if is_user_uploaded_url(f_url): + kwargs = {} + else: + kwargs = requests_scraping_kwargs() | dict( + timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC + ) try: - if is_user_uploaded_url(f_url): - r = requests.head(f_url) - else: - r = requests.head( - f_url, - timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, - **requests_scraping_kwargs(), - ) + r = requests.head(f_url, **kwargs) + if r.status_code == 405: + r = requests.get(f_url, **kwargs, stream=True) + r.close() raise_for_status(r) except requests.RequestException as e: logger.warning(f"ignore error while downloading {f_url}: {e}") @@ -403,52 +417,100 @@ def yt_dlp_extract_info(url: str, **params) -> dict: return data -def get_or_create_embedded_file( +def do_check_document_updates( *, - f_url: str, + input_docs: list[str], max_context_words: int, scroll_jump: int, google_translate_target: str | None, selected_asr_model: str | None, embedding_model: EmbeddingModels, - is_user_url: bool, check_document_updates: bool, +) -> typing.Generator[ + str, + None, + tuple[ + list[EmbeddedFile], + list[tuple[dict, FileMetadata, list[tuple[str, FileMetadata]]]], + ], +]: + lookups = {} + q = Q() + for f_url in input_docs: + lookup = dict( + url=f_url, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + google_translate_target=google_translate_target or "", + selected_asr_model=selected_asr_model or "", + embedding_model=embedding_model.name, + ) + q |= Q(**lookup) + lookups[f_url] = lookup + + cached_files = { + f.url: f + for f in ( + EmbeddedFile.objects.filter(q) + .select_related("metadata") + .order_by("url", "-updated_at") + .distinct("url") + ) + } + + for f_url in cached_files: + if is_user_uploaded_url(f_url) or not check_document_updates: + lookups.pop(f_url, None) + + metadatas = yield from apply_parallel( + doc_or_yt_url_to_file_metas, + lookups.keys(), + message="Fetching latest knowlege docs...", + max_workers=100, + ) + + args_to_create = [] + for (f_url, lookup), (file_meta, leaf_url_metas) in zip(lookups.items(), metadatas): + f = cached_files.get(f_url) + if f and f.metadata == file_meta: + continue + else: + args_to_create.append((lookup, file_meta, leaf_url_metas)) + + return list(cached_files.values()), args_to_create + + +def create_embedded_file( + lookup: dict, + file_meta: FileMetadata, + leaf_url_metas: list[tuple[str, FileMetadata]], + *, + max_context_words: int, + scroll_jump: int, + google_translate_target: str | None, + selected_asr_model: str | None, + embedding_model: EmbeddingModels, + is_user_url: bool, current_user: AppUser, ) -> EmbeddedFile: """ Return Vespa document ids and document tags for a given document url + metadata. """ - lookup = dict( - url=f_url, - max_context_words=max_context_words, - scroll_jump=scroll_jump, - google_translate_target=google_translate_target or "", - selected_asr_model=selected_asr_model or "", - embedding_model=embedding_model.name, - ) - lock_id = hashlib.sha256(str(lookup).encode()).hexdigest() + lock_id = _sha256(lookup) with redis_lock(f"gooey/get_or_create_embeddings/v1/{lock_id}"): - try: - embedded_file = EmbeddedFile.objects.filter(**lookup).order_by( - "-updated_at" - )[0] - except IndexError: - embedded_file = None - else: - # skip metadata check for bucket urls (since they are unique & static) - if is_user_uploaded_url(f_url) or not check_document_updates: - return embedded_file - - file_meta, leaf_url_metas = doc_or_yt_url_to_file_metas(f_url) - if embedded_file and embedded_file.metadata.astuple() == file_meta.astuple(): - # metadata hasn't changed, return existing file - return embedded_file - - file_id_fields = lookup | dict(metadata=file_meta.astuple()) - file_id = hashlib.sha256(str(file_id_fields).encode()).hexdigest() + # check if embeddings already exist and are up-to-date + f = ( + EmbeddedFile.objects.filter(**lookup) + .select_related("metadata") + .order_by("-updated_at") + .first() + ) + if f and f.metadata == file_meta: + return f # create fresh embeddings + file_id = _sha256(lookup | dict(metadata=file_meta.astuple())) for leaf_url, leaf_meta in leaf_url_metas: refs = create_embeddings_in_search_db( f_url=leaf_url, @@ -462,12 +524,15 @@ def get_or_create_embedded_file( is_user_url=is_user_url, ) with transaction.atomic(): - EmbeddedFile.objects.filter(**lookup).delete() + EmbeddedFile.objects.filter(Q(**lookup) | Q(vespa_file_id=file_id)).delete() file_meta.save() - embedded_file = EmbeddedFile.objects.get_or_create( + embedded_file = EmbeddedFile.objects.create( vespa_file_id=file_id, - defaults=lookup | dict(metadata=file_meta, created_by=current_user), - )[0] + metadata=file_meta, + created_by=current_user, + **lookup, + ) + logger.debug(f"created: {embedded_file}") for ref in refs: ref.embedded_file = embedded_file EmbeddingsReference.objects.bulk_create( @@ -503,7 +568,7 @@ def create_embeddings_in_search_db( embedding_model=embedding_model, is_user_url=is_user_url, ): - doc_id = file_id + "/" + hashlib.sha256(str(ref).encode()).hexdigest() + doc_id = file_id + "/" + _sha256(ref) db_ref = EmbeddingsReference( vespa_doc_id=doc_id, url=ref["url"], @@ -526,6 +591,10 @@ def create_embeddings_in_search_db( return list(refs.values()) +def _sha256(x) -> str: + return hashlib.sha256(str(x).encode()).hexdigest() + + def get_embeds_for_doc( *, f_url: str, diff --git a/files/models.py b/files/models.py index a84a5303a..686d9ba04 100644 --- a/files/models.py +++ b/files/models.py @@ -1,5 +1,6 @@ from django.db import models from django.template.defaultfilters import filesizeformat +from loguru import logger class FileMetadata(models.Model): @@ -18,6 +19,16 @@ def __str__(self): ret += f" - {self.etag}" return ret + def __eq__(self, other): + ret = bool( + isinstance(other, FileMetadata) + # avoid null comparisions -- when metadata is not available and hence not comparable + and (self.etag or other.etag or self.total_bytes or other.total_bytes) + and self.astuple() == other.astuple() + ) + logger.debug(f"checking: `{self}` == `{other}` ({ret})") + return ret + def astuple(self) -> tuple: return self.name, self.etag, self.mime_type, self.total_bytes diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py index 945d8f286..68e45f94f 100644 --- a/functions/recipe_functions.py +++ b/functions/recipe_functions.py @@ -148,8 +148,11 @@ def call_recipe_functions( state: dict, trigger: FunctionTrigger, ) -> typing.Iterable[str]: + tools = list(get_tools_from_state(state, trigger)) + if not tools: + return yield f"Running {trigger.name} hooks..." - for tool in get_tools_from_state(state, trigger): + for tool in tools: tool.bind( saved_run=saved_run, workspace=workspace,