diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 302e5cda9..3588def12 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -84,7 +84,7 @@ class DocSearchRequest(BaseModel): scroll_jump: int | None doc_extract_url: str | None - + cache_knowledge: typing.Optional[bool] = True embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None dense_weight: float | None = Field( ge=0.0, @@ -146,9 +146,14 @@ def get_top_k_references( else: selected_asr_model = google_translate_target = None - file_url_metas = flatmap_parallel(doc_or_yt_url_to_metadatas, input_docs) + cache_knowledge = request.cache_knowledge + file_url_metas = flatmap_parallel( + lambda f_url: doc_or_yt_url_to_metadatas(f_url, cache_knowledge), input_docs + ) file_urls, file_metas = zip(*file_url_metas) + logger.debug(f"file_urls: {file_urls}") + yield "Creating knowledge embeddings..." embedding_model = EmbeddingModels.get( @@ -269,7 +274,9 @@ def get_vespa_app(): return Vespa(url=settings.VESPA_URL) -def doc_or_yt_url_to_metadatas(f_url: str) -> list[tuple[str, FileMetadata]]: +def doc_or_yt_url_to_metadatas( + f_url: str, cache_knowledge: bool = True +) -> list[tuple[str, FileMetadata]]: if is_yt_dlp_able_url(f_url): entries = yt_dlp_get_video_entries(f_url) return [ @@ -287,10 +294,10 @@ def doc_or_yt_url_to_metadatas(f_url: str) -> list[tuple[str, FileMetadata]]: for entry in entries ] else: - return [(f_url, doc_url_to_file_metadata(f_url))] + return [(f_url, doc_url_to_file_metadata(f_url, cache_knowledge))] -def doc_url_to_file_metadata(f_url: str) -> FileMetadata: +def doc_url_to_file_metadata(f_url: str, cache_knowledge: bool = True) -> FileMetadata: from googleapiclient.errors import HttpError f = furl(f_url.strip("/")) @@ -314,7 +321,20 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: else: try: if is_user_uploaded_url(f_url): - r = requests.head(f_url) + + if cache_knowledge: + logger.debug(f"using cached metadata for {f_url}") + name = f.path.segments[-1] + return FileMetadata( + name=name, + etag=None, + mime_type=mimetypes.guess_type(name)[0], + total_bytes=0, + ) + else: + logger.debug(f"fetching latest metadata for {f_url}") + r = requests.head(f_url) + else: r = requests.head( f_url, @@ -402,10 +422,10 @@ def get_or_create_embedded_file( """ lookup = dict( url=f_url, - metadata__name=file_meta.name, - metadata__etag=file_meta.etag, - metadata__mime_type=file_meta.mime_type, - metadata__total_bytes=file_meta.total_bytes, + # metadata__name=file_meta.name, + # metadata__etag=file_meta.etag, + # metadata__mime_type=file_meta.mime_type, + # metadata__total_bytes=file_meta.total_bytes, max_context_words=max_context_words, scroll_jump=scroll_jump, google_translate_target=google_translate_target or "", diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 0fa063379..951808784 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -225,6 +225,7 @@ def run_v2( message="Updating sheet...", ) else: + # TODO: implement this as well file_url_metas = yield from flatapply_parallel( doc_or_yt_url_to_metadatas, request.documents,