Skip to content

Commit

Permalink
Improve embedded file caching
Browse files Browse the repository at this point in the history
- avoid race conditions when running multiple refresh cache calls
- check document metadata updates with higher parallelism
- fix metadata equality check with null comparison when metadata is not present
- handle 405 method not allowed if doing a head() request for metadata
- fix etag None == 'None' incorrect type comparison in youtube metadata
- more efficient db queries for looking up existing metadata
  • Loading branch information
devxpy committed Jan 9, 2025
1 parent 3d022fd commit ea85932
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 73 deletions.
213 changes: 141 additions & 72 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
import gooey_gui as gui
import numpy as np
import requests
from app_users.models import AppUser
from daras_ai.image_input import (
get_mimetype_from_response,
safe_filename,
upload_file_from_bytes,
)
from django.db import transaction
from django.db.models import F
from django.db.models import F, Q
from django.utils import timezone
from embeddings.models import EmbeddedFile, EmbeddingsReference
from files.models import FileMetadata
from furl import furl
from loguru import logger
from pydantic import BaseModel, Field

from app_users.models import AppUser
from daras_ai.image_input import (
get_mimetype_from_response,
safe_filename,
upload_file_from_bytes,
)
from daras_ai_v2 import settings
from daras_ai_v2.asr import (
AsrModels,
Expand Down Expand Up @@ -71,6 +69,8 @@
remove_quotes,
)
from daras_ai_v2.text_splitter import Document, text_splitter
from embeddings.models import EmbeddedFile, EmbeddingsReference
from files.models import FileMetadata


