diff --git a/bots/admin.py b/bots/admin.py index 94eb48997..49321ba35 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -219,14 +219,16 @@ class PublishedRunAdmin(admin.ModelAdmin): "view_user", "open_in_gooey", "linked_saved_run", + "view_runs", "created_at", "updated_at", ] list_filter = ["workflow", "visibility", "created_by__is_paying"] - search_fields = ["workflow", "published_run_id"] + search_fields = ["workflow", "published_run_id", "title", "notes"] autocomplete_fields = ["saved_run", "created_by", "last_edited_by"] readonly_fields = [ "open_in_gooey", + "view_runs", "created_at", "updated_at", ] @@ -243,19 +245,28 @@ def linked_saved_run(self, published_run: PublishedRun): linked_saved_run.short_description = "Linked Run" + @admin.display(description="View Runs") + def view_runs(self, published_run: PublishedRun): + return list_related_html_url( + SavedRun.objects.filter(parent_version__published_run=published_run), + query_param="parent_version__published_run__id__exact", + instance_id=published_run.id, + show_add=False, + ) + @admin.register(SavedRun) class SavedRunAdmin(admin.ModelAdmin): list_display = [ "__str__", - "example_id", "run_id", "view_user", - "created_at", + "open_in_gooey", + "view_parent_published_run", "run_time", - "updated_at", "price", - "preview_input", + "created_at", + "updated_at", ] list_filter = ["workflow"] search_fields = ["workflow", "example_id", "run_id", "uid"] @@ -278,6 +289,11 @@ class SavedRunAdmin(admin.ModelAdmin): django.db.models.JSONField: {"widget": JSONEditorWidget}, } + def lookup_allowed(self, key, value): + if key in ["parent_version__published_run__id__exact"]: + return True + return super().lookup_allowed(key, value) + def view_user(self, saved_run: SavedRun): return change_obj_url( AppUser.objects.get(uid=saved_run.uid), @@ -291,9 +307,10 @@ def view_bots(self, saved_run: SavedRun): view_bots.short_description = "View Bots" - @admin.display(description="Input") - def preview_input(self, saved_run: SavedRun): - return truncate_text_words(BasePage.preview_input(saved_run.state) or "", 100) + @admin.display(description="View Published Run") + def view_parent_published_run(self, saved_run: SavedRun): + pr = saved_run.parent_published_run() + return pr and change_obj_url(pr) @admin.register(PublishedRunVersion) diff --git a/bots/models.py b/bots/models.py index 61c56abc3..0714f187b 100644 --- a/bots/models.py +++ b/bots/models.py @@ -280,7 +280,15 @@ class Meta: ] def __str__(self): - return self.get_app_url() + from daras_ai_v2.breadcrumbs import get_title_breadcrumbs + + title = get_title_breadcrumbs( + Workflow(self.workflow).page_cls, self, self.parent_published_run() + ).h1_title + return title or self.get_app_url() + + def parent_published_run(self) -> "PublishedRun": + return self.parent_version and self.parent_version.published_run def get_app_url(self): workflow = Workflow(self.workflow) diff --git a/daras_ai_v2/all_pages.py b/daras_ai_v2/all_pages.py index f4a317416..096150662 100644 --- a/daras_ai_v2/all_pages.py +++ b/daras_ai_v2/all_pages.py @@ -108,7 +108,7 @@ def normalize_slug(page_slug): normalize_slug(slug): page for page in (all_api_pages + all_hidden_pages) for slug in page.slug_versions -} +} | {str(page.workflow.value): page for page in (all_api_pages + all_hidden_pages)} workflow_map: dict[Workflow, typing.Type[BasePage]] = { page.workflow: page for page in all_api_pages diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 519b70a28..7187178b7 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -923,7 +923,7 @@ def get_runs_from_query_params( ) -> tuple[SavedRun, PublishedRun | None]: if run_id and uid: sr = cls.run_doc_sr(run_id, uid) - pr = (sr and sr.parent_version and sr.parent_version.published_run) or None + pr = sr.parent_published_run() else: pr = cls.get_published_run(published_run_id=example_id or "") sr = pr.saved_run @@ -940,9 +940,7 @@ def get_pr_from_query_params( ) -> PublishedRun | None: if run_id and uid: sr = cls.get_sr_from_query_params(example_id, run_id, uid) - return ( - sr and sr.parent_version and sr.parent_version.published_run - ) or None + return sr.parent_published_run() elif example_id: return cls.get_published_run(published_run_id=example_id) else: diff --git a/daras_ai_v2/functional.py b/daras_ai_v2/functional.py index 625e9c026..f8415ec87 100644 --- a/daras_ai_v2/functional.py +++ b/daras_ai_v2/functional.py @@ -8,8 +8,8 @@ def flatapply_parallel( - fn: typing.Callable[[T], list[R]], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., list[R]], + *iterables, max_workers: int = None, message: str = "", ) -> typing.Generator[str, None, list[R]]: @@ -20,8 +20,8 @@ def flatapply_parallel( def apply_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, message: str = "", ) -> typing.Generator[str, None, list[R]]: @@ -42,8 +42,8 @@ def apply_parallel( def fetch_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, ) -> typing.Generator[R, None, None]: assert iterables, "fetch_parallel() requires at least one iterable" @@ -57,16 +57,16 @@ def fetch_parallel( def flatmap_parallel( - fn: typing.Callable[[T], list[R]], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., list[R]], + *iterables, max_workers: int = None, ) -> list[R]: return flatten(map_parallel(fn, *iterables, max_workers=max_workers)) def map_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, ) -> list[R]: assert iterables, "map_parallel() requires at least one iterable" diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 070aa1c22..f28e3ac3f 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -18,7 +18,11 @@ from django.conf import settings from jinja2.lexer import whitespace_re from loguru import logger -from openai.types.chat import ChatCompletionContentPartParam +from openai import Stream +from openai.types.chat import ( + ChatCompletionContentPartParam, + ChatCompletionChunk, +) from daras_ai_v2.asr import get_google_auth_session from daras_ai_v2.exceptions import raise_for_status @@ -27,7 +31,10 @@ from daras_ai_v2.redis_cache import ( get_redis_cache, ) -from daras_ai_v2.text_splitter import default_length_function +from daras_ai_v2.text_splitter import ( + default_length_function, + default_separators, +) DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible." @@ -38,6 +45,9 @@ CHATML_ROLE_ASSISTANT = "assistant" CHATML_ROLE_USER = "user" +# nice for showing streaming progress +SUPERSCRIPT = str.maketrans("0123456789", "⁰¹²³⁴⁵⁶⁷⁸⁹") + class LLMApis(Enum): vertex_ai = "Vertex AI" @@ -327,8 +337,13 @@ def run_language_model( stop: list[str] = None, avoid_repetition: bool = False, tools: list[LLMTools] = None, + stream: bool = False, response_format_type: typing.Literal["text", "json_object"] = None, -) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]: +) -> ( + list[str] + | tuple[list[str], list[list[dict]]] + | typing.Generator[list[str], None, None] +): assert bool(prompt) != bool( messages ), "Pleave provide exactly one of { prompt, messages }" @@ -336,10 +351,9 @@ def run_language_model( model: LargeLanguageModels = LargeLanguageModels[str(model)] api = llm_api[model] model_name = llm_model_names[model] + is_chatml = False if model.is_chat_model(): - if messages: - is_chatml = False - else: + if not messages: # if input is chatml, convert it into json messages is_chatml, messages = parse_chatml(prompt) # type: ignore messages = messages or [] @@ -349,7 +363,7 @@ def run_language_model( format_chat_entry(role=entry["role"], content=get_entry_text(entry)) for entry in messages ] - result = _run_chat_model( + entries = _run_chat_model( api=api, model=model_name, messages=messages, # type: ignore @@ -360,28 +374,18 @@ def run_language_model( avoid_repetition=avoid_repetition, tools=tools, response_format_type=response_format_type, + # we can't stream with tools or json yet + stream=stream and not (tools or response_format_type), ) - if response_format_type == "json_object": - out_content = [json.loads(entry["content"]) for entry in result] - else: - out_content = [ - # return messages back as either chatml or json messages - ( - format_chatml_message(entry) - if is_chatml - else (entry.get("content") or "").strip() - ) - for entry in result - ] - if tools: - return out_content, [(entry.get("tool_calls") or []) for entry in result] + if stream: + return _stream_llm_outputs(entries, is_chatml, response_format_type, tools) else: - return out_content + return _parse_entries(entries, is_chatml, response_format_type, tools) else: if tools: raise ValueError("Only OpenAI chat models support Tools") logger.info(f"{model_name=}, {len(prompt)=}, {max_tokens=}, {temperature=}") - result = _run_text_model( + msgs = _run_text_model( api=api, model=model_name, prompt=prompt, @@ -392,7 +396,41 @@ def run_language_model( avoid_repetition=avoid_repetition, quality=quality, ) - return [msg.strip() for msg in result] + ret = [msg.strip() for msg in msgs] + if stream: + ret = [ret] + return ret + + +def _stream_llm_outputs(result, is_chatml, response_format_type, tools): + if isinstance(result, list): # compatibility with non-streaming apis + result = [result] + for entries in result: + yield _parse_entries(entries, is_chatml, response_format_type, tools) + + +def _parse_entries( + entries: list[dict], + is_chatml: bool, + response_format_type: typing.Literal["text", "json_object"] | None, + tools: list[dict] | None, +): + if response_format_type == "json_object": + ret = [json.loads(entry["content"]) for entry in entries] + else: + ret = [ + # return messages back as either chatml or json messages + ( + format_chatml_message(entry) + if is_chatml + else (entry.get("content") or "").strip() + ) + for entry in entries + ] + if tools: + return ret, [(entry.get("tool_calls") or []) for entry in entries] + else: + return ret def _run_text_model( @@ -443,7 +481,8 @@ def _run_chat_model( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: + stream: bool = False, +) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]: match api: case LLMApis.openai: return _run_openai_chat( @@ -456,6 +495,7 @@ def _run_chat_model( temperature=temperature, tools=tools, response_format_type=response_format_type, + stream=stream, ) case LLMApis.vertex_ai: if tools: @@ -494,7 +534,8 @@ def _run_openai_chat( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: + stream: bool = False, +) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]: from openai._types import NOT_GIVEN if avoid_repetition: @@ -523,11 +564,72 @@ def _run_openai_chat( if response_format_type else NOT_GIVEN ), + stream=stream, ) for model_str in model ], ) - return [choice.message.dict() for choice in r.choices] + if stream: + return _stream_openai_chunked(r) + else: + return [choice.message.dict() for choice in r.choices] + + +def _stream_openai_chunked( + r: Stream[ChatCompletionChunk], + start_chunk_size: int = 50, + stop_chunk_size: int = 400, + step_chunk_size: int = 150, +): + ret = [] + chunk_size = start_chunk_size + + for completion_chunk in r: + changed = False + for choice in completion_chunk.choices: + try: + entry = ret[choice.index] + except IndexError: + # initialize the entry + entry = choice.delta.dict() | {"content": "", "chunk": ""} + ret.append(entry) + + # append the delta to the current chunk + if not choice.delta.content: + continue + entry["chunk"] += choice.delta.content + # if the chunk is too small, we need to wait for more data + chunk = entry["chunk"] + if len(chunk) < chunk_size: + continue + + # iterate through the separators and find the best one that matches + for sep in default_separators[:-1]: + # find the last occurrence of the separator + match = None + for match in re.finditer(sep, chunk): + pass + if not match: + continue # no match, try the next separator or wait for more data + # append text before the separator to the content + part = chunk[: match.end()] + if len(part) < chunk_size: + continue # not enough text, try the next separator or wait for more data + entry["content"] += part + # set text after the separator as the next chunk + entry["chunk"] = chunk[match.end() :] + # increase the chunk size, but don't go over the max + chunk_size = min(chunk_size + step_chunk_size, stop_chunk_size) + # we found a separator, so we can stop looking and yield the partial result + changed = True + break + if changed: + yield ret + + # add the leftover chunks + for entry in ret: + entry["content"] += entry["chunk"] + yield ret @retry_if(openai_should_retry) diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 1baf64b47..7cc810b1a 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -57,150 +57,178 @@ def render_text_with_refs(text: str, references: list[SearchReference]): return html -def apply_response_template( +def apply_response_formattings_prefix( output_text: list[str], references: list[SearchReference], citation_style: CitationStyles | None = CitationStyles.number, -): +) -> list[dict[int, SearchReference]]: + all_refs_list = [{}] * len(output_text) for i, text in enumerate(output_text): - formatted = "" - all_refs = {} - - for snippet, ref_map in parse_refs(text, references): - match citation_style: - case CitationStyles.number | CitationStyles.number_plaintext: - cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) - case CitationStyles.title: - cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) - case CitationStyles.url: - cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) - case CitationStyles.symbol | CitationStyles.symbol_plaintext: - cites = " ".join( - f"[{generate_footnote_symbol(ref_num - 1)}]" - for ref_num in ref_map.keys() - ) + all_refs_list[i], output_text[i] = format_citations( + text, references, citation_style + ) + return all_refs_list - case CitationStyles.markdown: - cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) - case CitationStyles.html: - cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) - case CitationStyles.slack_mrkdwn: - cites = " ".join( - ref_to_slack_mrkdwn(ref) for ref in ref_map.values() - ) - case CitationStyles.plaintext: - cites = " ".join( - f'[{ref["title"]} {ref["url"]}]' - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_markdown: - cites = " ".join( - markdown_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_html: - cites = " ".join( - html_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) +def apply_response_formattings_suffix( + all_refs_list: list[dict[int, SearchReference]], + output_text: list[str], + citation_style: CitationStyles | None = CitationStyles.number, +): + for i, text in enumerate(output_text): + output_text[i] = format_jinja_response_template( + all_refs_list[i], + format_footnotes(all_refs_list[i], text, citation_style), + ) - case CitationStyles.symbol_markdown: - cites = " ".join( - markdown_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_html: - cites = " ".join( - html_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case None: - cites = "" - case _: - raise ValueError(f"Unknown citation style: {citation_style}") - formatted += snippet + " " + cites + " " - all_refs.update(ref_map) +def format_citations( + text: str, + references: list[SearchReference], + citation_style: CitationStyles | None = CitationStyles.number, +) -> tuple[dict[int, SearchReference], str]: + all_refs = {} + formatted = "" + for snippet, ref_map in parse_refs(text, references): match citation_style: + case CitationStyles.number | CitationStyles.number_plaintext: + cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) + case CitationStyles.title: + cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) + case CitationStyles.url: + cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) + case CitationStyles.symbol | CitationStyles.symbol_plaintext: + cites = " ".join( + f"[{generate_footnote_symbol(ref_num - 1)}]" + for ref_num in ref_map.keys() + ) + + case CitationStyles.markdown: + cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) + case CitationStyles.html: + cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) + case CitationStyles.slack_mrkdwn: + cites = " ".join(ref_to_slack_mrkdwn(ref) for ref in ref_map.values()) + case CitationStyles.plaintext: + cites = " ".join( + f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items() + ) + case CitationStyles.number_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_html: - formatted += "

