diff --git a/bots/admin.py b/bots/admin.py
index 94eb48997..49321ba35 100644
--- a/bots/admin.py
+++ b/bots/admin.py
@@ -219,14 +219,16 @@ class PublishedRunAdmin(admin.ModelAdmin):
"view_user",
"open_in_gooey",
"linked_saved_run",
+ "view_runs",
"created_at",
"updated_at",
]
list_filter = ["workflow", "visibility", "created_by__is_paying"]
- search_fields = ["workflow", "published_run_id"]
+ search_fields = ["workflow", "published_run_id", "title", "notes"]
autocomplete_fields = ["saved_run", "created_by", "last_edited_by"]
readonly_fields = [
"open_in_gooey",
+ "view_runs",
"created_at",
"updated_at",
]
@@ -243,19 +245,28 @@ def linked_saved_run(self, published_run: PublishedRun):
linked_saved_run.short_description = "Linked Run"
+ @admin.display(description="View Runs")
+ def view_runs(self, published_run: PublishedRun):
+ return list_related_html_url(
+ SavedRun.objects.filter(parent_version__published_run=published_run),
+ query_param="parent_version__published_run__id__exact",
+ instance_id=published_run.id,
+ show_add=False,
+ )
+
@admin.register(SavedRun)
class SavedRunAdmin(admin.ModelAdmin):
list_display = [
"__str__",
- "example_id",
"run_id",
"view_user",
- "created_at",
+ "open_in_gooey",
+ "view_parent_published_run",
"run_time",
- "updated_at",
"price",
- "preview_input",
+ "created_at",
+ "updated_at",
]
list_filter = ["workflow"]
search_fields = ["workflow", "example_id", "run_id", "uid"]
@@ -278,6 +289,11 @@ class SavedRunAdmin(admin.ModelAdmin):
django.db.models.JSONField: {"widget": JSONEditorWidget},
}
+ def lookup_allowed(self, key, value):
+ if key in ["parent_version__published_run__id__exact"]:
+ return True
+ return super().lookup_allowed(key, value)
+
def view_user(self, saved_run: SavedRun):
return change_obj_url(
AppUser.objects.get(uid=saved_run.uid),
@@ -291,9 +307,10 @@ def view_bots(self, saved_run: SavedRun):
view_bots.short_description = "View Bots"
- @admin.display(description="Input")
- def preview_input(self, saved_run: SavedRun):
- return truncate_text_words(BasePage.preview_input(saved_run.state) or "", 100)
+ @admin.display(description="View Published Run")
+ def view_parent_published_run(self, saved_run: SavedRun):
+ pr = saved_run.parent_published_run()
+ return pr and change_obj_url(pr)
@admin.register(PublishedRunVersion)
diff --git a/bots/models.py b/bots/models.py
index 61c56abc3..0714f187b 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -280,7 +280,15 @@ class Meta:
]
def __str__(self):
- return self.get_app_url()
+ from daras_ai_v2.breadcrumbs import get_title_breadcrumbs
+
+ title = get_title_breadcrumbs(
+ Workflow(self.workflow).page_cls, self, self.parent_published_run()
+ ).h1_title
+ return title or self.get_app_url()
+
+ def parent_published_run(self) -> "PublishedRun":
+ return self.parent_version and self.parent_version.published_run
def get_app_url(self):
workflow = Workflow(self.workflow)
diff --git a/daras_ai_v2/all_pages.py b/daras_ai_v2/all_pages.py
index f4a317416..096150662 100644
--- a/daras_ai_v2/all_pages.py
+++ b/daras_ai_v2/all_pages.py
@@ -108,7 +108,7 @@ def normalize_slug(page_slug):
normalize_slug(slug): page
for page in (all_api_pages + all_hidden_pages)
for slug in page.slug_versions
-}
+} | {str(page.workflow.value): page for page in (all_api_pages + all_hidden_pages)}
workflow_map: dict[Workflow, typing.Type[BasePage]] = {
page.workflow: page for page in all_api_pages
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 519b70a28..7187178b7 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -923,7 +923,7 @@ def get_runs_from_query_params(
) -> tuple[SavedRun, PublishedRun | None]:
if run_id and uid:
sr = cls.run_doc_sr(run_id, uid)
- pr = (sr and sr.parent_version and sr.parent_version.published_run) or None
+ pr = sr.parent_published_run()
else:
pr = cls.get_published_run(published_run_id=example_id or "")
sr = pr.saved_run
@@ -940,9 +940,7 @@ def get_pr_from_query_params(
) -> PublishedRun | None:
if run_id and uid:
sr = cls.get_sr_from_query_params(example_id, run_id, uid)
- return (
- sr and sr.parent_version and sr.parent_version.published_run
- ) or None
+ return sr.parent_published_run()
elif example_id:
return cls.get_published_run(published_run_id=example_id)
else:
diff --git a/daras_ai_v2/functional.py b/daras_ai_v2/functional.py
index 625e9c026..f8415ec87 100644
--- a/daras_ai_v2/functional.py
+++ b/daras_ai_v2/functional.py
@@ -8,8 +8,8 @@
def flatapply_parallel(
- fn: typing.Callable[[T], list[R]],
- *iterables: typing.Sequence[T],
+ fn: typing.Callable[..., list[R]],
+ *iterables,
max_workers: int = None,
message: str = "",
) -> typing.Generator[str, None, list[R]]:
@@ -20,8 +20,8 @@ def flatapply_parallel(
def apply_parallel(
- fn: typing.Callable[[T], R],
- *iterables: typing.Sequence[T],
+ fn: typing.Callable[..., R],
+ *iterables,
max_workers: int = None,
message: str = "",
) -> typing.Generator[str, None, list[R]]:
@@ -42,8 +42,8 @@ def apply_parallel(
def fetch_parallel(
- fn: typing.Callable[[T], R],
- *iterables: typing.Sequence[T],
+ fn: typing.Callable[..., R],
+ *iterables,
max_workers: int = None,
) -> typing.Generator[R, None, None]:
assert iterables, "fetch_parallel() requires at least one iterable"
@@ -57,16 +57,16 @@ def fetch_parallel(
def flatmap_parallel(
- fn: typing.Callable[[T], list[R]],
- *iterables: typing.Sequence[T],
+ fn: typing.Callable[..., list[R]],
+ *iterables,
max_workers: int = None,
) -> list[R]:
return flatten(map_parallel(fn, *iterables, max_workers=max_workers))
def map_parallel(
- fn: typing.Callable[[T], R],
- *iterables: typing.Sequence[T],
+ fn: typing.Callable[..., R],
+ *iterables,
max_workers: int = None,
) -> list[R]:
assert iterables, "map_parallel() requires at least one iterable"
diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py
index 070aa1c22..f28e3ac3f 100644
--- a/daras_ai_v2/language_model.py
+++ b/daras_ai_v2/language_model.py
@@ -18,7 +18,11 @@
from django.conf import settings
from jinja2.lexer import whitespace_re
from loguru import logger
-from openai.types.chat import ChatCompletionContentPartParam
+from openai import Stream
+from openai.types.chat import (
+ ChatCompletionContentPartParam,
+ ChatCompletionChunk,
+)
from daras_ai_v2.asr import get_google_auth_session
from daras_ai_v2.exceptions import raise_for_status
@@ -27,7 +31,10 @@
from daras_ai_v2.redis_cache import (
get_redis_cache,
)
-from daras_ai_v2.text_splitter import default_length_function
+from daras_ai_v2.text_splitter import (
+ default_length_function,
+ default_separators,
+)
DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible."
@@ -38,6 +45,9 @@
CHATML_ROLE_ASSISTANT = "assistant"
CHATML_ROLE_USER = "user"
+# nice for showing streaming progress
+SUPERSCRIPT = str.maketrans("0123456789", "⁰¹²³⁴⁵⁶⁷⁸⁹")
+
class LLMApis(Enum):
vertex_ai = "Vertex AI"
@@ -327,8 +337,13 @@ def run_language_model(
stop: list[str] = None,
avoid_repetition: bool = False,
tools: list[LLMTools] = None,
+ stream: bool = False,
response_format_type: typing.Literal["text", "json_object"] = None,
-) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]:
+) -> (
+ list[str]
+ | tuple[list[str], list[list[dict]]]
+ | typing.Generator[list[str], None, None]
+):
assert bool(prompt) != bool(
messages
), "Pleave provide exactly one of { prompt, messages }"
@@ -336,10 +351,9 @@ def run_language_model(
model: LargeLanguageModels = LargeLanguageModels[str(model)]
api = llm_api[model]
model_name = llm_model_names[model]
+ is_chatml = False
if model.is_chat_model():
- if messages:
- is_chatml = False
- else:
+ if not messages:
# if input is chatml, convert it into json messages
is_chatml, messages = parse_chatml(prompt) # type: ignore
messages = messages or []
@@ -349,7 +363,7 @@ def run_language_model(
format_chat_entry(role=entry["role"], content=get_entry_text(entry))
for entry in messages
]
- result = _run_chat_model(
+ entries = _run_chat_model(
api=api,
model=model_name,
messages=messages, # type: ignore
@@ -360,28 +374,18 @@ def run_language_model(
avoid_repetition=avoid_repetition,
tools=tools,
response_format_type=response_format_type,
+ # we can't stream with tools or json yet
+ stream=stream and not (tools or response_format_type),
)
- if response_format_type == "json_object":
- out_content = [json.loads(entry["content"]) for entry in result]
- else:
- out_content = [
- # return messages back as either chatml or json messages
- (
- format_chatml_message(entry)
- if is_chatml
- else (entry.get("content") or "").strip()
- )
- for entry in result
- ]
- if tools:
- return out_content, [(entry.get("tool_calls") or []) for entry in result]
+ if stream:
+ return _stream_llm_outputs(entries, is_chatml, response_format_type, tools)
else:
- return out_content
+ return _parse_entries(entries, is_chatml, response_format_type, tools)
else:
if tools:
raise ValueError("Only OpenAI chat models support Tools")
logger.info(f"{model_name=}, {len(prompt)=}, {max_tokens=}, {temperature=}")
- result = _run_text_model(
+ msgs = _run_text_model(
api=api,
model=model_name,
prompt=prompt,
@@ -392,7 +396,41 @@ def run_language_model(
avoid_repetition=avoid_repetition,
quality=quality,
)
- return [msg.strip() for msg in result]
+ ret = [msg.strip() for msg in msgs]
+ if stream:
+ ret = [ret]
+ return ret
+
+
+def _stream_llm_outputs(result, is_chatml, response_format_type, tools):
+ if isinstance(result, list): # compatibility with non-streaming apis
+ result = [result]
+ for entries in result:
+ yield _parse_entries(entries, is_chatml, response_format_type, tools)
+
+
+def _parse_entries(
+ entries: list[dict],
+ is_chatml: bool,
+ response_format_type: typing.Literal["text", "json_object"] | None,
+ tools: list[dict] | None,
+):
+ if response_format_type == "json_object":
+ ret = [json.loads(entry["content"]) for entry in entries]
+ else:
+ ret = [
+ # return messages back as either chatml or json messages
+ (
+ format_chatml_message(entry)
+ if is_chatml
+ else (entry.get("content") or "").strip()
+ )
+ for entry in entries
+ ]
+ if tools:
+ return ret, [(entry.get("tool_calls") or []) for entry in entries]
+ else:
+ return ret
def _run_text_model(
@@ -443,7 +481,8 @@ def _run_chat_model(
avoid_repetition: bool,
tools: list[LLMTools] | None,
response_format_type: typing.Literal["text", "json_object"] | None,
-) -> list[ConversationEntry]:
+ stream: bool = False,
+) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]:
match api:
case LLMApis.openai:
return _run_openai_chat(
@@ -456,6 +495,7 @@ def _run_chat_model(
temperature=temperature,
tools=tools,
response_format_type=response_format_type,
+ stream=stream,
)
case LLMApis.vertex_ai:
if tools:
@@ -494,7 +534,8 @@ def _run_openai_chat(
avoid_repetition: bool,
tools: list[LLMTools] | None,
response_format_type: typing.Literal["text", "json_object"] | None,
-) -> list[ConversationEntry]:
+ stream: bool = False,
+) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]:
from openai._types import NOT_GIVEN
if avoid_repetition:
@@ -523,11 +564,72 @@ def _run_openai_chat(
if response_format_type
else NOT_GIVEN
),
+ stream=stream,
)
for model_str in model
],
)
- return [choice.message.dict() for choice in r.choices]
+ if stream:
+ return _stream_openai_chunked(r)
+ else:
+ return [choice.message.dict() for choice in r.choices]
+
+
+def _stream_openai_chunked(
+ r: Stream[ChatCompletionChunk],
+ start_chunk_size: int = 50,
+ stop_chunk_size: int = 400,
+ step_chunk_size: int = 150,
+):
+ ret = []
+ chunk_size = start_chunk_size
+
+ for completion_chunk in r:
+ changed = False
+ for choice in completion_chunk.choices:
+ try:
+ entry = ret[choice.index]
+ except IndexError:
+ # initialize the entry
+ entry = choice.delta.dict() | {"content": "", "chunk": ""}
+ ret.append(entry)
+
+ # append the delta to the current chunk
+ if not choice.delta.content:
+ continue
+ entry["chunk"] += choice.delta.content
+ # if the chunk is too small, we need to wait for more data
+ chunk = entry["chunk"]
+ if len(chunk) < chunk_size:
+ continue
+
+ # iterate through the separators and find the best one that matches
+ for sep in default_separators[:-1]:
+ # find the last occurrence of the separator
+ match = None
+ for match in re.finditer(sep, chunk):
+ pass
+ if not match:
+ continue # no match, try the next separator or wait for more data
+ # append text before the separator to the content
+ part = chunk[: match.end()]
+ if len(part) < chunk_size:
+ continue # not enough text, try the next separator or wait for more data
+ entry["content"] += part
+ # set text after the separator as the next chunk
+ entry["chunk"] = chunk[match.end() :]
+ # increase the chunk size, but don't go over the max
+ chunk_size = min(chunk_size + step_chunk_size, stop_chunk_size)
+ # we found a separator, so we can stop looking and yield the partial result
+ changed = True
+ break
+ if changed:
+ yield ret
+
+ # add the leftover chunks
+ for entry in ret:
+ entry["content"] += entry["chunk"]
+ yield ret
@retry_if(openai_should_retry)
diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py
index 1baf64b47..7cc810b1a 100644
--- a/daras_ai_v2/search_ref.py
+++ b/daras_ai_v2/search_ref.py
@@ -57,150 +57,178 @@ def render_text_with_refs(text: str, references: list[SearchReference]):
return html
-def apply_response_template(
+def apply_response_formattings_prefix(
output_text: list[str],
references: list[SearchReference],
citation_style: CitationStyles | None = CitationStyles.number,
-):
+) -> list[dict[int, SearchReference]]:
+ all_refs_list = [{}] * len(output_text)
for i, text in enumerate(output_text):
- formatted = ""
- all_refs = {}
-
- for snippet, ref_map in parse_refs(text, references):
- match citation_style:
- case CitationStyles.number | CitationStyles.number_plaintext:
- cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
- case CitationStyles.title:
- cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
- case CitationStyles.url:
- cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
- case CitationStyles.symbol | CitationStyles.symbol_plaintext:
- cites = " ".join(
- f"[{generate_footnote_symbol(ref_num - 1)}]"
- for ref_num in ref_map.keys()
- )
+ all_refs_list[i], output_text[i] = format_citations(
+ text, references, citation_style
+ )
+ return all_refs_list
- case CitationStyles.markdown:
- cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
- case CitationStyles.html:
- cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
- case CitationStyles.slack_mrkdwn:
- cites = " ".join(
- ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
- )
- case CitationStyles.plaintext:
- cites = " ".join(
- f'[{ref["title"]} {ref["url"]}]'
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.number_markdown:
- cites = " ".join(
- markdown_link(f"[{ref_num}]", ref["url"])
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.number_html:
- cites = " ".join(
- html_link(f"[{ref_num}]", ref["url"])
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.number_slack_mrkdwn:
- cites = " ".join(
- slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
- for ref_num, ref in ref_map.items()
- )
+def apply_response_formattings_suffix(
+ all_refs_list: list[dict[int, SearchReference]],
+ output_text: list[str],
+ citation_style: CitationStyles | None = CitationStyles.number,
+):
+ for i, text in enumerate(output_text):
+ output_text[i] = format_jinja_response_template(
+ all_refs_list[i],
+ format_footnotes(all_refs_list[i], text, citation_style),
+ )
- case CitationStyles.symbol_markdown:
- cites = " ".join(
- markdown_link(
- f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
- )
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.symbol_html:
- cites = " ".join(
- html_link(
- f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
- )
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.symbol_slack_mrkdwn:
- cites = " ".join(
- slack_mrkdwn_link(
- f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
- )
- for ref_num, ref in ref_map.items()
- )
- case None:
- cites = ""
- case _:
- raise ValueError(f"Unknown citation style: {citation_style}")
- formatted += snippet + " " + cites + " "
- all_refs.update(ref_map)
+def format_citations(
+ text: str,
+ references: list[SearchReference],
+ citation_style: CitationStyles | None = CitationStyles.number,
+) -> tuple[dict[int, SearchReference], str]:
+ all_refs = {}
+ formatted = ""
+ for snippet, ref_map in parse_refs(text, references):
match citation_style:
+ case CitationStyles.number | CitationStyles.number_plaintext:
+ cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
+ case CitationStyles.title:
+ cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
+ case CitationStyles.url:
+ cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
+ case CitationStyles.symbol | CitationStyles.symbol_plaintext:
+ cites = " ".join(
+ f"[{generate_footnote_symbol(ref_num - 1)}]"
+ for ref_num in ref_map.keys()
+ )
+
+ case CitationStyles.markdown:
+ cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
+ case CitationStyles.html:
+ cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
+ case CitationStyles.slack_mrkdwn:
+ cites = " ".join(ref_to_slack_mrkdwn(ref) for ref in ref_map.values())
+ case CitationStyles.plaintext:
+ cites = " ".join(
+ f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items()
+ )
+
case CitationStyles.number_markdown:
- formatted += "\n\n"
- formatted += "\n".join(
- f"[{ref_num}] {ref_to_markdown(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ markdown_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.number_html:
- formatted += "
"
- formatted += "
".join(
- f"[{ref_num}] {ref_to_html(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ html_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.number_slack_mrkdwn:
- formatted += "\n\n"
- formatted += "\n".join(
- f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}"
- for ref_num, ref in sorted(all_refs.items())
- )
- case CitationStyles.number_plaintext:
- formatted += "\n\n"
- formatted += "\n".join(
- f'{ref_num}. {ref["title"]} {ref["url"]}'
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_markdown:
- formatted += "\n\n"
- formatted += "\n".join(
- f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ markdown_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_html:
- formatted += "
"
- formatted += "
".join(
- f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ html_link(f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_slack_mrkdwn:
- formatted += "\n\n"
- formatted += "\n".join(
- f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
- for ref_num, ref in sorted(all_refs.items())
- )
- case CitationStyles.symbol_plaintext:
- formatted += "\n\n"
- formatted += "\n".join(
- f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
- for ref_num, ref in sorted(all_refs.items())
- )
-
- for ref_num, ref in all_refs.items():
- try:
- template = ref["response_template"]
- except KeyError:
- pass
- else:
- formatted = jinja2.Template(template).render(
- **ref,
- output_text=formatted,
- ref_num=ref_num,
+ cites = " ".join(
+ slack_mrkdwn_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
)
- output_text[i] = formatted
+ case None:
+ cites = ""
+ case _:
+ raise ValueError(f"Unknown citation style: {citation_style}")
+ formatted += " ".join(filter(None, [snippet, cites]))
+ all_refs.update(ref_map)
+ return all_refs, formatted
+
+
+def format_footnotes(
+ all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles
+) -> str:
+ match citation_style:
+ case CitationStyles.number_markdown:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"[{ref_num}] {ref_to_markdown(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.number_html:
+ formatted += "
"
+ formatted += "
".join(
+ f"[{ref_num}] {ref_to_html(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.number_slack_mrkdwn:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.number_plaintext:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f'{ref_num}. {ref["title"]} {ref["url"]}'
+ for ref_num, ref in sorted(all_refs.items())
+ )
+
+ case CitationStyles.symbol_markdown:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_html:
+ formatted += "
"
+ formatted += "
".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_slack_mrkdwn:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_plaintext:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ return formatted
+
+
+def format_jinja_response_template(
+ all_refs: dict[int, SearchReference], formatted: str
+) -> str:
+ for ref_num, ref in all_refs.items():
+ try:
+ template = ref["response_template"]
+ except KeyError:
+ pass
+ else:
+ formatted = jinja2.Template(template).render(
+ **ref,
+ output_text=formatted,
+ ref_num=ref_num,
+ )
+ return formatted
search_ref_pat = re.compile(r"\[" r"[\d\s\.\,\[\]\$\{\}]+" r"\]")
diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py
index 5cf08c721..394d55b55 100644
--- a/daras_ai_v2/vector_search.py
+++ b/daras_ai_v2/vector_search.py
@@ -12,6 +12,7 @@
import requests
from furl import furl
from googleapiclient.errors import HttpError
+from loguru import logger
from pydantic import BaseModel, Field
from rank_bm25 import BM25Okapi
@@ -32,7 +33,7 @@
)
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS
-from daras_ai_v2.functional import flatmap_parallel
+from daras_ai_v2.functional import flatmap_parallel, map_parallel
from daras_ai_v2.gdrive_downloader import (
gdrive_download,
is_gdrive_url,
@@ -87,20 +88,23 @@ def get_top_k_references(
Returns:
the top k documents
"""
- yield "Getting embeddings..."
+ yield "Checking docs..."
input_docs = request.documents or []
+ doc_metas = map_parallel(doc_url_to_metadata, input_docs)
+
+ yield "Getting embeddings..."
embeds: list[tuple[SearchReference, np.ndarray]] = flatmap_parallel(
- lambda f_url: doc_url_to_embeds(
+ lambda f_url, doc_meta: get_embeds_for_doc(
f_url=f_url,
+ doc_meta=doc_meta,
max_context_words=request.max_context_words,
scroll_jump=request.scroll_jump,
selected_asr_model=request.selected_asr_model,
google_translate_target=request.google_translate_target,
),
input_docs,
- max_workers=10,
+ doc_metas,
)
-
dense_query_embeds = openai_embedding_create([request.search_query])[0]
yield "Searching documents..."
@@ -129,12 +133,21 @@ def get_top_k_references(
dense_ranks = np.zeros(len(embeds))
if sparse_weight:
+ yield "Getting sparse scores..."
# get sparse scores
- tokenized_corpus = [
- bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"])
- for ref, _ in embeds
- ]
- bm25 = BM25Okapi(tokenized_corpus, k1=2, b=0.3)
+ bm25_corpus = flatmap_parallel(
+ lambda f_url, doc_meta: get_bm25_embeds_for_doc(
+ f_url=f_url,
+ doc_meta=doc_meta,
+ max_context_words=request.max_context_words,
+ scroll_jump=request.scroll_jump,
+ selected_asr_model=request.selected_asr_model,
+ google_translate_target=request.google_translate_target,
+ ),
+ input_docs,
+ doc_metas,
+ )
+ bm25 = BM25Okapi(bm25_corpus, k1=2, b=0.3)
if request.keyword_query and isinstance(request.keyword_query, list):
sparse_query_tokenized = [item.lower() for item in request.keyword_query]
else:
@@ -150,6 +163,13 @@ def get_top_k_references(
else:
sparse_ranks = np.zeros(len(embeds))
+ # just in case sparse and dense ranks are different lengths, truncate to the shorter one
+ if len(sparse_ranks) != len(dense_ranks):
+ logger.warning(
+ f"sparse and dense ranks are different lengths, truncating... {len(sparse_ranks)=} {len(dense_ranks)=} {len(embeds)=}"
+ )
+ sparse_ranks = sparse_ranks[: len(dense_ranks)]
+ dense_ranks = dense_ranks[: len(sparse_ranks)]
# RRF formula: 1 / (k + rank)
k = 60
rrf_scores = (
@@ -159,11 +179,7 @@ def get_top_k_references(
# Final ranking
max_references = min(request.max_references, len(rrf_scores))
top_k = np.argpartition(rrf_scores, -max_references)[-max_references:]
- final_ranks = sorted(
- top_k,
- key=lambda idx: rrf_scores[idx],
- reverse=True,
- )
+ final_ranks = sorted(top_k, key=rrf_scores.__getitem__, reverse=True)
references = [embeds[idx][0] | {"score": rrf_scores[idx]} for idx in final_ranks]
@@ -211,38 +227,6 @@ def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str:
)
-def doc_url_to_embeds(
- *,
- f_url: str,
- max_context_words: int,
- scroll_jump: int,
- selected_asr_model: str = None,
- google_translate_target: str = None,
-) -> list[tuple[SearchReference, np.ndarray]]:
- """
- Get document embeddings for a given document url.
-
- Args:
- f_url: document url
- max_context_words: max number of words to include in each chunk
- scroll_jump: number of words to scroll by
- google_translate_target: target language for google translate
- selected_asr_model: selected ASR model (used for audio files)
-
- Returns:
- list of (SearchReference, embeddings vector) tuples
- """
- doc_meta = doc_url_to_metadata(f_url)
- return get_embeds_for_doc(
- f_url=f_url,
- doc_meta=doc_meta,
- max_context_words=max_context_words,
- scroll_jump=scroll_jump,
- selected_asr_model=selected_asr_model,
- google_translate_target=google_translate_target,
- )
-
-
class DocMetadata(typing.NamedTuple):
name: str
etag: str | None
@@ -321,6 +305,35 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
)
+@redis_cache_decorator
+def get_bm25_embeds_for_doc(
+ *,
+ f_url: str,
+ doc_meta: DocMetadata,
+ max_context_words: int,
+ scroll_jump: int,
+ google_translate_target: str = None,
+ selected_asr_model: str = None,
+):
+ pages = doc_url_to_text_pages(
+ f_url=f_url,
+ doc_meta=doc_meta,
+ selected_asr_model=selected_asr_model,
+ google_translate_target=google_translate_target,
+ )
+ refs = pages_to_split_refs(
+ pages=pages,
+ f_url=f_url,
+ doc_meta=doc_meta,
+ max_context_words=max_context_words,
+ scroll_jump=scroll_jump,
+ )
+ tokenized_corpus = [
+ bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) for ref in refs
+ ]
+ return tokenized_corpus
+
+
@redis_cache_decorator
def get_embeds_for_doc(
*,
@@ -345,18 +358,44 @@ def get_embeds_for_doc(
Returns:
list of (metadata, embeddings) tuples
"""
- import pandas as pd
-
pages = doc_url_to_text_pages(
f_url=f_url,
doc_meta=doc_meta,
selected_asr_model=selected_asr_model,
google_translate_target=google_translate_target,
)
+ refs = pages_to_split_refs(
+ pages=pages,
+ f_url=f_url,
+ doc_meta=doc_meta,
+ max_context_words=max_context_words,
+ scroll_jump=scroll_jump,
+ )
+ texts = [m["title"] + " | " + m["snippet"] for m in refs]
+ # get doc embeds in batches
+ batch_size = 16 # azure openai limits
+ embeds = flatmap_parallel(
+ openai_embedding_create,
+ [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)],
+ max_workers=2,
+ )
+ return list(zip(refs, embeds))
+
+
+def pages_to_split_refs(
+ *,
+ pages,
+ f_url: str,
+ doc_meta: DocMetadata,
+ max_context_words: int,
+ scroll_jump: int,
+) -> list[SearchReference]:
+ import pandas as pd
+
chunk_size = int(max_context_words * 2)
chunk_overlap = int(max_context_words * 2 / scroll_jump)
if isinstance(pages, pd.DataFrame):
- metas = []
+ refs = []
# treat each row as a separate document
for idx, row in pages.iterrows():
row = dict(row)
@@ -372,7 +411,7 @@ def get_embeds_for_doc(
)
else:
continue
- metas += [
+ refs += [
{
"title": doc_meta.name,
"url": f_url,
@@ -385,7 +424,7 @@ def get_embeds_for_doc(
]
else:
# split the text into chunks
- metas = [
+ refs = [
{
"title": (
doc_meta.name + (f", page {doc.end + 1}" if len(pages) > 1 else "")
@@ -403,15 +442,7 @@ def get_embeds_for_doc(
pages, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
]
- # get doc embeds in batches
- batch_size = 16 # azure openai limits
- texts = [m["title"] + " | " + m["snippet"] for m in metas]
- embeds = flatmap_parallel(
- openai_embedding_create,
- [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)],
- max_workers=5,
- )
- return list(zip(metas, embeds))
+ return refs
sections_re = re.compile(r"(\s*[\r\n\f\v]|^)(\w+)\=", re.MULTILINE)
diff --git a/gooey_ui/pubsub.py b/gooey_ui/pubsub.py
index ca3dfcb1a..5e210c956 100644
--- a/gooey_ui/pubsub.py
+++ b/gooey_ui/pubsub.py
@@ -42,7 +42,11 @@ def realtime_push(channel: str, value: typing.Any = "ping"):
msg = json.dumps(jsonable_encoder(value))
r.set(channel, msg)
r.publish(channel, json.dumps(time()))
- logger.info(f"publish {channel=}")
+ if isinstance(value, dict):
+ run_status = value.get("__run_status")
+ logger.info(f"publish {channel=} {run_status=}")
+ else:
+ logger.info(f"publish {channel=}")
# def use_state(
diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py
index f1be5a937..eac5ef885 100644
--- a/recipes/CompareLLM.py
+++ b/recipes/CompareLLM.py
@@ -1,10 +1,9 @@
import random
import typing
-
-import gooey_ui as st
from pydantic import BaseModel
+import gooey_ui as st
from bots.models import Workflow
from daras_ai_v2.base import BasePage
from daras_ai_v2.enum_selector_widget import enum_multiselect
@@ -12,6 +11,7 @@
run_language_model,
LargeLanguageModels,
llm_price,
+ SUPERSCRIPT,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
@@ -94,9 +94,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
state["output_text"] = output_text = {}
for selected_model in request.selected_models:
- yield f"Running {LargeLanguageModels[selected_model].value}..."
-
- output_text[selected_model] = run_language_model(
+ model = LargeLanguageModels[selected_model]
+ yield f"Running {model.value}..."
+ ret = run_language_model(
model=selected_model,
quality=request.quality,
num_outputs=request.num_outputs,
@@ -104,7 +104,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
prompt=prompt,
max_tokens=request.max_tokens,
avoid_repetition=request.avoid_repetition,
+ stream=True,
)
+ for i, item in enumerate(ret):
+ output_text[selected_model] = item
+ yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..."
def render_output(self):
self._render_outputs(st.session_state, 450)
diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py
index d18b5a54c..0b95464e1 100644
--- a/recipes/DocSearch.py
+++ b/recipes/DocSearch.py
@@ -23,8 +23,9 @@
from daras_ai_v2.search_ref import (
SearchReference,
render_output_with_refs,
- apply_response_template,
CitationStyles,
+ apply_response_formattings_prefix,
+ apply_response_formattings_suffix,
)
from daras_ai_v2.vector_search import (
DocSearchRequest,
@@ -194,9 +195,12 @@ def run_v2(
citation_style = (
request.citation_style and CitationStyles[request.citation_style]
) or None
- apply_response_template(
+ all_refs_list = apply_response_formattings_prefix(
response.output_text, response.references, citation_style
)
+ apply_response_formattings_suffix(
+ all_refs_list, response.output_text, citation_style
+ )
def get_raw_price(self, state: dict) -> float:
name = state.get("selected_model")
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index faa88d533..20eaa8537 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -50,6 +50,7 @@
get_entry_images,
get_entry_text,
format_chat_entry,
+ SUPERSCRIPT,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.lipsync_settings_widgets import lipsync_settings
@@ -58,7 +59,12 @@
from daras_ai_v2.query_generator import generate_final_search_query
from daras_ai_v2.query_params import gooey_get_query_params
from daras_ai_v2.query_params_util import extract_query_params
-from daras_ai_v2.search_ref import apply_response_template, parse_refs, CitationStyles
+from daras_ai_v2.search_ref import (
+ parse_refs,
+ CitationStyles,
+ apply_response_formattings_prefix,
+ apply_response_formattings_suffix,
+)
from daras_ai_v2.text_output_widget import text_output
from daras_ai_v2.text_to_speech_settings_widgets import (
TextToSpeechProviders,
@@ -805,7 +811,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
yield f"Running {model.value}..."
if is_chat_model:
- output_text = run_language_model(
+ chunks = run_language_model(
model=request.selected_model,
messages=[
{"role": s["role"], "content": s["content"]}
@@ -816,12 +822,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
temperature=request.sampling_temperature,
avoid_repetition=request.avoid_repetition,
tools=request.tools,
+ stream=True,
)
else:
prompt = "\n".join(
format_chatml_message(entry) for entry in prompt_messages
)
- output_text = run_language_model(
+ chunks = run_language_model(
model=request.selected_model,
prompt=prompt,
max_tokens=max_allowed_tokens,
@@ -830,43 +837,52 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
temperature=request.sampling_temperature,
avoid_repetition=request.avoid_repetition,
stop=[CHATML_END_TOKEN, CHATML_START_TOKEN],
+ stream=True,
)
- if request.tools:
- output_text, tool_call_choices = output_text
- state["output_documents"] = output_documents = []
- for tool_calls in tool_call_choices:
- for call in tool_calls:
- result = yield from exec_tool_call(call)
- output_documents.append(result)
-
- # save model response
- state["raw_output_text"] = [
- "".join(snippet for snippet, _ in parse_refs(text, references))
- for text in output_text
- ]
-
- # translate response text
- if request.user_language and request.user_language != "en":
- yield f"Translating response to {request.user_language}..."
- output_text = run_google_translate(
- texts=output_text,
- source_language="en",
- target_language=request.user_language,
- glossary_url=request.output_glossary_document,
- )
- state["raw_tts_text"] = [
+ citation_style = (
+ request.citation_style and CitationStyles[request.citation_style]
+ ) or None
+ all_refs_list = []
+ for i, output_text in enumerate(chunks):
+ if request.tools:
+ output_text, tool_call_choices = output_text
+ state["output_documents"] = output_documents = []
+ for tool_calls in tool_call_choices:
+ for call in tool_calls:
+ result = yield from exec_tool_call(call)
+ output_documents.append(result)
+
+ # save model response
+ state["raw_output_text"] = [
"".join(snippet for snippet, _ in parse_refs(text, references))
for text in output_text
]
- if references:
- citation_style = (
- request.citation_style and CitationStyles[request.citation_style]
- ) or None
- apply_response_template(output_text, references, citation_style)
-
- state["output_text"] = output_text
+ # translate response text
+ if request.user_language and request.user_language != "en":
+ yield f"Translating response to {request.user_language}..."
+ output_text = run_google_translate(
+ texts=output_text,
+ source_language="en",
+ target_language=request.user_language,
+ glossary_url=request.output_glossary_document,
+ )
+ state["raw_tts_text"] = [
+ "".join(snippet for snippet, _ in parse_refs(text, references))
+ for text in output_text
+ ]
+
+ if references:
+ all_refs_list = apply_response_formattings_prefix(
+ output_text, references, citation_style
+ )
+ state["output_text"] = output_text
+ yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..."
+ if all_refs_list:
+ apply_response_formattings_suffix(
+ all_refs_list, state["output_text"], citation_style
+ )
state["output_audio"] = []
state["output_video"] = []
diff --git a/url_shortener/admin.py b/url_shortener/admin.py
index f4d6bb4c0..e7f102a7a 100644
--- a/url_shortener/admin.py
+++ b/url_shortener/admin.py
@@ -137,3 +137,6 @@ class VisitorClickInfoAdmin(admin.ModelAdmin):
ordering = ["-created_at"]
autocomplete_fields = ["shortened_url"]
actions = [export_to_csv, export_to_excel]
+ readonly_fields = [
+ "created_at",
+ ]
diff --git a/url_shortener/routers.py b/url_shortener/routers.py
index fba39ffa9..1b3991cf8 100644
--- a/url_shortener/routers.py
+++ b/url_shortener/routers.py
@@ -21,10 +21,11 @@ def url_shortener(hashid: str, request: Request):
return Response(status_code=410, content="This link has expired")
# increment the click count
ShortenedURL.objects.filter(id=surl.id).update(clicks=F("clicks") + 1)
- if surl.enable_analytics:
- save_click_info.delay(
- surl.id, request.client.host, request.headers.get("user-agent", "")
- )
+ # disable because iplist.cc is down
+ # if surl.enable_analytics:
+ # save_click_info.delay(
+ # surl.id, request.client.host, request.headers.get("user-agent", "")
+ # )
if surl.url:
return RedirectResponse(
url=surl.url, status_code=303 # because youtu.be redirects are 303