diff --git a/README.md b/README.md index 89b0c5c34..322e83718 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,12 @@ ngrok http 8080 5. Copy the temporary access token there and set env var `WHATSAPP_ACCESS_TOKEN = XXXX` +**(Optional) Use the test script to send yourself messages** + +```bash +python manage.py runscript test_wa_msg_send --script-args 104696745926402 +918764022384 +``` +Replace `+918764022384` with your number and `104696745926402` with the test number ID ## Dangerous postgres commands @@ -145,6 +151,9 @@ pg_restore --no-privileges --no-owner -d $PGDATABASE $fname cid=$(docker ps | grep gooey-api-prod | cut -d " " -f 1 | head -1) # exec the script to create the fixture docker exec -it $cid poetry run ./manage.py runscript create_fixture +``` + +```bash # copy the fixture outside container docker cp $cid:/app/fixture.json . # print the absolute path @@ -178,3 +187,4 @@ rsync -P -a @captain.us-1.gooey.ai:/home//fixture.json . createdb -T template0 $PGDATABASE pg_dump $SOURCE_DATABASE | psql -q $PGDATABASE ``` + diff --git a/app_users/admin.py b/app_users/admin.py index 191197eb6..ffe8ccdc6 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -41,6 +41,7 @@ class AppUserAdmin(admin.ModelAdmin): "view_transactions", "open_in_firebase", "open_in_stripe", + "low_balance_email_sent_at", ] @admin.display(description="User Runs") diff --git a/app_users/migrations/0012_appuser_low_balance_email_sent_at.py b/app_users/migrations/0012_appuser_low_balance_email_sent_at.py new file mode 100644 index 000000000..efc2beaf0 --- /dev/null +++ b/app_users/migrations/0012_appuser_low_balance_email_sent_at.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-14 07:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0011_appusertransaction_charged_amount_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='appuser', + name='low_balance_email_sent_at', + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/app_users/migrations/0013_appusertransaction_app_users_a_user_id_9b2e8d_idx_and_more.py b/app_users/migrations/0013_appusertransaction_app_users_a_user_id_9b2e8d_idx_and_more.py new file mode 100644 index 000000000..b992887bb --- /dev/null +++ b/app_users/migrations/0013_appusertransaction_app_users_a_user_id_9b2e8d_idx_and_more.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.7 on 2024-02-28 14:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0012_appuser_low_balance_email_sent_at'), + ] + + operations = [ + migrations.AddIndex( + model_name='appusertransaction', + index=models.Index(fields=['user', 'amount', '-created_at'], name='app_users_a_user_id_9b2e8d_idx'), + ), + migrations.AddIndex( + model_name='appusertransaction', + index=models.Index(fields=['-created_at'], name='app_users_a_created_3c27fe_idx'), + ), + ] diff --git a/app_users/models.py b/app_users/models.py index 576ea0390..9299fba47 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -89,6 +89,8 @@ class AppUser(models.Model): stripe_customer_id = models.CharField(max_length=255, default="", blank=True) is_paying = models.BooleanField("paid", default=False) + low_balance_email_sent_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField( "created", editable=False, blank=True, default=timezone.now ) @@ -207,7 +209,11 @@ def search_stripe_customer(self) -> stripe.Customer | None: if not self.uid: return None if self.stripe_customer_id: - return stripe.Customer.retrieve(self.stripe_customer_id) + try: + return stripe.Customer.retrieve(self.stripe_customer_id) + except stripe.error.InvalidRequestError as e: + if e.http_status != 404: + raise try: customer = stripe.Customer.search( query=f'metadata["uid"]:"{self.uid}"' @@ -263,6 +269,10 @@ class AppUserTransaction(models.Model): class Meta: verbose_name = "Transaction" + indexes = [ + models.Index(fields=["user", "amount", "-created_at"]), + models.Index(fields=["-created_at"]), + ] def __str__(self): return f"{self.invoice_id} ({self.amount})" diff --git a/bots/admin.py b/bots/admin.py index 49321ba35..23d0a9dd5 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -5,7 +5,7 @@ from django import forms from django.conf import settings from django.contrib import admin -from django.db.models import Max, Count, F +from django.db.models import Max, Count, F, Sum from django.template import loader from django.utils import dateformat from django.utils.safestring import mark_safe @@ -28,8 +28,6 @@ WorkflowMetadata, ) from bots.tasks import create_personal_channels_for_all_members -from daras_ai.image_input import truncate_text_words -from daras_ai_v2.base import BasePage from gooeysite.custom_actions import export_to_excel, export_to_csv from gooeysite.custom_filters import ( related_json_field_summary, @@ -168,6 +166,7 @@ class BotIntegrationAdmin(admin.ModelAdmin): "Settings", { "fields": [ + "streaming_enabled", "show_feedback_buttons", "analysis_run", "view_analysis_results", @@ -265,10 +264,14 @@ class SavedRunAdmin(admin.ModelAdmin): "view_parent_published_run", "run_time", "price", + "is_api_call", "created_at", "updated_at", ] - list_filter = ["workflow"] + list_filter = [ + "workflow", + "is_api_call", + ] search_fields = ["workflow", "example_id", "run_id", "uid"] autocomplete_fields = ["parent_version"] @@ -277,10 +280,12 @@ class SavedRunAdmin(admin.ModelAdmin): "parent", "view_bots", "price", + "view_usage_cost", "transaction", "created_at", "updated_at", "run_time", + "is_api_call", ] actions = [export_to_csv, export_to_excel] @@ -289,16 +294,25 @@ class SavedRunAdmin(admin.ModelAdmin): django.db.models.JSONField: {"widget": JSONEditorWidget}, } + def get_queryset(self, request): + return ( + super() + .get_queryset(request) + .prefetch_related( + "parent_version", + "parent_version__published_run", + "parent_version__published_run__saved_run", + ) + ) + def lookup_allowed(self, key, value): if key in ["parent_version__published_run__id__exact"]: return True return super().lookup_allowed(key, value) def view_user(self, saved_run: SavedRun): - return change_obj_url( - AppUser.objects.get(uid=saved_run.uid), - label=f"{saved_run.uid}", - ) + user = AppUser.objects.get(uid=saved_run.uid) + return change_obj_url(user) view_user.short_description = "View User" @@ -312,6 +326,15 @@ def view_parent_published_run(self, saved_run: SavedRun): pr = saved_run.parent_published_run() return pr and change_obj_url(pr) + @admin.display(description="Usage Costs") + def view_usage_cost(self, saved_run: SavedRun): + total_cost = saved_run.usage_costs.aggregate(total_cost=Sum("dollar_amount"))[ + "total_cost" + ] + return list_related_html_url( + saved_run.usage_costs, extra_label=f"${total_cost.normalize()}" + ) + @admin.register(PublishedRunVersion) class PublishedRunVersionAdmin(admin.ModelAdmin): @@ -491,6 +514,7 @@ class MessageAdmin(admin.ModelAdmin): "prev_msg_content", "prev_msg_display_content", "prev_msg_saved_run", + "response_time", ] ordering = ["created_at"] actions = [export_to_csv, export_to_excel] @@ -550,6 +574,7 @@ def get_fieldsets(self, request, msg: Message = None): "Analysis", { "fields": [ + "response_time", "analysis_result", "analysis_run", "question_answered", diff --git a/bots/admin_links.py b/bots/admin_links.py index 94086e348..c06601953 100644 --- a/bots/admin_links.py +++ b/bots/admin_links.py @@ -35,6 +35,7 @@ def list_related_html_url( query_param: str = None, instance_id: int = None, show_add: bool = True, + extra_label: str = None, ) -> typing.Optional[str]: num = manager.all().count() @@ -60,6 +61,8 @@ def list_related_html_url( ).url label = f"{num} {meta.verbose_name if num == 1 else meta.verbose_name_plural}" + if extra_label: + label = f"{label} ({extra_label})" if show_add: add_related_url = furl( diff --git a/bots/migrations/0056_botintegration_streaming_enabled.py b/bots/migrations/0056_botintegration_streaming_enabled.py new file mode 100644 index 000000000..dcf1366fa --- /dev/null +++ b/bots/migrations/0056_botintegration_streaming_enabled.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.7 on 2024-01-31 19:14 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0055_workflowmetadata"), + ] + + operations = [ + migrations.AddField( + model_name="botintegration", + name="streaming_enabled", + field=models.BooleanField( + default=False, + help_text="If set, the bot will stream messages to the frontend", + ), + ), + ] diff --git a/bots/migrations/0057_message_response_time_and_more.py b/bots/migrations/0057_message_response_time_and_more.py new file mode 100644 index 000000000..5bdecd39c --- /dev/null +++ b/bots/migrations/0057_message_response_time_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.7 on 2024-02-05 15:11 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0056_botintegration_streaming_enabled'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='response_time', + field=models.DurationField(default=None, help_text='The time it took for the bot to respond to the corresponding user message', null=True), + ), + migrations.AlterField( + model_name='botintegration', + name='streaming_enabled', + field=models.BooleanField(default=False, help_text='If set, the bot will stream messages to the frontend (Slack only)'), + ), + ] diff --git a/bots/migrations/0058_alter_savedrun_unique_together_and_more.py b/bots/migrations/0058_alter_savedrun_unique_together_and_more.py new file mode 100644 index 000000000..a4e137127 --- /dev/null +++ b/bots/migrations/0058_alter_savedrun_unique_together_and_more.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.7 on 2024-02-06 18:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0057_message_response_time_and_more'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='savedrun', + unique_together={('run_id', 'uid'), ('workflow', 'example_id')}, + ), + migrations.AddIndex( + model_name='savedrun', + index=models.Index(fields=['run_id', 'uid'], name='bots_savedr_run_id_7b0b34_idx'), + ), + ] diff --git a/bots/migrations/0059_savedrun_is_api_call.py b/bots/migrations/0059_savedrun_is_api_call.py new file mode 100644 index 000000000..dc93057fc --- /dev/null +++ b/bots/migrations/0059_savedrun_is_api_call.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-12 07:54 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0058_alter_savedrun_unique_together_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='savedrun', + name='is_api_call', + field=models.BooleanField(default=False), + ), + ] diff --git a/bots/migrations/0060_conversation_reset_at.py b/bots/migrations/0060_conversation_reset_at.py new file mode 100644 index 000000000..10cd847b6 --- /dev/null +++ b/bots/migrations/0060_conversation_reset_at.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-20 16:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0059_savedrun_is_api_call'), + ] + + operations = [ + migrations.AddField( + model_name='conversation', + name='reset_at', + field=models.DateTimeField(blank=True, default=None, null=True), + ), + ] diff --git a/bots/models.py b/bots/models.py index fd6071f58..534ccdd7a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -254,13 +254,15 @@ class SavedRun(models.Model): page_title = models.TextField(default="", blank=True, help_text="(Deprecated)") page_notes = models.TextField(default="", blank=True, help_text="(Deprecated)") + is_api_call = models.BooleanField(default=False) + objects = SavedRunQuerySet.as_manager() class Meta: ordering = ["-updated_at"] unique_together = [ ["workflow", "example_id"], - ["workflow", "run_id", "uid"], + ["run_id", "uid"], ] constraints = [ models.CheckConstraint( @@ -273,6 +275,7 @@ class Meta: models.Index(fields=["-created_at"]), models.Index(fields=["-updated_at"]), models.Index(fields=["workflow"]), + models.Index(fields=["run_id", "uid"]), models.Index(fields=["workflow", "run_id", "uid"]), models.Index(fields=["workflow", "example_id", "run_id", "uid"]), models.Index(fields=["workflow", "example_id", "hidden"]), @@ -571,6 +574,11 @@ class BotIntegration(models.Model): help_text="If provided, the message content will be analyzed for this bot using this saved run", ) + streaming_enabled = models.BooleanField( + default=False, + help_text="If set, the bot will stream messages to the frontend (Slack only)", + ) + created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -694,8 +702,8 @@ def to_df_format( .replace(tzinfo=None) ) row |= { - "Last Sent": last_time.strftime("%b %d, %Y %I:%M %p"), - "First Sent": first_time.strftime("%b %d, %Y %I:%M %p"), + "Last Sent": last_time.strftime(settings.SHORT_DATETIME_FORMAT), + "First Sent": first_time.strftime(settings.SHORT_DATETIME_FORMAT), "A7": not convo.d7(), "A30": not convo.d30(), "R1": last_time - first_time < datetime.timedelta(days=1), @@ -712,7 +720,26 @@ def to_df_format( "Bot": str(convo.bot_integration), } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Messages", + "Correct Answers", + "Thumbs up", + "Thumbs down", + "Last Sent", + "First Sent", + "A7", + "A30", + "R1", + "R7", + "R30", + "Delta Hours", + "Created At", + "Bot", + ], + ) return df @@ -799,6 +826,7 @@ class Conversation(models.Model): ) created_at = models.DateTimeField(auto_now_add=True) + reset_at = models.DateTimeField(null=True, blank=True, default=None) objects = ConversationQuerySet.as_manager() @@ -901,14 +929,30 @@ def to_df_format( "Message (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), - "Feedback": message.feedbacks.first().get_display_text() - if message.feedbacks.first() - else None, # only show first feedback as per Sean's request + .strftime(settings.SHORT_DATETIME_FORMAT), + "Feedback": ( + message.feedbacks.first().get_display_text() + if message.feedbacks.first() + else None + ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, + "Run Time": ( + message.saved_run.run_time if message.saved_run else 0 + ), # user messages have no run/run_time } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Role", + "Message (EN)", + "Sent", + "Feedback", + "Analysis JSON", + "Run Time", + ], + ) return df def to_df_analysis_format( @@ -916,24 +960,31 @@ def to_df_analysis_format( ) -> "pd.DataFrame": import pandas as pd - qs = self.filter(role=CHATML_ROLE_USER).prefetch_related("feedbacks") + qs = self.filter(role=CHATML_ROLE_ASSISSTANT).prefetch_related("feedbacks") rows = [] for message in qs[:row_limit]: message: Message row = { "Name": message.conversation.get_display_name(), - "Question (EN)": message.content, - "Answer (EN)": message.get_next_by_created_at().content, + "Question (EN)": message.get_previous_by_created_at().content, + "Answer (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Analysis JSON": message.analysis_result, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=["Name", "Question (EN)", "Answer (EN)", "Sent", "Analysis JSON"], + ) return df - def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]: + def as_llm_context( + self, limit: int = 50, reset_at: datetime.datetime = None + ) -> list["ConversationEntry"]: + if reset_at: + self = self.filter(created_at__gt=reset_at) msgs = self.order_by("-created_at").prefetch_related("attachments")[:limit] entries = [None] * len(msgs) for i, msg in enumerate(reversed(msgs)): @@ -1012,6 +1063,12 @@ class Message(models.Model): help_text="Subject of given question (DEPRECATED)", ) + response_time = models.DurationField( + default=None, + null=True, + help_text="The time it took for the bot to respond to the corresponding user message", + ) + _analysis_started = False objects = MessageQuerySet.as_manager() @@ -1102,20 +1159,33 @@ def to_df_format( "Question Sent": feedback.message.get_previous_by_created_at() .created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Answer (EN)": feedback.message.content, "Answer Sent": feedback.message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Rating": Feedback.Rating(feedback.rating).label, "Feedback (EN)": feedback.text_english, "Feedback Sent": feedback.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Question Answered": feedback.message.question_answered, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Question (EN)", + "Question Sent", + "Answer (EN)", + "Answer Sent", + "Rating", + "Feedback (EN)", + "Feedback Sent", + "Question Answered", + ], + ) return df diff --git a/bots/tasks.py b/bots/tasks.py index 9afe2c343..4e00b4e78 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -1,4 +1,5 @@ import json +from json import JSONDecodeError from celery import shared_task from django.db.models import QuerySet @@ -19,6 +20,7 @@ SlackBot, ) from daras_ai_v2.vector_search import references_as_prompt +from gooeysite.bg_db_conn import get_celery_result_db_safe from recipes.VideoBots import ReplyButton @@ -57,15 +59,22 @@ def msg_analysis(msg_id: int): Message.objects.filter(id=msg_id).update(analysis_run=sr) # wait for the result - result.get(disable_sync_subtasks=False) + get_celery_result_db_safe(result) sr.refresh_from_db() # if failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) # save the result as json + output_text = flatten(sr.state["output_text"].values())[0] + try: + analysis_result = json.loads(output_text) + except JSONDecodeError: + analysis_result = { + "error": "Failed to parse the analysis result. Please check your script.", + } Message.objects.filter(id=msg_id).update( - analysis_result=json.loads(flatten(sr.state["output_text"].values())[0]), + analysis_result=analysis_result, ) @@ -128,7 +137,7 @@ def send_broadcast_msg( channel_is_personal=convo.slack_channel_is_personal, username=bi.name, token=bi.slack_access_token, - ) + )[0] case _: raise NotImplementedError( f"Platform {bi.platform} doesn't support broadcasts yet" diff --git a/bots/tests.py b/bots/tests.py index 21eef78f8..30105e648 100644 --- a/bots/tests.py +++ b/bots/tests.py @@ -92,3 +92,66 @@ def test_create_bot_integration_conversation_message(transactional_db): assert message_b.role == CHATML_ROLE_ASSISSTANT assert message_b.content == "Red, green, and yellow grow the best." assert message_b.display_content == "Red, green, and yellow grow the best." + + +def test_stats_get_tabular_data_invalid_sorting_options(transactional_db): + from recipes.VideoBotsStats import VideoBotsStatsPage + + page = VideoBotsStatsPage() + + # setup + run_url = "https://my_run_url" + bi = BotIntegration.objects.create( + name="My Bot Integration", + saved_run=None, + billing_account_uid="fdnacsFSBQNKVW8z6tzhBLHKpAm1", # digital green's account id + user_language="en", + show_feedback_buttons=True, + platform=Platform.WHATSAPP, + wa_phone_number="my_whatsapp_number", + wa_phone_number_id="my_whatsapp_number_id", + ) + convos = Conversation.objects.filter(bot_integration=bi) + msgs = Message.objects.filter(conversation__in=convos) + + # valid option but no data + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Name" + ) + assert df.shape[0] == 0 + assert "Name" in df.columns + + # valid option and data + convo = Conversation.objects.create( + bot_integration=bi, + state=ConvoState.INITIAL, + wa_phone_number="+919876543210", + ) + Message.objects.create( + conversation=convo, + role=CHATML_ROLE_USER, + content="What types of chilies can be grown in Mumbai?", + display_content="What types of chilies can be grown in Mumbai?", + ) + Message.objects.create( + conversation=convo, + role=CHATML_ROLE_ASSISSTANT, + content="Red, green, and yellow grow the best.", + display_content="Red, green, and yellow grow the best.", + analysis_result={"Answered": True}, + ) + convos = Conversation.objects.filter(bot_integration=bi) + msgs = Message.objects.filter(conversation__in=convos) + assert msgs.count() == 2 + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Name" + ) + assert df.shape[0] == 1 + assert "Name" in df.columns + + # invalid sort option should be ignored + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Invalid" + ) + assert df.shape[0] == 1 + assert "Name" in df.columns diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..a4a6ffe07 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -1,21 +1,30 @@ +import datetime +import html import traceback import typing from time import time from types import SimpleNamespace +import requests import sentry_sdk +from django.db.models import Sum +from django.utils import timezone +from fastapi import HTTPException import gooey_ui as st -from app_users.models import AppUser +from app_users.models import AppUser, AppUserTransaction from bots.models import SavedRun from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings -from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage +from daras_ai_v2.base import StateKeys, BasePage +from daras_ai_v2.exceptions import UserError from daras_ai_v2.send_email import send_email_via_postmark +from daras_ai_v2.send_email import send_low_balance_email from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params +from gooeysite.bg_db_conn import db_middleware, next_db_safe @app.task @@ -31,7 +40,19 @@ def gui_runner( is_api_call: bool = False, ): page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) + + def event_processor(event, hint): + event["request"] = { + "method": "POST", + "url": page.app_url(query_params=query_params), + "data": state, + } + return event + + page.setup_sentry(event_processor=event_processor) + sr = page.run_doc_sr(run_id, uid) + sr.is_api_call = is_api_call st.set_session_state(state) run_time = 0 @@ -39,6 +60,7 @@ def gui_runner( error_msg = None set_query_params(query_params or {}) + @db_middleware def save(done=False): if done: # clear run status @@ -80,7 +102,7 @@ def save(done=False): start_time = time() try: # advance the generator (to further progress of run()) - yield_val = next(gen) + yield_val = next_db_safe(gen) # increment total time taken after every iteration run_time += time() - start_time continue @@ -92,8 +114,25 @@ def save(done=False): # render errors nicely except Exception as e: run_time += time() - start_time - traceback.print_exc() - sentry_sdk.capture_exception(e) + + if isinstance(e, HTTPException) and e.status_code == 402: + error_msg = page.generate_credit_error_message( + example_id=query_params.get("example_id"), + run_id=run_id, + uid=uid, + ) + try: + raise UserError(error_msg) from e + except UserError as e: + sentry_sdk.capture_exception(e, level=e.sentry_level) + break + + if isinstance(e, UserError): + sentry_level = e.sentry_level + else: + sentry_level = "error" + traceback.print_exc() + sentry_sdk.capture_exception(e, level=sentry_level) error_msg = err_msg_for_exc(e) break finally: @@ -102,6 +141,69 @@ def save(done=False): save(done=True) if not is_api_call: send_email_on_completion(page, sr) + run_low_balance_email_check(uid) + + +def err_msg_for_exc(e: Exception): + if isinstance(e, requests.HTTPError): + response: requests.Response = e.response + try: + err_body = response.json() + except requests.JSONDecodeError: + err_str = response.text + else: + format_exc = err_body.get("format_exc") + if format_exc: + print("⚡️ " + format_exc) + err_type = err_body.get("type") + err_str = err_body.get("str") + if err_type and err_str: + return f"(GPU) {err_type}: {err_str}" + err_str = str(err_body) + return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" + elif isinstance(e, HTTPException): + return f"(HTTP {e.status_code}) {e.detail})" + elif isinstance(e, UserError): + return e.message + else: + return f"{type(e).__name__}: {e}" + + +def run_low_balance_email_check(uid: str): + # don't send email if feature is disabled + if not settings.LOW_BALANCE_EMAIL_ENABLED: + return + user = AppUser.objects.get(uid=uid) + # don't send email if user is not paying or has enough balance + if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS: + return + last_purchase = ( + AppUserTransaction.objects.filter(user=user, amount__gt=0) + .order_by("-created_at") + .first() + ) + email_date_cutoff = timezone.now() - datetime.timedelta( + days=settings.LOW_BALANCE_EMAIL_DAYS + ) + # send email if user has not been sent email in last X days or last purchase was after last email sent + if ( + # user has not been sent any email + not user.low_balance_email_sent_at + # user was sent email before X days + or (user.low_balance_email_sent_at < email_date_cutoff) + # user has made a purchase after last email sent + or (last_purchase and last_purchase.created_at > user.low_balance_email_sent_at) + ): + # calculate total credits consumed in last X days + total_credits_consumed = abs( + AppUserTransaction.objects.filter( + user=user, amount__lt=0, created_at__gte=email_date_cutoff + ).aggregate(Sum("amount"))["amount__sum"] + or 0 + ) + send_low_balance_email(user=user, total_credits_consumed=total_credits_consumed) + user.low_balance_email_sent_at = timezone.now() + user.save(update_fields=["low_balance_email_sent_at"]) def send_email_on_completion(page: BasePage, sr: SavedRun): diff --git a/conftest.py b/conftest.py index e8534d5f2..f2003dc0c 100644 --- a/conftest.py +++ b/conftest.py @@ -9,6 +9,22 @@ from auth import auth_backend from celeryapp import app from daras_ai_v2.base import BasePage +from daras_ai_v2.send_email import pytest_outbox + + +def flaky(fn): + max_tries = 5 + + @wraps(fn) + def wrapper(*args, **kwargs): + for i in range(max_tries): + try: + return fn(*args, **kwargs) + except Exception: + if i == max_tries - 1: + raise + + return wrapper @pytest.fixture(scope="session") @@ -44,11 +60,12 @@ def _mock_gui_runner( @pytest.fixture -def threadpool_subtest(subtests, max_workers: int = 8): +def threadpool_subtest(subtests, max_workers: int = 128): ts = [] - def submit(fn, *args, **kwargs): - msg = "--".join(map(str, [*args, *kwargs.values()])) + def submit(fn, *args, msg=None, **kwargs): + if not msg: + msg = "--".join(map(str, [*args, *kwargs.values()])) @wraps(fn) def runner(*args, **kwargs): @@ -67,6 +84,11 @@ def runner(*args, **kwargs): t.join() +@pytest.fixture(autouse=True) +def clear_pytest_outbox(): + pytest_outbox.clear() + + # class DummyDatabaseBlocker(pytest_django.plugin._DatabaseBlocker): # class _dj_db_wrapper: # def ensure_connection(self): diff --git a/daras_ai/extract_face.py b/daras_ai/extract_face.py index e81b35d6e..aeb107a45 100644 --- a/daras_ai/extract_face.py +++ b/daras_ai/extract_face.py @@ -1,5 +1,7 @@ import numpy as np +from daras_ai_v2.exceptions import UserError + def extract_and_reposition_face_cv2( orig_img, @@ -118,7 +120,7 @@ def face_oval_hull_generator(image_cv2): results = face_mesh.process(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)) if not results.multi_face_landmarks: - raise ValueError("Face not found") + raise UserError("Face not found") for landmark_list in results.multi_face_landmarks: idx_to_coordinates = build_idx_to_coordinates_dict( diff --git a/daras_ai/image_input.py b/daras_ai/image_input.py index 2a7a70417..5e61fcba9 100644 --- a/daras_ai/image_input.py +++ b/daras_ai/image_input.py @@ -11,6 +11,7 @@ from furl import furl from daras_ai_v2 import settings +from daras_ai_v2.exceptions import UserError def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes: @@ -90,7 +91,7 @@ def bytes_to_cv2_img(img_bytes: bytes, greyscale=False) -> np.ndarray: flags = cv2.IMREAD_COLOR img_cv2 = cv2.imdecode(np.frombuffer(img_bytes, dtype=np.uint8), flags=flags) if not img_exists(img_cv2): - raise ValueError("Bad Image") + raise UserError("Bad Image") return img_cv2 @@ -112,7 +113,9 @@ def safe_filename(filename: str) -> str: return out -def truncate_filename(text: str, maxlen: int = 100, sep: str = "...") -> str: +def truncate_filename( + text: str | bytes, maxlen: int = 100, sep: str | bytes = "..." +) -> str | bytes: if len(text) <= maxlen: return text assert len(sep) <= maxlen diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 81562a49f..4403dd32b 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -1,9 +1,7 @@ -import json import os.path -import subprocess +import os.path import tempfile from enum import Enum -from time import sleep import langcodes import requests @@ -14,7 +12,14 @@ import gooey_ui as st from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.azure_asr import azure_asr +from daras_ai_v2.exceptions import ( + raise_for_status, + UserError, + ffmpeg, + call_cmd, + ffprobe, +) from daras_ai_v2.functional import map_parallel from daras_ai_v2.gdrive_downloader import ( is_gdrive_url, @@ -22,25 +27,75 @@ gdrive_metadata, url_to_gdrive_file_id, ) +from daras_ai_v2.google_asr import gcp_asr_v1 from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.redis_cache import redis_cache_decorator SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB - TRANSLITERATION_SUPPORTED = {"ar", "bn", " gu", "hi", "ja", "kn", "ru", "ta", "te"} -# below CHIRP list was found experimentally since the supported languages list by google is actually wrong: -CHIRP_SUPPORTED = {"af-ZA", "sq-AL", "am-ET", "ar-EG", "hy-AM", "as-IN", "ast-ES", "az-AZ", "eu-ES", "be-BY", "bs-BA", "bg-BG", "my-MM", "ca-ES", "ceb-PH", "ckb-IQ", "zh-Hans-CN", "yue-Hant-HK", "hr-HR", "cs-CZ", "da-DK", "nl-NL", "en-AU", "en-IN", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-CA", "fr-FR", "gl-ES", "ka-GE", "de-DE", "el-GR", "gu-IN", "ha-NG", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", "it-IT", "ja-JP", "jv-ID", "kea-CV", "kam-KE", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "ky-KG", "lo-LA", "lv-LV", "ln-CD", "lt-LT", "luo-KE", "lb-LU", "mk-MK", "ms-MY", "ml-IN", "mt-MT", "mi-NZ", "mr-IN", "mn-MN", "ne-NP", "ny-MW", "oc-FR", "ps-AF", "fa-IR", "pl-PL", "pt-BR", "pa-Guru-IN", "ro-RO", "ru-RU", "nso-ZA", "sr-RS", "sn-ZW", "sd-IN", "si-LK", "sk-SK", "sl-SI", "so-SO", "es-ES", "es-US", "su-ID", "sw", "sv-SE", "tg-TJ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", "uz-UZ", "vi-VN", "cy-GB", "wo-SN", "yo-NG", "zu-ZA"} # fmt: skip - -WHISPER_SUPPORTED = {"af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi", "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy"} # fmt: skip +# https://cloud.google.com/speech-to-text/docs/speech-to-text-supported-languages +GCP_V1_SUPPORTED = { + "af-ZA", "sq-AL", "am-ET", "ar-DZ", "ar-BH", "ar-EG", "ar-IQ", "ar-IL", "ar-JO", "ar-KW", "ar-LB", "ar-MR", "ar-MA", + "ar-OM", "ar-QA", "ar-SA", "ar-PS", "ar-SY", "ar-TN", "ar-AE", "ar-YE", "hy-AM", "az-AZ", "eu-ES", "bn-BD", "bn-IN", + "bs-BA", "bg-BG", "my-MM", "ca-ES", "yue-Hant-HK", "zh", "zh-TW", "hr-HR", "cs-CZ", + "da-DK", "nl-BE", "nl-NL", "en-AU", "en-CA", "en-GH", "en-HK", "en-IN", "en-IE", "en-KE", "en-NZ", "en-NG", "en-PK", + "en-PH", "en-SG", "en-ZA", "en-TZ", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-BE", "fr-CA", "fr-FR", + "fr-CH", "gl-ES", "ka-GE", "de-AT", "de-DE", "de-CH", "el-GR", "gu-IN", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", + "it-IT", "it-CH", "ja-JP", "jv-ID", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "lo-LA", "lv-LV", "lt-LT", "mk-MK", "ms-MY", + "ml-IN", "mr-IN", "mn-MN", "ne-NP", "no-NO", "fa-IR", "pl-PL", "pt-BR", "pt-PT", "pa-Guru-IN", "ro-RO", "ru-RU", + "sr-RS", "si-LK", "sk-SK", "sl-SI", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", "es-DO", "es-EC", "es-SV", "es-GT", + "es-HN", "es-MX", "es-NI", "es-PA", "es-PY", "es-PE", "es-PR", "es-ES", "es-US", "es-UY", "es-VE", "su-ID", "sw-KE", + "sw-TZ", "sv-SE", "ta-IN", "ta-MY", "ta-SG", "ta-LK", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-IN", "ur-PK", "uz-UZ", + "vi-VN", "zu-ZA", +} # fmt: skip + +# https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages +CHIRP_SUPPORTED = { + "af-ZA", "sq-AL", "am-ET", "ar-EG", "hy-AM", "as-IN", "ast-ES", "az-AZ", "eu-ES", "be-BY", "bs-BA", "bg-BG", + "my-MM", "ca-ES", "ceb-PH", "ckb-IQ", "yue-Hant-HK", "zh-TW", "hr-HR", "cs-CZ", "da-DK", "nl-NL", + "en-AU", "en-IN", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-CA", "fr-FR", "gl-ES", "ka-GE", "de-DE", + "el-GR", "gu-IN", "ha-NG", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", "it-IT", "ja-JP", "jv-ID", "kea-CV", + "kam-KE", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "ky-KG", "lo-LA", "lv-LV", "ln-CD", "lt-LT", "luo-KE", "lb-LU", + "mk-MK", "ms-MY", "ml-IN", "mt-MT", "mi-NZ", "mr-IN", "mn-MN", "ne-NP", "no-NO", "ny-MW", "oc-FR", "ps-AF", "fa-IR", + "pl-PL", "pt-BR", "pa-Guru-IN", "ro-RO", "ru-RU", "nso-ZA", "sr-RS", "sn-ZW", "sd-IN", "si-LK", "sk-SK", "sl-SI", + "so-SO", "es-ES", "es-US", "su-ID", "sw", "sv-SE", "tg-TJ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", + "uz-UZ", "vi-VN", "cy-GB", "wo-SN", "yo-NG", "zu-ZA" +} # fmt: skip + +WHISPER_SUPPORTED = { + "af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "gl", "de", + "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi", "ne", "no", + "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy" +} # fmt: skip # See page 14 of https://scontent-sea1-1.xx.fbcdn.net/v/t39.2365-6/369747868_602316515432698_2401716319310287708_n.pdf?_nc_cat=106&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=_5cpNOcftdYAX8rCrVo&_nc_ht=scontent-sea1-1.xx&oh=00_AfDVkx7XubifELxmB_Un-yEYMJavBHFzPnvTbTlalbd_1Q&oe=65141B39 # For now, below are listed the languages that support ASR. Note that Seamless only accepts ISO 639-3 codes. -SEAMLESS_SUPPORTED = {"afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", "cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", "glg", "guj", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khk", "khm", "kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", "nno", "nob", "npi", "nya", "oci", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa", "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", "xho", "yor", "yue", "zlm", "zul"} # fmt: skip - -AZURE_SUPPORTED = {"af-ZA", "am-ET", "ar-AE", "ar-BH", "ar-DZ", "ar-EG", "ar-IL", "ar-IQ", "ar-JO", "ar-KW", "ar-LB", "ar-LY", "ar-MA", "ar-OM", "ar-PS", "ar-QA", "ar-SA", "ar-SY", "ar-TN", "ar-YE", "az-AZ", "bg-BG", "bn-IN", "bs-BA", "ca-ES", "cs-CZ", "cy-GB", "da-DK", "de-AT", "de-CH", "de-DE", "el-GR", "en-AU", "en-CA", "en-GB", "en-GH", "en-HK", "en-IE", "en-IN", "en-KE", "en-NG", "en-NZ", "en-PH", "en-SG", "en-TZ", "en-US", "en-ZA", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", "es-CU", "es-DO", "es-EC", "es-ES", "es-GQ", "es-GT", "es-HN", "es-MX", "es-NI", "es-PA", "es-PE", "es-PR", "es-PY", "es-SV", "es-US", "es-UY", "es-VE", "et-EE", "eu-ES", "fa-IR", "fi-FI", "fil-PH", "fr-BE", "fr-CA", "fr-CH", "fr-FR", "ga-IE", "gl-ES", "gu-IN", "he-IL", "hi-IN", "hr-HR", "hu-HU", "hy-AM", "id-ID", "is-IS", "it-CH", "it-IT", "ja-JP", "jv-ID", "ka-GE", "kk-KZ", "km-KH", "kn-IN", "ko-KR", "lo-LA", "lt-LT", "lv-LV", "mk-MK", "ml-IN", "mn-MN", "mr-IN", "ms-MY", "mt-MT", "my-MM", "nb-NO", "ne-NP", "nl-BE", "nl-NL", "pa-IN", "pl-PL", "ps-AF", "pt-BR", "pt-PT", "ro-RO", "ru-RU", "si-LK", "sk-SK", "sl-SI", "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", "zh-CN-sichuan", "zh-HK", "zh-TW", "zu-ZA"} # fmt: skip -MAX_POLLS = 100 +SEAMLESS_SUPPORTED = { + "afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", + "cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", "glg", "guj", "heb", "hin", + "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khk", "khm", + "kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", + "nno", "nob", "npi", "nya", "oci", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", + "snd", "som", "spa", "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", + "xho", "yor", "yue", "zlm", "zul" +} # fmt: skip + +AZURE_SUPPORTED = { + "af-ZA", "am-ET", "ar-AE", "ar-BH", "ar-DZ", "ar-EG", "ar-IL", "ar-IQ", "ar-JO", "ar-KW", "ar-LB", "ar-LY", "ar-MA", + "ar-OM", "ar-PS", "ar-QA", "ar-SA", "ar-SY", "ar-TN", "ar-YE", "az-AZ", "bg-BG", "bn-IN", "bs-BA", "ca-ES", "cs-CZ", + "cy-GB", "da-DK", "de-AT", "de-CH", "de-DE", "el-GR", "en-AU", "en-CA", "en-GB", "en-GH", "en-HK", "en-IE", "en-IN", + "en-KE", "en-NG", "en-NZ", "en-PH", "en-SG", "en-TZ", "en-US", "en-ZA", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", + "es-CU", "es-DO", "es-EC", "es-ES", "es-GQ", "es-GT", "es-HN", "es-MX", "es-NI", "es-PA", "es-PE", "es-PR", "es-PY", + "es-SV", "es-US", "es-UY", "es-VE", "et-EE", "eu-ES", "fa-IR", "fi-FI", "fil-PH", "fr-BE", "fr-CA", "fr-CH", + "fr-FR", "ga-IE", "gl-ES", "gu-IN", "he-IL", "hi-IN", "hr-HR", "hu-HU", "hy-AM", "id-ID", "is-IS", "it-CH", "it-IT", + "ja-JP", "jv-ID", "ka-GE", "kk-KZ", "km-KH", "kn-IN", "ko-KR", "lo-LA", "lt-LT", "lv-LV", "mk-MK", "ml-IN", "mn-MN", + "mr-IN", "ms-MY", "mt-MT", "my-MM", "nb-NO", "ne-NP", "nl-BE", "nl-NL", "pa-IN", "pl-PL", "ps-AF", "pt-BR", "pt-PT", + "ro-RO", "ru-RU", "si-LK", "sk-SK", "sl-SI", "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", + "th-TH", "tr-TR", "uk-UA", "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", + "zh-CN-sichuan", "zh-HK", "zh-TW", "zu-ZA" +} # fmt: skip # https://deepgram.com/product/languages for the "general" model: # DEEPGRAM_SUPPORTED = {"nl","en","en-AU","en-US","en-GB","en-NZ","en-IN","fr","fr-CA","de","hi","hi-Latn","id","it","ja","ko","cmn-Hans-CN","cmn-Hant-TW","no","pl","pt","pt-PT","pt-BR","ru","es","es-419","sv","tr","uk"} # fmt: skip @@ -56,15 +111,14 @@ class AsrModels(Enum): nemo_english = "Conformer English (ai4bharat.org)" nemo_hindi = "Conformer Hindi (ai4bharat.org)" vakyansh_bhojpuri = "Vakyansh Bhojpuri (Open-Speech-EkStep)" - usm = "Chirp / USM (Google)" + gcp_v1 = "Google Cloud V1" + usm = "Chirp / USM (Google V2)" deepgram = "Deepgram" azure = "Azure Speech" seamless_m4t = "Seamless M4T (Facebook Research)" def supports_auto_detect(self) -> bool: - return self not in { - self.azure, - } + return self not in {self.azure, self.gcp_v1} asr_model_ids = { @@ -89,6 +143,7 @@ def supports_auto_detect(self) -> bool: asr_supported_languages = { AsrModels.whisper_large_v3: WHISPER_SUPPORTED, AsrModels.whisper_large_v2: WHISPER_SUPPORTED, + AsrModels.gcp_v1: GCP_V1_SUPPORTED, AsrModels.usm: CHIRP_SUPPORTED, AsrModels.deepgram: DEEPGRAM_SUPPORTED, AsrModels.seamless_m4t: SEAMLESS_SUPPORTED, @@ -128,7 +183,7 @@ def google_translate_language_selector( label: the label to display key: the key to save the selected language to in the session state """ - languages = google_translate_languages() + languages = google_translate_target_languages() options = list(languages.keys()) if allow_none: options.insert(0, None) @@ -141,8 +196,8 @@ def google_translate_language_selector( ) -@redis_cache_decorator -def google_translate_languages() -> dict[str, str]: +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def google_translate_target_languages() -> dict[str, str]: """ Get list of supported languages for Google Translate. :return: Dictionary of language codes and display names. @@ -162,8 +217,8 @@ def google_translate_languages() -> dict[str, str]: } -@redis_cache_decorator -def google_translate_input_languages() -> dict[str, str]: +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def google_translate_source_languages() -> dict[str, str]: """ Get list of supported languages for Google Translate. :return: Dictionary of language codes and display names. @@ -246,11 +301,11 @@ def run_google_translate( if source_language: source_language = langcodes.Language.get(source_language).to_tag() source_language = get_language_in_collection( - source_language, google_translate_input_languages().keys() + source_language, google_translate_source_languages().keys() ) # this will default to autodetect if language is not found as supported target_language = langcodes.Language.get(target_language).to_tag() target_language: str | None = get_language_in_collection( - target_language, google_translate_languages().keys() + target_language, google_translate_target_languages().keys() ) if not target_language: raise ValueError(f"Unsupported target language: {target_language!r}") @@ -467,6 +522,8 @@ def run_asr( src_lang=language, ), ) + elif selected_model == AsrModels.gcp_v1: + return gcp_asr_v1(audio_url, language) elif selected_model == AsrModels.usm: location = settings.GCP_REGION @@ -575,7 +632,7 @@ def run_asr( assert data.get("chunks"), f"{selected_model.value} can't generate VTT" return generate_vtt(data["chunks"]) case _: - raise ValueError(f"Invalid output format: {output_format}") + raise UserError(f"Invalid output format: {output_format}") def _get_or_create_recognizer( @@ -611,64 +668,6 @@ def _get_or_create_recognizer( return recognizer -def azure_asr(audio_url: str, language: str): - # transcription from audio url only supported via rest api or cli - # Start by initializing a request - payload = { - "contentUrls": [ - audio_url, - ], - "displayName": "Gooey Transcription", - "model": None, - "properties": { - "wordLevelTimestampsEnabled": False, - }, - "locale": language or "en-US", - } - r = requests.post( - str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"), - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - "Content-Type": "application/json", - }, - json=payload, - ) - raise_for_status(r) - uri = r.json()["self"] - - # poll for results - for _ in range(MAX_POLLS): - r = requests.get( - uri, - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - }, - ) - if not r.ok or not r.json()["status"] == "Succeeded": - sleep(5) - continue - r = requests.get( - r.json()["links"]["files"], - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - }, - ) - raise_for_status(r) - transcriptions = [] - for value in r.json()["values"]: - if value["kind"] != "Transcription": - continue - r = requests.get( - value["links"]["contentUrl"], - headers={"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY}, - ) - raise_for_status(r) - combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}] - transcriptions += [combined_phrases[0].get("display", "")] - return "\n".join(transcriptions) - assert False, "Max polls exceeded, Azure speech did not yield a response" - - # 16kHz, 16-bit, mono FFMPEG_WAV_ARGS = ["-vn", "-acodec", "pcm_s16le", "-ac", "1", "-ar", "16000"] @@ -683,7 +682,7 @@ def download_youtube_to_wav(youtube_url: str) -> tuple[str, int]: infile = os.path.join(tmpdir, "infile") outfile = os.path.join(tmpdir, "outfile.wav") # run yt-dlp to download audio - args = [ + call_cmd( "yt-dlp", "--no-playlist", "--format", @@ -691,13 +690,9 @@ def download_youtube_to_wav(youtube_url: str) -> tuple[str, int]: "--output", infile, youtube_url, - ] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + ) # convert audio to single channel wav - args = ["ffmpeg", "-y", "-i", infile, *FFMPEG_WAV_ARGS, outfile] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + ffmpeg("-i", infile, *FFMPEG_WAV_ARGS, outfile) # read wav file into memory with open(outfile, "rb") as f: wavdata = f.read() @@ -728,43 +723,12 @@ def audio_bytes_to_wav(audio_bytes: bytes) -> tuple[bytes | None, int]: with tempfile.NamedTemporaryFile(suffix=".wav") as outfile: # convert audio to single channel wav - args = [ - "ffmpeg", - "-y", - "-i", - infile.name, - *FFMPEG_WAV_ARGS, - outfile.name, - ] - print("\t$ " + " ".join(args)) - try: - subprocess.check_output(args, stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - ffmpeg_output_error = ValueError(e.output, e) - raise ValueError( - "Invalid audio file. Could not convert audio to wav format. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')" - ) from ffmpeg_output_error + ffmpeg("-i", infile.name, *FFMPEG_WAV_ARGS, outfile.name) return outfile.read(), os.path.getsize(outfile.name) def check_wav_audio_format(filename: str) -> bool: - args = [ - "ffprobe", - "-v", - "quiet", - "-print_format", - "json", - "-show_streams", - filename, - ] - print("\t$ " + " ".join(args)) - try: - data = json.loads(subprocess.check_output(args, stderr=subprocess.STDOUT)) - except subprocess.CalledProcessError as e: - ffmpeg_output_error = ValueError(e.output, e) - raise ValueError( - "Invalid audio file. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')" - ) from ffmpeg_output_error + data = ffprobe(filename) return ( len(data["streams"]) == 1 and data["streams"][0]["codec_name"] == "pcm_s16le" diff --git a/daras_ai_v2/azure_asr.py b/daras_ai_v2/azure_asr.py new file mode 100644 index 000000000..aed873b03 --- /dev/null +++ b/daras_ai_v2/azure_asr.py @@ -0,0 +1,96 @@ +import datetime +from time import sleep + +import requests +from furl import furl + +from daras_ai_v2 import settings +from daras_ai_v2.exceptions import ( + raise_for_status, +) +from daras_ai_v2.redis_cache import redis_cache_decorator + +# 20 mins timeout +MAX_POLLS = 200 +POLL_INTERVAL = 6 + + +def azure_asr(audio_url: str, language: str): + # Start by initializing a request + # https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Transcriptions_Create + language = language or "en-US" + r = requests.post( + str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"), + headers=azure_auth_header(), + json={ + "contentUrls": [audio_url], + "displayName": f"Gooey Transcription {datetime.datetime.now().isoformat()} {language=} {audio_url=}", + "model": azure_get_latest_model(language), + "properties": { + "wordLevelTimestampsEnabled": False, + # "displayFormWordLevelTimestampsEnabled": True, + # "diarizationEnabled": False, + # "punctuationMode": "DictatedAndAutomatic", + # "profanityFilterMode": "Masked", + }, + "locale": language, + }, + ) + raise_for_status(r) + uri = r.json()["self"] + + # poll for results + for _ in range(MAX_POLLS): + r = requests.get(uri, headers=azure_auth_header()) + if not r.ok or not r.json()["status"] == "Succeeded": + sleep(POLL_INTERVAL) + continue + r = requests.get(r.json()["links"]["files"], headers=azure_auth_header()) + raise_for_status(r) + transcriptions = [] + for value in r.json()["values"]: + if value["kind"] != "Transcription": + continue + r = requests.get(value["links"]["contentUrl"], headers=azure_auth_header()) + raise_for_status(r) + combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}] + transcriptions += [combined_phrases[0].get("display", "")] + return "\n".join(transcriptions) + + raise RuntimeError("Max polls exceeded, Azure speech did not yield a response") + + +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def azure_get_latest_model(language: str) -> dict | None: + # https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Models_ListBaseModels + r = requests.get( + str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/models/base"), + headers=azure_auth_header(), + params={"filter": f"locale eq '{language}'"}, + ) + raise_for_status(r) + data = r.json()["values"] + try: + models = sorted( + data, + key=lambda m: datetime.datetime.strptime( + m["createdDateTime"], "%Y-%m-%dT%H:%M:%SZ" + ), + reverse=True, + ) + # ignore date parsing errors + except ValueError: + models = data + models.reverse() + for model in models: + if "whisper" in model["displayName"].lower(): + # whisper is pretty slow on azure, so we ignore it + continue + # return the latest model + return {"self": model["self"]} + + +def azure_auth_header(): + return { + "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, + } diff --git a/daras_ai_v2/azure_doc_extract.py b/daras_ai_v2/azure_doc_extract.py index b14179dba..878dc5733 100644 --- a/daras_ai_v2/azure_doc_extract.py +++ b/daras_ai_v2/azure_doc_extract.py @@ -16,15 +16,17 @@ auth_headers = {"Ocp-Apim-Subscription-Key": settings.AZURE_FORM_RECOGNIZER_KEY} -def azure_doc_extract_pages(pdf_url: str, model_id: str = "prebuilt-layout"): - result = azure_form_recognizer(pdf_url, model_id) +def azure_doc_extract_pages( + pdf_url: str, model_id: str = "prebuilt-layout", params: dict = None +): + result = azure_form_recognizer(pdf_url, model_id, params) return [ records_to_text(extract_records(result, page["pageNumber"])) for page in result["pages"] ] -@redis_cache_decorator +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) def azure_form_recognizer_models() -> dict[str, str]: r = requests.get( str( @@ -38,14 +40,14 @@ def azure_form_recognizer_models() -> dict[str, str]: return {value["modelId"]: value["description"] for value in r.json()["value"]} -@redis_cache_decorator -def azure_form_recognizer(url: str, model_id: str): +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def azure_form_recognizer(url: str, model_id: str, params: dict = None): r = requests.post( str( furl(settings.AZURE_FORM_RECOGNIZER_ENDPOINT) / f"formrecognizer/documentModels/{model_id}:analyze" ), - params={"api-version": "2023-07-31"}, + params={"api-version": "2023-07-31"} | (params or {}), headers=auth_headers, json={"urlSource": url}, ) @@ -67,7 +69,7 @@ def azure_form_recognizer(url: str, model_id: str): def extract_records(result: dict, page_num: int) -> list[dict]: table_polys = extract_tables(result, page_num) records = [] - for para in result["paragraphs"]: + for para in result.get("paragraphs", []): try: if para["boundingRegions"][0]["pageNumber"] != page_num: continue diff --git a/daras_ai_v2/azure_image_moderation.py b/daras_ai_v2/azure_image_moderation.py index 30da7a561..c9afc871a 100644 --- a/daras_ai_v2/azure_image_moderation.py +++ b/daras_ai_v2/azure_image_moderation.py @@ -1,7 +1,5 @@ -from typing import Any - -from furl import furl import requests +from furl import furl from daras_ai_v2 import settings from daras_ai_v2.exceptions import raise_for_status @@ -11,7 +9,7 @@ def get_auth_headers(): return {"Ocp-Apim-Subscription-Key": settings.AZURE_IMAGE_MODERATION_KEY} -def run_moderator(image_url: str, cache: bool) -> dict[str, Any]: +def is_image_nsfw(image_url: str, cache: bool = False) -> bool: url = str( furl(settings.AZURE_IMAGE_MODERATION_ENDPOINT) / "contentmoderator/moderate/v1.0/ProcessImage/Evaluate" @@ -22,10 +20,9 @@ def run_moderator(image_url: str, cache: bool) -> dict[str, Any]: headers=get_auth_headers(), json={"DataRepresentation": "URL", "Value": image_url}, ) + if r.status_code == 400 and ( + b"Image Size Error" in r.content or b"Image Error" in r.content + ): + return False raise_for_status(r) - return r.json() - - -def is_image_nsfw(image_url: str, cache: bool = False) -> bool: - response = run_moderator(image_url=image_url, cache=cache) - return response["IsImageAdultClassified"] + return r.json().get("IsImageAdultClassified", False) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index b65a65dae..cae24bd49 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -13,7 +13,6 @@ from time import sleep from types import SimpleNamespace -import requests import sentry_sdk from django.utils import timezone from django.utils.text import slugify @@ -180,24 +179,52 @@ def get_tab_url(self, tab: str) -> str: tab_name=MenuTabs.paths[tab], ) - def setup_render(self): + def setup_sentry(self, event_processor: typing.Callable = None): + def add_user_to_event(event, hint): + user = self.request and self.request.user + if not user: + return event + event["user"] = { + "id": user.id, + "name": user.display_name, + "email": user.email, + "data": { + field: getattr(user, field) + for field in [ + "uid", + "phone_number", + "photo_url", + "balance", + "is_paying", + "is_anonymous", + "is_disabled", + "disable_safety_checker", + "created_at", + ] + }, + } + return event + with sentry_sdk.configure_scope() as scope: scope.set_extra("base_url", self.app_url()) scope.set_transaction_name( "/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE ) + scope.add_event_processor(add_user_to_event) + if event_processor: + scope.add_event_processor(event_processor) def refresh_state(self): _, run_id, uid = extract_query_params(gooey_get_query_params()) - channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" + channel = self.realtime_channel_name(run_id, uid) output = realtime_pull([channel])[0] if output: st.session_state.update(output) def render(self): - self.setup_render() + self.setup_sentry() - if self.get_run_state() == RecipeRunState.running: + if self.get_run_state(st.session_state) == RecipeRunState.running: self.refresh_state() else: realtime_clear_subs() @@ -243,7 +270,12 @@ def _render_header(self): if tbreadcrumbs: with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): - render_breadcrumbs(tbreadcrumbs) + render_breadcrumbs( + tbreadcrumbs, + is_api_call=( + current_run.is_api_call and self.tab == MenuTabs.run + ), + ) author = self.run_user or current_run.get_creator() if not is_root_example: @@ -316,7 +348,7 @@ def _render_social_buttons(self, show_button_text: bool = False): copy_to_clipboard_button( f'{button_text}', - value=self._get_current_app_url(), + value=self.get_tab_url(self.tab), type="secondary", className="mb-0 ms-lg-2", ) @@ -996,9 +1028,7 @@ def get_or_create_root_published_run(cls) -> PublishedRun: workflow=cls.workflow, published_run_id="", defaults={ - "saved_run": lambda: cls.run_doc_sr( - run_id="", uid="", create=True, parent=None, parent_version=None - ), + "saved_run": lambda: cls.run_doc_sr(run_id="", uid="", create=True), "created_by": None, "last_edited_by": None, "title": cls.title, @@ -1022,15 +1052,11 @@ def run_doc_sr( run_id: str, uid: str, create: bool = False, - parent: SavedRun | None = None, - parent_version: PublishedRunVersion | None = None, + defaults: dict = None, ) -> SavedRun: config = dict(workflow=cls.workflow, uid=uid, run_id=run_id) if create: - return SavedRun.objects.get_or_create( - **config, - defaults=dict(parent=parent, parent_version=parent_version), - )[0] + return SavedRun.objects.get_or_create(**config, defaults=defaults)[0] else: return SavedRun.objects.get(**config) @@ -1264,7 +1290,9 @@ def _render_report_button(self): if not (self.request.user and run_id and uid): return - reported = st.button("❗Report") + reported = st.button( + ' Report', type="tertiary" + ) if not reported: return @@ -1306,12 +1334,13 @@ def _render_input_col(self): ) return submitted - def get_run_state(self) -> RecipeRunState: - if st.session_state.get(StateKeys.run_status): + @classmethod + def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState: + if state.get(StateKeys.run_status): return RecipeRunState.running - elif st.session_state.get(StateKeys.error_msg): + elif state.get(StateKeys.error_msg): return RecipeRunState.failed - elif st.session_state.get(StateKeys.run_time): + elif state.get(StateKeys.run_time): return RecipeRunState.completed else: # when user is at a recipe root, and not running anything @@ -1330,7 +1359,7 @@ def _render_output_col(self, submitted: bool): self._render_before_output() - run_state = self.get_run_state() + run_state = self.get_run_state(st.session_state) match run_state: case RecipeRunState.completed: self._render_completed_output() @@ -1344,12 +1373,11 @@ def _render_output_col(self, submitted: bool): # render outputs self.render_output() - if run_state != "waiting": + if run_state != RecipeRunState.running: self._render_after_output() def _render_completed_output(self): - run_time = st.session_state.get(StateKeys.run_time, 0) - st.success(f"Success! Run Time: `{run_time:.2f}` seconds.") + pass def _render_failed_output(self): err_msg = st.session_state.get(StateKeys.error_msg) @@ -1365,12 +1393,10 @@ def render_extra_waiting_output(self): if not estimated_run_time: return if created_at := st.session_state.get("created_at"): - if isinstance(created_at, datetime.datetime): - start_time = created_at - else: - start_time = datetime.datetime.fromisoformat(created_at) + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) with st.countdown_timer( - end_time=start_time + datetime.timedelta(seconds=estimated_run_time), + end_time=created_at + datetime.timedelta(seconds=estimated_run_time), delay_text="Sorry for the wait. Your run is taking longer than we expected.", ): if self.is_current_user_owner() and self.request.user.email: @@ -1407,7 +1433,7 @@ def should_submit_after_login(self) -> bool: and not self.request.user.is_anonymous ) - def create_new_run(self): + def create_new_run(self, is_api_call: bool = False): st.session_state[StateKeys.run_status] = "Starting..." st.session_state.pop(StateKeys.error_msg, None) st.session_state.pop(StateKeys.run_time, None) @@ -1442,8 +1468,11 @@ def create_new_run(self): run_id, uid, create=True, - parent=parent, - parent_version=parent_version, + defaults=dict( + parent=parent, + parent_version=parent_version, + is_api_call=is_api_call, + ), ).set(self.state_to_doc(st.session_state)) return None, run_id, uid @@ -1457,13 +1486,16 @@ def call_runner_task(self, example_id, run_id, uid, is_api_call=False): run_id=run_id, uid=uid, state=st.session_state, - channel=f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}", + channel=self.realtime_channel_name(run_id, uid), query_params=self.clean_query_params( example_id=example_id, run_id=run_id, uid=uid ), is_api_call=is_api_call, ) + def realtime_channel_name(self, run_id, uid): + return f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" + def generate_credit_error_message(self, example_id, run_id, uid) -> str: account_url = furl(settings.APP_BASE_URL) / "account/" if self.request.user.is_anonymous: @@ -1508,22 +1540,17 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): - col1, col2, col3 = st.columns([1, 1, 1], responsive=False) - col2.node.props[ - "className" - ] += " d-flex justify-content-center align-items-center" - col3.node.props["className"] += " d-flex justify-content-end align-items-center" + self._render_report_button() + if "seed" in self.RequestModel.schema_json(): - seed = st.session_state.get("seed") - with col1: - st.caption(f"*Seed\\\n`{seed}`*") - with col2: - randomize = st.button("♻️ Regenerate") - if randomize: - st.session_state[StateKeys.pressed_randomize] = True - st.experimental_rerun() - with col3: - self._render_report_button() + randomize = st.button( + ' Regenerate', type="tertiary" + ) + if randomize: + st.session_state[StateKeys.pressed_randomize] = True + st.experimental_rerun() + + render_output_caption() def state_to_doc(self, state: dict): ret = { @@ -1791,8 +1818,8 @@ def run_as_api_tab(self): as_async = st.checkbox("##### Run Async") as_form_data = st.checkbox("##### Upload Files via Form Data") - request_body = get_example_request_body( - self.RequestModel, st.session_state, include_all=include_all + request_body = self.get_example_request_body( + st.session_state, include_all=include_all ) response_body = self.get_example_response_body( st.session_state, as_async=as_async, include_all=include_all @@ -1838,7 +1865,27 @@ def get_price_roundoff(self, state: dict) -> int: return max(1, math.ceil(self.get_raw_price(state))) def get_raw_price(self, state: dict) -> float: - return self.price + return self.price * state.get("num_outputs", 1) + + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + """ + Fields that are not required, but are preferred to be shown in the example. + """ + return [] + + @classmethod + def get_example_request_body( + cls, + state: dict, + include_all: bool = False, + ) -> dict: + return extract_model_fields( + cls.RequestModel, + state, + include_all=include_all, + preferred_fields=cls.get_example_preferred_fields(state), + ) def get_example_response_body( self, @@ -1854,6 +1901,7 @@ def get_example_response_body( run_id=run_id, uid=self.request.user and self.request.user.uid, ) + output = extract_model_fields(self.ResponseModel, state, include_all=True) if as_async: return dict( run_id=run_id, @@ -1861,18 +1909,14 @@ def get_example_response_body( created_at=created_at, run_time_sec=st.session_state.get(StateKeys.run_time, 0), status="completed", - output=get_example_request_body( - self.ResponseModel, state, include_all=include_all - ), + output=output, ) else: return dict( id=run_id, url=web_url, created_at=created_at, - output=get_example_request_body( - self.ResponseModel, state, include_all=include_all - ), + output=output, ) def additional_notes(self) -> str | None: @@ -1900,15 +1944,41 @@ def is_current_user_owner(self) -> bool: ) -def get_example_request_body( - request_model: typing.Type[BaseModel], +def render_output_caption(): + caption = "" + + run_time = st.session_state.get(StateKeys.run_time, 0) + if run_time: + caption += f'Generated in {run_time :.2f}s' + + if seed := st.session_state.get("seed"): + caption += f' with seed {seed} ' + + created_at = st.session_state.get(StateKeys.created_at, datetime.datetime.today()) + if created_at: + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) + format_created_at = created_at.strftime(settings.SHORT_DATETIME_FORMAT) + caption += f' at {format_created_at}' + + st.caption(caption, unsafe_allow_html=True) + + +def extract_model_fields( + model: typing.Type[BaseModel], state: dict, include_all: bool = False, + preferred_fields: list[str] = None, ) -> dict: + """Only returns required fields unless include_all is set to True.""" return { field_name: state.get(field_name) - for field_name, field in request_model.__fields__.items() - if include_all or field.required + for field_name, field in model.__fields__.items() + if ( + include_all + or field.required + or (preferred_fields and field_name in preferred_fields) + ) } @@ -1928,27 +1998,6 @@ def extract_nested_str(obj) -> str: return "" -def err_msg_for_exc(e): - if isinstance(e, requests.HTTPError): - response: requests.Response = e.response - try: - err_body = response.json() - except requests.JSONDecodeError: - err_str = response.text - else: - format_exc = err_body.get("format_exc") - if format_exc: - print("⚡️ " + format_exc) - err_type = err_body.get("type") - err_str = err_body.get("str") - if err_type and err_str: - return f"(GPU) {err_type}: {err_str}" - err_str = str(err_body) - return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" - else: - return f"{type(e).__name__}: {e}" - - def force_redirect(url: str): # note: assumes sanitized URLs st.html( diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 3880f9fe7..73c6ccad0 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -19,11 +19,19 @@ def general_integration_settings(bi: BotIntegration): st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( "user_language" ).default - st.session_state[ - f"_bi_show_feedback_buttons_{bi.id}" - ] = BotIntegration._meta.get_field("show_feedback_buttons").default + st.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( + BotIntegration._meta.get_field("streaming_enabled").default + ) + st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( + BotIntegration._meta.get_field("show_feedback_buttons").default + ) st.session_state[f"_bi_analysis_url_{bi.id}"] = None + bi.streaming_enabled = st.checkbox( + "**📡 Streaming Enabled**", + value=bi.streaming_enabled, + key=f"_bi_streaming_enabled_{bi.id}", + ) bi.show_feedback_buttons = st.checkbox( "**👍🏾 👎🏽 Show Feedback Buttons**", value=bi.show_feedback_buttons, diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index dfc00c794..7075e3b2d 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -1,9 +1,11 @@ import mimetypes import traceback import typing +from datetime import datetime from urllib.parse import parse_qs from django.db import transaction +from django.utils import timezone from fastapi import HTTPException, Request from furl import furl from sentry_sdk import capture_exception @@ -20,11 +22,13 @@ MessageAttachment, ) from daras_ai_v2.asr import AsrModels, run_google_translate -from daras_ai_v2.base import BasePage +from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT from daras_ai_v2.vector_search import doc_url_to_file_metadata -from gooeysite.bg_db_conn import db_middleware +from gooey_ui.pubsub import realtime_subscribe +from gooeysite.bg_db_conn import db_middleware, get_celery_result_db_safe from recipes.VideoBots import VideoBotsPage, ReplyButton +from routers.api import submit_api_call PAGE_NOT_CONNECTED_ERROR = ( "💔 Looks like you haven't connected this page to a gooey.ai workflow. " @@ -47,7 +51,7 @@ """.strip() ERROR_MSG = """ -`{0!r}` +`{}` ⚠️ Sorry, I ran into an error while processing your request. Please try again, or type "Reset" to start over. """.strip() @@ -60,6 +64,8 @@ TAPPED_SKIP_MSG = "🌱 Alright. What else can I help you with?" +SLACK_MAX_SIZE = 3000 + async def request_json(request: Request): return await request.json() @@ -80,33 +86,13 @@ class BotInterface: input_type: str language: str show_feedback_buttons: bool = False + streaming_enabled: bool = False + can_update_message: bool = False convo: Conversation recieved_msg_id: str = None input_glossary: str | None = None output_glossary: str | None = None - def send_msg_or_default( - self, - *, - text: str | None = None, - audio: str = None, - video: str = None, - buttons: list[ReplyButton] = None, - documents: list[str] = None, - should_translate: bool = False, - default: str = DEFAULT_RESPONSE, - ): - if not (text or audio or video or documents): - text = default - return self.send_msg( - text=text, - audio=audio, - video=video, - buttons=buttons, - documents=documents, - should_translate=should_translate, - ) - def send_msg( self, *, @@ -116,6 +102,7 @@ def send_msg( buttons: list[ReplyButton] = None, documents: list[str] = None, should_translate: bool = False, + update_msg_id: str = None, ) -> str | None: raise NotImplementedError @@ -166,6 +153,7 @@ def _unpack_bot_integration(self): self.billing_account_uid = bi.billing_account_uid self.language = bi.user_language self.show_feedback_buttons = bi.show_feedback_buttons + self.streaming_enabled = bi.streaming_enabled def get_interactive_msg_info(self) -> tuple[str, str]: raise NotImplementedError("This bot does not support interactive messages.") @@ -202,6 +190,7 @@ def _on_msg(bot: BotInterface): speech_run = None input_images = None input_documents = None + recieved_time: datetime = timezone.now() if not bot.page_cls: bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR) return @@ -265,8 +254,8 @@ def _on_msg(bot: BotInterface): return # handle reset keyword if input_text.lower() == RESET_KEYWORD: - # clear saved messages - bot.convo.messages.all().delete() + # record the reset time so we don't send context + bot.convo.reset_at = timezone.now() # reset convo state bot.convo.state = ConvoState.INITIAL bot.convo.save() @@ -286,6 +275,7 @@ def _on_msg(bot: BotInterface): input_documents=input_documents, input_text=input_text, speech_run=speech_run, + recieved_time=recieved_time, ) @@ -324,41 +314,115 @@ def _process_and_send_msg( input_images: list[str] | None, input_documents: list[str] | None, input_text: str, + recieved_time: datetime, speech_run: str | None, ): - try: - # # mock testing - # msgs_to_save, response_audio, response_text, response_video = _echo( - # bot, input_text - # ) - # make API call to gooey bots to get the response - response, url = _process_msg( - page_cls=bot.page_cls, - api_user=billing_account_user, - query_params=bot.query_params, - convo=bot.convo, - input_text=input_text, - user_language=bot.language, - speech_run=speech_run, - input_images=input_images, - input_documents=input_documents, - ) - except HTTPException as e: - traceback.print_exc() - capture_exception(e) - # send error msg as repsonse - bot.send_msg(text=ERROR_MSG.format(e)) - return + # get latest messages for context + saved_msgs = bot.convo.messages.all().as_llm_context(reset_at=bot.convo.reset_at) - # send the response to the user - msg_id = bot.send_msg_or_default( - text=response.output_text and response.output_text[0], - audio=response.output_audio and response.output_audio[0], - video=response.output_video and response.output_video[0], - documents=response.output_documents or [], - buttons=_feedback_start_buttons() if bot.show_feedback_buttons else None, + # # mock testing + # result = _mock_api_output(input_text) + page, result, run_id, uid = submit_api_call( + page_cls=bot.page_cls, + user=billing_account_user, + request_body={ + "input_prompt": input_text, + "input_images": input_images, + "input_documents": input_documents, + "messages": saved_msgs, + "user_language": bot.language, + }, + query_params=bot.query_params, ) + if bot.show_feedback_buttons: + buttons = _feedback_start_buttons() + else: + buttons = None + + update_msg_id = None # this is the message id to update during streaming + sent_msg_id = None # this is the message id to record in the db + last_idx = 0 # this is the last index of the text sent to the user + if bot.streaming_enabled: + # subscribe to the realtime channel for updates + channel = page.realtime_channel_name(run_id, uid) + with realtime_subscribe(channel) as realtime_gen: + for state in realtime_gen: + run_state = page.get_run_state(state) + run_status = state.get(StateKeys.run_status) or "" + # check for errors + if run_state == RecipeRunState.failed: + err_msg = state.get(StateKeys.error_msg) + bot.send_msg(text=ERROR_MSG.format(err_msg)) + return # abort + if run_state != RecipeRunState.running: + break # we're done running, abort + text = state.get("output_text") and state.get("output_text")[0] + if not text: + # if no text, send the run status + if bot.can_update_message: + update_msg_id = bot.send_msg( + text=run_status, update_msg_id=update_msg_id + ) + continue # no text, wait for the next update + streaming_done = not run_status.lower().startswith("streaming") + # send the response to the user + if bot.can_update_message: + update_msg_id = bot.send_msg( + text=text.strip() + "...", + update_msg_id=update_msg_id, + buttons=buttons if streaming_done else None, + ) + last_idx = len(text) + else: + next_chunk = text[last_idx:] + last_idx = len(text) + if not next_chunk: + continue # no chunk, wait for the next update + update_msg_id = bot.send_msg( + text=next_chunk, + buttons=buttons if streaming_done else None, + ) + if streaming_done and not bot.can_update_message: + # if we send the buttons, this is the ID we need to record in the db for lookups later when the button is pressed + sent_msg_id = update_msg_id + # don't show buttons again + buttons = None + if streaming_done: + break # we're done streaming, abort + + # wait for the celery task to finish + get_celery_result_db_safe(result) + # get the final state from db + state = page.run_doc_sr(run_id, uid).to_dict() + # check for errors + err_msg = state.get(StateKeys.error_msg) + if err_msg: + bot.send_msg(text=ERROR_MSG.format(err_msg)) + return + + text = (state.get("output_text") and state.get("output_text")[0]) or "" + audio = state.get("output_audio") and state.get("output_audio")[0] + video = state.get("output_video") and state.get("output_video")[0] + documents = state.get("output_documents") or [] + # check for empty response + if not (text or audio or video or documents or buttons): + bot.send_msg(text=DEFAULT_RESPONSE) + return + # if in-place updates are enabled, update the message, otherwise send the remaining text + if not bot.can_update_message: + text = text[last_idx:] + # send the response to the user if there is any remaining + if text or audio or video or documents or buttons: + update_msg_id = bot.send_msg( + text=text, + audio=audio, + video=video, + documents=documents, + buttons=buttons, + update_msg_id=update_msg_id, + ) + # save msgs to db _save_msgs( bot=bot, @@ -366,9 +430,10 @@ def _process_and_send_msg( input_documents=input_documents, input_text=input_text, speech_run=speech_run, - platform_msg_id=msg_id, - response=response, - url=url, + platform_msg_id=sent_msg_id or update_msg_id, + response=VideoBotsPage.ResponseModel.parse_obj(state), + url=page.app_url(run_id=run_id, uid=uid), + received_time=recieved_time, ) @@ -381,6 +446,7 @@ def _save_msgs( platform_msg_id: str | None, response: VideoBotsPage.ResponseModel, url: str, + received_time: datetime, ): # create messages for future context user_msg = Message( @@ -389,11 +455,14 @@ def _save_msgs( role=CHATML_ROLE_USER, content=response.raw_input_text, display_content=input_text, - saved_run=SavedRun.objects.get_or_create( - workflow=Workflow.ASR, **furl(speech_run).query.params - )[0] - if speech_run - else None, + saved_run=( + SavedRun.objects.get_or_create( + workflow=Workflow.ASR, **furl(speech_run).query.params + )[0] + if speech_run + else None + ), + response_time=timezone.now() - received_time, ) attachments = [] for f_url in (input_images or []) + (input_documents or []): @@ -410,6 +479,7 @@ def _save_msgs( saved_run=SavedRun.objects.get_or_create( workflow=Workflow.VIDEO_BOTS, **furl(url).query.params )[0], + response_time=timezone.now() - received_time, ) # save the messages & attachments with transaction.atomic(): @@ -420,45 +490,6 @@ def _save_msgs( assistant_msg.save() -def _process_msg( - *, - page_cls, - api_user: AppUser, - query_params: dict, - convo: Conversation, - input_images: list[str] | None, - input_documents: list[str] | None, - input_text: str, - user_language: str, - speech_run: str | None, -) -> tuple[VideoBotsPage.ResponseModel, str]: - from routers.api import call_api - - # get latest messages for context (upto 100) - saved_msgs = convo.messages.all().as_llm_context() - - # # mock testing - # result = _mock_api_output(input_text) - - # call the api with provided input - result = call_api( - page_cls=page_cls, - user=api_user, - request_body={ - "input_prompt": input_text, - "input_images": input_images, - "input_documents": input_documents, - "messages": saved_msgs, - "user_language": user_language, - }, - query_params=query_params, - ) - # parse result - response = page_cls.ResponseModel.parse_obj(result["output"]) - url = result.get("url", "") - return response, url - - def _handle_interactive_msg(bot: BotInterface): try: button_id, context_msg_id = bot.get_interactive_msg_info() @@ -478,13 +509,25 @@ def _handle_interactive_msg(bot: BotInterface): return if button_id == ButtonIds.feedback_thumbs_up: rating = Feedback.Rating.RATING_THUMBS_UP - bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_UP - response_text = FEEDBACK_THUMBS_UP_MSG + # bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_UP + # response_text = FEEDBACK_THUMBS_UP_MSG else: rating = Feedback.Rating.RATING_THUMBS_DOWN - bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_DOWN - response_text = FEEDBACK_THUMBS_DOWN_MSG + # bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_DOWN + # response_text = FEEDBACK_THUMBS_DOWN_MSG + response_text = FEEDBACK_CONFIRMED_MSG.format( + bot_name=str(bot.convo.bot_integration.name) + ) bot.convo.save() + # save the feedback + Feedback.objects.create(message=context_msg, rating=rating) + # send a confirmation msg + post click buttons + bot.send_msg( + text=response_text, + # buttons=_feedback_post_click_buttons(), + should_translate=True, + ) + # handle skip case ButtonIds.action_skip: bot.send_msg(text=TAPPED_SKIP_MSG, should_translate=True) @@ -492,6 +535,7 @@ def _handle_interactive_msg(bot: BotInterface): bot.convo.state = ConvoState.INITIAL bot.convo.save() return + # not sure what button was pressed, ignore case _: bot_name = str(bot.convo.bot_integration.name) @@ -503,14 +547,6 @@ def _handle_interactive_msg(bot: BotInterface): bot.convo.state = ConvoState.INITIAL bot.convo.save() return - # save the feedback - Feedback.objects.create(message=context_msg, rating=rating) - # send a confirmation msg + post click buttons - bot.send_msg( - text=response_text, - buttons=_feedback_post_click_buttons(), - should_translate=True, - ) def _handle_audio_msg(billing_account_user, bot: BotInterface): @@ -533,8 +569,12 @@ def _handle_audio_msg(billing_account_user, bot: BotInterface): selected_model = AsrModels.whisper_telugu_large_v2.name case "bho": selected_model = AsrModels.vakyansh_bhojpuri.name - case "en": - selected_model = AsrModels.usm.name + case "sw": + selected_model = AsrModels.seamless_m4t.name + language = "swh" + # case "en": + # selected_model = AsrModels.usm.name + # language = "am-et" case _: selected_model = AsrModels.whisper_large_v2.name diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py index 96dc919e5..94e63cb80 100644 --- a/daras_ai_v2/breadcrumbs.py +++ b/daras_ai_v2/breadcrumbs.py @@ -31,7 +31,7 @@ def has_breadcrumbs(self): return bool(self.root_title or self.published_title) -def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs): +def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs, *, is_api_call: bool = False): st.html( """