" - formatted += "
".join( - f"[{ref_num}] {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.number_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{ref_num}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_html: - formatted += "

" - formatted += "
".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.symbol_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) - ) - - for ref_num, ref in all_refs.items(): - try: - template = ref["response_template"] - except KeyError: - pass - else: - formatted = jinja2.Template(template).render( - **ref, - output_text=formatted, - ref_num=ref_num, + cites = " ".join( + slack_mrkdwn_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) - output_text[i] = formatted + case None: + cites = "" + case _: + raise ValueError(f"Unknown citation style: {citation_style}") + formatted += " ".join(filter(None, [snippet, cites])) + all_refs.update(ref_map) + return all_refs, formatted + + +def format_footnotes( + all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles +) -> str: + match citation_style: + case CitationStyles.number_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_html: + formatted += "

" + formatted += "
".join( + f"[{ref_num}] {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{ref_num}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + + case CitationStyles.symbol_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_html: + formatted += "

" + formatted += "
".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + return formatted + + +def format_jinja_response_template( + all_refs: dict[int, SearchReference], formatted: str +) -> str: + for ref_num, ref in all_refs.items(): + try: + template = ref["response_template"] + except KeyError: + pass + else: + formatted = jinja2.Template(template).render( + **ref, + output_text=formatted, + ref_num=ref_num, + ) + return formatted search_ref_pat = re.compile(r"\[" r"[\d\s\.\,\[\]\$\{\}]+" r"\]") diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 5cf08c721..394d55b55 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -12,6 +12,7 @@ import requests from furl import furl from googleapiclient.errors import HttpError +from loguru import logger from pydantic import BaseModel, Field from rank_bm25 import BM25Okapi @@ -32,7 +33,7 @@ ) from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS -from daras_ai_v2.functional import flatmap_parallel +from daras_ai_v2.functional import flatmap_parallel, map_parallel from daras_ai_v2.gdrive_downloader import ( gdrive_download, is_gdrive_url, @@ -87,20 +88,23 @@ def get_top_k_references( Returns: the top k documents """ - yield "Getting embeddings..." + yield "Checking docs..." input_docs = request.documents or [] + doc_metas = map_parallel(doc_url_to_metadata, input_docs) + + yield "Getting embeddings..." embeds: list[tuple[SearchReference, np.ndarray]] = flatmap_parallel( - lambda f_url: doc_url_to_embeds( + lambda f_url, doc_meta: get_embeds_for_doc( f_url=f_url, + doc_meta=doc_meta, max_context_words=request.max_context_words, scroll_jump=request.scroll_jump, selected_asr_model=request.selected_asr_model, google_translate_target=request.google_translate_target, ), input_docs, - max_workers=10, + doc_metas, ) - dense_query_embeds = openai_embedding_create([request.search_query])[0] yield "Searching documents..." @@ -129,12 +133,21 @@ def get_top_k_references( dense_ranks = np.zeros(len(embeds)) if sparse_weight: + yield "Getting sparse scores..." # get sparse scores - tokenized_corpus = [ - bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) - for ref, _ in embeds - ] - bm25 = BM25Okapi(tokenized_corpus, k1=2, b=0.3) + bm25_corpus = flatmap_parallel( + lambda f_url, doc_meta: get_bm25_embeds_for_doc( + f_url=f_url, + doc_meta=doc_meta, + max_context_words=request.max_context_words, + scroll_jump=request.scroll_jump, + selected_asr_model=request.selected_asr_model, + google_translate_target=request.google_translate_target, + ), + input_docs, + doc_metas, + ) + bm25 = BM25Okapi(bm25_corpus, k1=2, b=0.3) if request.keyword_query and isinstance(request.keyword_query, list): sparse_query_tokenized = [item.lower() for item in request.keyword_query] else: @@ -150,6 +163,13 @@ def get_top_k_references( else: sparse_ranks = np.zeros(len(embeds)) + # just in case sparse and dense ranks are different lengths, truncate to the shorter one + if len(sparse_ranks) != len(dense_ranks): + logger.warning( + f"sparse and dense ranks are different lengths, truncating... {len(sparse_ranks)=} {len(dense_ranks)=} {len(embeds)=}" + ) + sparse_ranks = sparse_ranks[: len(dense_ranks)] + dense_ranks = dense_ranks[: len(sparse_ranks)] # RRF formula: 1 / (k + rank) k = 60 rrf_scores = ( @@ -159,11 +179,7 @@ def get_top_k_references( # Final ranking max_references = min(request.max_references, len(rrf_scores)) top_k = np.argpartition(rrf_scores, -max_references)[-max_references:] - final_ranks = sorted( - top_k, - key=lambda idx: rrf_scores[idx], - reverse=True, - ) + final_ranks = sorted(top_k, key=rrf_scores.__getitem__, reverse=True) references = [embeds[idx][0] | {"score": rrf_scores[idx]} for idx in final_ranks] @@ -211,38 +227,6 @@ def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str: ) -def doc_url_to_embeds( - *, - f_url: str, - max_context_words: int, - scroll_jump: int, - selected_asr_model: str = None, - google_translate_target: str = None, -) -> list[tuple[SearchReference, np.ndarray]]: - """ - Get document embeddings for a given document url. - - Args: - f_url: document url - max_context_words: max number of words to include in each chunk - scroll_jump: number of words to scroll by - google_translate_target: target language for google translate - selected_asr_model: selected ASR model (used for audio files) - - Returns: - list of (SearchReference, embeddings vector) tuples - """ - doc_meta = doc_url_to_metadata(f_url) - return get_embeds_for_doc( - f_url=f_url, - doc_meta=doc_meta, - max_context_words=max_context_words, - scroll_jump=scroll_jump, - selected_asr_model=selected_asr_model, - google_translate_target=google_translate_target, - ) - - class DocMetadata(typing.NamedTuple): name: str etag: str | None @@ -321,6 +305,35 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: ) +@redis_cache_decorator +def get_bm25_embeds_for_doc( + *, + f_url: str, + doc_meta: DocMetadata, + max_context_words: int, + scroll_jump: int, + google_translate_target: str = None, + selected_asr_model: str = None, +): + pages = doc_url_to_text_pages( + f_url=f_url, + doc_meta=doc_meta, + selected_asr_model=selected_asr_model, + google_translate_target=google_translate_target, + ) + refs = pages_to_split_refs( + pages=pages, + f_url=f_url, + doc_meta=doc_meta, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + ) + tokenized_corpus = [ + bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) for ref in refs + ] + return tokenized_corpus + + @redis_cache_decorator def get_embeds_for_doc( *, @@ -345,18 +358,44 @@ def get_embeds_for_doc( Returns: list of (metadata, embeddings) tuples """ - import pandas as pd - pages = doc_url_to_text_pages( f_url=f_url, doc_meta=doc_meta, selected_asr_model=selected_asr_model, google_translate_target=google_translate_target, ) + refs = pages_to_split_refs( + pages=pages, + f_url=f_url, + doc_meta=doc_meta, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + ) + texts = [m["title"] + " | " + m["snippet"] for m in refs] + # get doc embeds in batches + batch_size = 16 # azure openai limits + embeds = flatmap_parallel( + openai_embedding_create, + [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)], + max_workers=2, + ) + return list(zip(refs, embeds)) + + +def pages_to_split_refs( + *, + pages, + f_url: str, + doc_meta: DocMetadata, + max_context_words: int, + scroll_jump: int, +) -> list[SearchReference]: + import pandas as pd + chunk_size = int(max_context_words * 2) chunk_overlap = int(max_context_words * 2 / scroll_jump) if isinstance(pages, pd.DataFrame): - metas = [] + refs = [] # treat each row as a separate document for idx, row in pages.iterrows(): row = dict(row) @@ -372,7 +411,7 @@ def get_embeds_for_doc( ) else: continue - metas += [ + refs += [ { "title": doc_meta.name, "url": f_url, @@ -385,7 +424,7 @@ def get_embeds_for_doc( ] else: # split the text into chunks - metas = [ + refs = [ { "title": ( doc_meta.name + (f", page {doc.end + 1}" if len(pages) > 1 else "") @@ -403,15 +442,7 @@ def get_embeds_for_doc( pages, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) ] - # get doc embeds in batches - batch_size = 16 # azure openai limits - texts = [m["title"] + " | " + m["snippet"] for m in metas] - embeds = flatmap_parallel( - openai_embedding_create, - [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)], - max_workers=5, - ) - return list(zip(metas, embeds)) + return refs sections_re = re.compile(r"(\s*[\r\n\f\v]|^)(\w+)\=", re.MULTILINE) diff --git a/gooey_ui/pubsub.py b/gooey_ui/pubsub.py index ca3dfcb1a..5e210c956 100644 --- a/gooey_ui/pubsub.py +++ b/gooey_ui/pubsub.py @@ -42,7 +42,11 @@ def realtime_push(channel: str, value: typing.Any = "ping"): msg = json.dumps(jsonable_encoder(value)) r.set(channel, msg) r.publish(channel, json.dumps(time())) - logger.info(f"publish {channel=}") + if isinstance(value, dict): + run_status = value.get("__run_status") + logger.info(f"publish {channel=} {run_status=}") + else: + logger.info(f"publish {channel=}") # def use_state( diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index f1be5a937..eac5ef885 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -1,10 +1,9 @@ import random import typing - -import gooey_ui as st from pydantic import BaseModel +import gooey_ui as st from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -12,6 +11,7 @@ run_language_model, LargeLanguageModels, llm_price, + SUPERSCRIPT, ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.loom_video_widget import youtube_video @@ -94,9 +94,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_text"] = output_text = {} for selected_model in request.selected_models: - yield f"Running {LargeLanguageModels[selected_model].value}..." - - output_text[selected_model] = run_language_model( + model = LargeLanguageModels[selected_model] + yield f"Running {model.value}..." + ret = run_language_model( model=selected_model, quality=request.quality, num_outputs=request.num_outputs, @@ -104,7 +104,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]: prompt=prompt, max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, + stream=True, ) + for i, item in enumerate(ret): + output_text[selected_model] = item + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." def render_output(self): self._render_outputs(st.session_state, 450) diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index d18b5a54c..0b95464e1 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -23,8 +23,9 @@ from daras_ai_v2.search_ref import ( SearchReference, render_output_with_refs, - apply_response_template, CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, ) from daras_ai_v2.vector_search import ( DocSearchRequest, @@ -194,9 +195,12 @@ def run_v2( citation_style = ( request.citation_style and CitationStyles[request.citation_style] ) or None - apply_response_template( + all_refs_list = apply_response_formattings_prefix( response.output_text, response.references, citation_style ) + apply_response_formattings_suffix( + all_refs_list, response.output_text, citation_style + ) def get_raw_price(self, state: dict) -> float: name = state.get("selected_model") diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index faa88d533..20eaa8537 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -50,6 +50,7 @@ get_entry_images, get_entry_text, format_chat_entry, + SUPERSCRIPT, ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.lipsync_settings_widgets import lipsync_settings @@ -58,7 +59,12 @@ from daras_ai_v2.query_generator import generate_final_search_query from daras_ai_v2.query_params import gooey_get_query_params from daras_ai_v2.query_params_util import extract_query_params -from daras_ai_v2.search_ref import apply_response_template, parse_refs, CitationStyles +from daras_ai_v2.search_ref import ( + parse_refs, + CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, +) from daras_ai_v2.text_output_widget import text_output from daras_ai_v2.text_to_speech_settings_widgets import ( TextToSpeechProviders, @@ -805,7 +811,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield f"Running {model.value}..." if is_chat_model: - output_text = run_language_model( + chunks = run_language_model( model=request.selected_model, messages=[ {"role": s["role"], "content": s["content"]} @@ -816,12 +822,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]: temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, tools=request.tools, + stream=True, ) else: prompt = "\n".join( format_chatml_message(entry) for entry in prompt_messages ) - output_text = run_language_model( + chunks = run_language_model( model=request.selected_model, prompt=prompt, max_tokens=max_allowed_tokens, @@ -830,43 +837,52 @@ def run(self, state: dict) -> typing.Iterator[str | None]: temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, stop=[CHATML_END_TOKEN, CHATML_START_TOKEN], + stream=True, ) - if request.tools: - output_text, tool_call_choices = output_text - state["output_documents"] = output_documents = [] - for tool_calls in tool_call_choices: - for call in tool_calls: - result = yield from exec_tool_call(call) - output_documents.append(result) - - # save model response - state["raw_output_text"] = [ - "".join(snippet for snippet, _ in parse_refs(text, references)) - for text in output_text - ] - - # translate response text - if request.user_language and request.user_language != "en": - yield f"Translating response to {request.user_language}..." - output_text = run_google_translate( - texts=output_text, - source_language="en", - target_language=request.user_language, - glossary_url=request.output_glossary_document, - ) - state["raw_tts_text"] = [ + citation_style = ( + request.citation_style and CitationStyles[request.citation_style] + ) or None + all_refs_list = [] + for i, output_text in enumerate(chunks): + if request.tools: + output_text, tool_call_choices = output_text + state["output_documents"] = output_documents = [] + for tool_calls in tool_call_choices: + for call in tool_calls: + result = yield from exec_tool_call(call) + output_documents.append(result) + + # save model response + state["raw_output_text"] = [ "".join(snippet for snippet, _ in parse_refs(text, references)) for text in output_text ] - if references: - citation_style = ( - request.citation_style and CitationStyles[request.citation_style] - ) or None - apply_response_template(output_text, references, citation_style) - - state["output_text"] = output_text + # translate response text + if request.user_language and request.user_language != "en": + yield f"Translating response to {request.user_language}..." + output_text = run_google_translate( + texts=output_text, + source_language="en", + target_language=request.user_language, + glossary_url=request.output_glossary_document, + ) + state["raw_tts_text"] = [ + "".join(snippet for snippet, _ in parse_refs(text, references)) + for text in output_text + ] + + if references: + all_refs_list = apply_response_formattings_prefix( + output_text, references, citation_style + ) + state["output_text"] = output_text + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." + if all_refs_list: + apply_response_formattings_suffix( + all_refs_list, state["output_text"], citation_style + ) state["output_audio"] = [] state["output_video"] = [] diff --git a/url_shortener/admin.py b/url_shortener/admin.py index f4d6bb4c0..e7f102a7a 100644 --- a/url_shortener/admin.py +++ b/url_shortener/admin.py @@ -137,3 +137,6 @@ class VisitorClickInfoAdmin(admin.ModelAdmin): ordering = ["-created_at"] autocomplete_fields = ["shortened_url"] actions = [export_to_csv, export_to_excel] + readonly_fields = [ + "created_at", + ] diff --git a/url_shortener/routers.py b/url_shortener/routers.py index fba39ffa9..1b3991cf8 100644 --- a/url_shortener/routers.py +++ b/url_shortener/routers.py @@ -21,10 +21,11 @@ def url_shortener(hashid: str, request: Request): return Response(status_code=410, content="This link has expired") # increment the click count ShortenedURL.objects.filter(id=surl.id).update(clicks=F("clicks") + 1) - if surl.enable_analytics: - save_click_info.delay( - surl.id, request.client.host, request.headers.get("user-agent", "") - ) + # disable because iplist.cc is down + # if surl.enable_analytics: + # save_click_info.delay( + # surl.id, request.client.host, request.headers.get("user-agent", "") + # ) if surl.url: return RedirectResponse( url=surl.url, status_code=303 # because youtu.be redirects are 303