class DocSearchRequest(BaseModel):
Expand Down Expand Up @@ -137,9 +137,7 @@ def get_top_k_references(
"""
from recipes.BulkRunner import url_to_runs

yield "Fetching latest knowledge docs..."
input_docs = request.documents or []
check_document_updates = request.check_document_updates

if request.doc_extract_url:
page_cls, sr, pr = url_to_runs(request.doc_extract_url)
Expand All @@ -154,22 +152,34 @@ def get_top_k_references(
EmbeddedFile._meta.get_field("embedding_model").default
),
)
embedded_files: list[EmbeddedFile] = yield from apply_parallel(
lambda f_url: get_or_create_embedded_file(
f_url=f_url,
max_context_words=request.max_context_words,
scroll_jump=request.scroll_jump,
google_translate_target=google_translate_target,
selected_asr_model=selected_asr_model,
embedding_model=embedding_model,
is_user_url=is_user_url,
check_document_updates=check_document_updates,
current_user=current_user,
),
input_docs,
max_workers=4,
message="Fetching latest knowledge docs & Embeddings...",

embedded_files, args_to_create = yield from do_check_document_updates(
input_docs=input_docs,
max_context_words=request.max_context_words,
scroll_jump=request.scroll_jump,
google_translate_target=google_translate_target,
selected_asr_model=selected_asr_model,
embedding_model=embedding_model,
check_document_updates=request.check_document_updates,
)

if args_to_create:
embedded_files += yield from apply_parallel(
lambda args: create_embedded_file(
*args,
max_context_words=request.max_context_words,
scroll_jump=request.scroll_jump,
google_translate_target=google_translate_target,
selected_asr_model=selected_asr_model,
embedding_model=embedding_model,
is_user_url=is_user_url,
current_user=current_user,
),
args_to_create,
max_workers=4,
message="Creating knowledge embeddings...",
)

if not embedded_files:
yield "No embeddings found - skipping search"
return []
Expand Down Expand Up @@ -239,7 +249,7 @@ def query_vespa(
return {"root": {"children": []}}
file_ids_str = ", ".join(map(repr, file_ids))
query = f"select * from {settings.VESPA_SCHEMA} where file_id in (@fileIds) and (userQuery() or ({{targetHits: {limit}}}nearestNeighbor(embedding, q))) limit {limit}"
logger.debug(f"Vespa query: {'-'*80}\n{query}\n{'-'*80}")
logger.debug(f"Vespa query: {query!r}")
if semantic_weight == 1.0:
ranking = "semantic"
elif semantic_weight == 0.0:
Expand Down Expand Up @@ -286,20 +296,22 @@ def doc_or_yt_url_to_file_metas(


def yt_info_to_playlist_metadata(data: dict) -> FileMetadata:
etag = data.get("modified_date") or data.get("playlist_count")
return FileMetadata(
name=data.get("title", "YouTube Playlist"),
# youtube doesn't provide etag, so we use modified_date / playlist_count
etag=data.get("modified_date") or data.get("playlist_count"),
etag=etag and str(etag) or None,
# will be converted later & saved as wav
mime_type="audio/wav",
)


def yt_info_to_video_metadata(data: dict) -> FileMetadata:
etag = data.get("filesize_approx") or data.get("upload_date")
return FileMetadata(
name=data.get("title", "YouTube Video"),
# youtube doesn't provide etag, so we use filesize_approx or upload_date
etag=data.get("filesize_approx") or data.get("upload_date"),
etag=etag and str(etag) or None,
# we will later convert & save as wav
mime_type="audio/wav",
total_bytes=data.get("filesize_approx", 0),
Expand Down Expand Up @@ -328,15 +340,17 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
total_bytes = int(meta.get("size") or 0)
export_links = meta.get("exportLinks", None)
else:
if is_user_uploaded_url(f_url):
kwargs = {}
else:
kwargs = requests_scraping_kwargs() | dict(
timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC
)
try:
if is_user_uploaded_url(f_url):
r = requests.head(f_url)
else:
r = requests.head(
f_url,
timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC,
**requests_scraping_kwargs(),
)
r = requests.head(f_url, **kwargs)
if r.status_code == 405:
r = requests.get(f_url, **kwargs, stream=True)
r.close()
raise_for_status(r)
except requests.RequestException as e:
logger.warning(f"ignore error while downloading {f_url}: {e}")
Expand Down Expand Up @@ -403,52 +417,100 @@ def yt_dlp_extract_info(url: str, **params) -> dict:
return data


def get_or_create_embedded_file(
def do_check_document_updates(
*,
f_url: str,
input_docs: list[str],
max_context_words: int,
scroll_jump: int,
google_translate_target: str | None,
selected_asr_model: str | None,
embedding_model: EmbeddingModels,
is_user_url: bool,
check_document_updates: bool,
) -> typing.Generator[
str,
None,
tuple[
list[EmbeddedFile],
list[tuple[dict, FileMetadata, list[tuple[str, FileMetadata]]]],
],
]:
lookups = {}
q = Q()
for f_url in input_docs:
lookup = dict(
url=f_url,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model.name,
)
q |= Q(**lookup)
lookups[f_url] = lookup

cached_files = {
f.url: f
for f in (
EmbeddedFile.objects.filter(q)
.select_related("metadata")
.order_by("url", "-updated_at")
.distinct("url")
)
}

for f_url in cached_files:
if is_user_uploaded_url(f_url) or not check_document_updates:
lookups.pop(f_url, None)

metadatas = yield from apply_parallel(
doc_or_yt_url_to_file_metas,
lookups.keys(),
message="Fetching latest knowlege docs...",
max_workers=100,
)

args_to_create = []
for (f_url, lookup), (file_meta, leaf_url_metas) in zip(lookups.items(), metadatas):
f = cached_files.get(f_url)
if f and f.metadata == file_meta:
continue
else:
args_to_create.append((lookup, file_meta, leaf_url_metas))

return list(cached_files.values()), args_to_create


def create_embedded_file(
lookup: dict,
file_meta: FileMetadata,
leaf_url_metas: list[tuple[str, FileMetadata]],
*,
max_context_words: int,
scroll_jump: int,
google_translate_target: str | None,
selected_asr_model: str | None,
embedding_model: EmbeddingModels,
is_user_url: bool,
current_user: AppUser,
) -> EmbeddedFile:
"""
Return Vespa document ids and document tags
for a given document url + metadata.
"""
lookup = dict(
url=f_url,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model.name,
)
lock_id = hashlib.sha256(str(lookup).encode()).hexdigest()
lock_id = _sha256(lookup)
with redis_lock(f"gooey/get_or_create_embeddings/v1/{lock_id}"):
try:
embedded_file = EmbeddedFile.objects.filter(**lookup).order_by(
"-updated_at"
)[0]
except IndexError:
embedded_file = None
else:
# skip metadata check for bucket urls (since they are unique & static)
if is_user_uploaded_url(f_url) or not check_document_updates:
return embedded_file

file_meta, leaf_url_metas = doc_or_yt_url_to_file_metas(f_url)
if embedded_file and embedded_file.metadata.astuple() == file_meta.astuple():
# metadata hasn't changed, return existing file
return embedded_file

file_id_fields = lookup | dict(metadata=file_meta.astuple())
file_id = hashlib.sha256(str(file_id_fields).encode()).hexdigest()
# check if embeddings already exist and are up-to-date
f = (
EmbeddedFile.objects.filter(**lookup)
.select_related("metadata")
.order_by("-updated_at")
.first()
)
if f and f.metadata == file_meta:
return f

# create fresh embeddings
file_id = _sha256(lookup | dict(metadata=file_meta.astuple()))
for leaf_url, leaf_meta in leaf_url_metas:
refs = create_embeddings_in_search_db(
f_url=leaf_url,
Expand All @@ -462,12 +524,15 @@ def get_or_create_embedded_file(
is_user_url=is_user_url,
)
with transaction.atomic():
EmbeddedFile.objects.filter(**lookup).delete()
EmbeddedFile.objects.filter(Q(**lookup) | Q(vespa_file_id=file_id)).delete()
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
embedded_file = EmbeddedFile.objects.create(
vespa_file_id=file_id,
defaults=lookup | dict(metadata=file_meta, created_by=current_user),
)[0]
metadata=file_meta,
created_by=current_user,
**lookup,
)
logger.debug(f"created: {embedded_file}")
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(
Expand Down Expand Up @@ -503,7 +568,7 @@ def create_embeddings_in_search_db(
embedding_model=embedding_model,
is_user_url=is_user_url,
):
doc_id = file_id + "/" + hashlib.sha256(str(ref).encode()).hexdigest()
doc_id = file_id + "/" + _sha256(ref)
db_ref = EmbeddingsReference(
vespa_doc_id=doc_id,
url=ref["url"],
Expand All @@ -526,6 +591,10 @@ def create_embeddings_in_search_db(
return list(refs.values())


def _sha256(x) -> str:
return hashlib.sha256(str(x).encode()).hexdigest()


def get_embeds_for_doc(
*,
f_url: str,
Expand Down
11 changes: 11 additions & 0 deletions files/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db import models
from django.template.defaultfilters import filesizeformat
from loguru import logger


class FileMetadata(models.Model):
Expand All @@ -18,6 +19,16 @@ def __str__(self):
ret += f" - {self.etag}"
return ret

def __eq__(self, other):
ret = bool(
isinstance(other, FileMetadata)
# avoid null comparisions -- when metadata is not available and hence not comparable
and (self.etag or other.etag or self.total_bytes or other.total_bytes)
and self.astuple() == other.astuple()
)
logger.debug(f"checking: `{self}` == `{other}` ({ret})")
return ret

def astuple(self) -> tuple:
return self.name, self.etag, self.mime_type, self.total_bytes

Expand Down
5 changes: 4 additions & 1 deletion functions/recipe_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ def call_recipe_functions(
state: dict,
trigger: FunctionTrigger,
) -> typing.Iterable[str]:
tools = list(get_tools_from_state(state, trigger))
if not tools:
return
yield f"Running {trigger.name} hooks..."
for tool in get_tools_from_state(state, trigger):
for tool in tools:
tool.bind(
saved_run=saved_run,
workspace=workspace,
Expand Down

0 comments on commit ea85932

Please sign in to comment.