Skip to content

Commit

Permalink
Merge branch 'master' into Generated-at-needs-correct-time
Browse files Browse the repository at this point in the history
  • Loading branch information
clr-li committed Feb 28, 2024
2 parents ec06c97 + 8ac0e71 commit 32939ca
Show file tree
Hide file tree
Showing 25 changed files with 171 additions and 69 deletions.
17 changes: 13 additions & 4 deletions bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,25 @@ class SavedRunAdmin(admin.ModelAdmin):
django.db.models.JSONField: {"widget": JSONEditorWidget},
}

def get_queryset(self, request):
return (
super()
.get_queryset(request)
.prefetch_related(
"parent_version",
"parent_version__published_run",
"parent_version__published_run__saved_run",
)
)

def lookup_allowed(self, key, value):
if key in ["parent_version__published_run__id__exact"]:
return True
return super().lookup_allowed(key, value)

def view_user(self, saved_run: SavedRun):
return change_obj_url(
AppUser.objects.get(uid=saved_run.uid),
label=f"{saved_run.uid}",
)
user = AppUser.objects.get(uid=saved_run.uid)
return change_obj_url(user)

view_user.short_description = "View User"

Expand Down
51 changes: 37 additions & 14 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _render_social_buttons(self, show_button_text: bool = False):

copy_to_clipboard_button(
f'<i class="fa-regular fa-link"></i>{button_text}',
value=self._get_current_app_url(),
value=self.get_tab_url(self.tab),
type="secondary",
className="mb-0 ms-lg-2",
)
Expand Down Expand Up @@ -1791,8 +1791,8 @@ def run_as_api_tab(self):
as_async = st.checkbox("##### Run Async")
as_form_data = st.checkbox("##### Upload Files via Form Data")

request_body = get_example_request_body(
self.RequestModel, st.session_state, include_all=include_all
request_body = self.get_example_request_body(
st.session_state, include_all=include_all
)
response_body = self.get_example_response_body(
st.session_state, as_async=as_async, include_all=include_all
Expand Down Expand Up @@ -1838,7 +1838,27 @@ def get_price_roundoff(self, state: dict) -> int:
return max(1, math.ceil(self.get_raw_price(state)))

def get_raw_price(self, state: dict) -> float:
return self.price
return self.price * state.get("num_outputs", 1)

@classmethod
def get_example_preferred_fields(cls, state: dict) -> list[str]:
"""
Fields that are not required, but are preferred to be shown in the example.
"""
return []

@classmethod
def get_example_request_body(
cls,
state: dict,
include_all: bool = False,
) -> dict:
return extract_model_fields(
cls.RequestModel,
state,
include_all=include_all,
preferred_fields=cls.get_example_preferred_fields(state),
)

def get_example_response_body(
self,
Expand All @@ -1854,25 +1874,22 @@ def get_example_response_body(
run_id=run_id,
uid=self.request.user and self.request.user.uid,
)
output = extract_model_fields(self.ResponseModel, state, include_all=True)
if as_async:
return dict(
run_id=run_id,
web_url=web_url,
created_at=created_at,
run_time_sec=st.session_state.get(StateKeys.run_time, 0),
status="completed",
output=get_example_request_body(
self.ResponseModel, state, include_all=include_all
),
output=output,
)
else:
return dict(
id=run_id,
url=web_url,
created_at=created_at,
output=get_example_request_body(
self.ResponseModel, state, include_all=include_all
),
output=output,
)

def additional_notes(self) -> str | None:
Expand Down Expand Up @@ -1924,15 +1941,21 @@ def render_output_caption():
)


def get_example_request_body(
request_model: typing.Type[BaseModel],
def extract_model_fields(
model: typing.Type[BaseModel],
state: dict,
include_all: bool = False,
preferred_fields: list[str] = None,
) -> dict:
"""Only returns required fields unless include_all is set to True."""
return {
field_name: state.get(field_name)
for field_name, field in request_model.__fields__.items()
if include_all or field.required
for field_name, field in model.__fields__.items()
if (
include_all
or field.required
or (preferred_fields and field_name in preferred_fields)
)
}


Expand Down
4 changes: 3 additions & 1 deletion daras_ai_v2/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
HASHIDS_SALT = config("HASHIDS_SALT", default="")

ALLOWED_HOSTS = ["*"]
INTERNAL_IPS = ["127.0.0.1"]
INTERNAL_IPS = ["127.0.0.1", "localhost"]
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https")

# Application definition
Expand All @@ -48,6 +48,7 @@
"django.contrib.staticfiles",
"bots",
"django_extensions",
# "debug_toolbar",
# the order matters, since we want to override the admin templates
"django.forms", # needed to override admin forms
"django.contrib.admin",
Expand All @@ -67,6 +68,7 @@
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
# "debug_toolbar.middleware.DebugToolbarMiddleware",
]

ROOT_URLCONF = "gooeysite.urls"
Expand Down
3 changes: 2 additions & 1 deletion gooeysite/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
"""

from django.contrib import admin
from django.urls import path
from django.urls import path, include

urlpatterns = [
# path("__debug__/", include("debug_toolbar.urls")),
path("", admin.site.urls),
]
4 changes: 4 additions & 0 deletions recipes/CompareLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def preview_image(self, state: dict) -> str | None:
def preview_description(self, state: dict) -> str:
return "Which language model works best your prompt? Compare your text generations across multiple large language models (LLMs) like OpenAI's evolving and latest ChatGPT engines and others like Curie, Ada, Babbage."

@classmethod
def get_example_preferred_fields(cls, state: dict) -> list[str]:
return ["input_prompt", "selected_models"]

def render_form_v2(self):
st.text_area(
"""
Expand Down
6 changes: 5 additions & 1 deletion recipes/CompareText2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class ResponseModel(BaseModel):
typing.Literal[tuple(e.name for e in Text2ImgModels)], list[str]
]

@classmethod
def get_example_preferred_fields(cls, state: dict) -> list[str]:
return ["selected_models"]

def preview_image(self, state: dict) -> str | None:
return DEFAULT_COMPARE_TEXT2IMG_META_IMG

Expand Down Expand Up @@ -264,4 +268,4 @@ def get_raw_price(self, state: dict) -> int:
total += 15
case _:
total += 2
return total
return total * state.get("num_outputs", 1)
9 changes: 7 additions & 2 deletions recipes/DocSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class ResponseModel(BaseModel):
final_prompt: str
final_search_query: str | None

@classmethod
def get_example_preferred_fields(self, state: dict) -> list[str]:
return ["documents"]

def render_form_v2(self):
st.text_area("#### Search Query", key="search_query")
document_uploader("#### Documents")
Expand Down Expand Up @@ -205,9 +209,10 @@ def run_v2(
def get_raw_price(self, state: dict) -> float:
name = state.get("selected_model")
try:
return llm_price[LargeLanguageModels[name]] * 2
unit_price = llm_price[LargeLanguageModels[name]] * 2
except KeyError:
return 10
unit_price = 10
return unit_price * state.get("num_outputs", 1)


def render_documents(state, label="**Documents**", *, key="documents"):
Expand Down
4 changes: 4 additions & 0 deletions recipes/DocSummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class ResponseModel(BaseModel):
prompt_tree: PromptTree | None
final_prompt: str

@classmethod
def get_example_preferred_fields(cls, state: dict) -> list[str]:
return ["task_instructions", "merge_instructions"]

def preview_image(self, state: dict) -> str | None:
return DEFAULT_DOC_SUMMARY_META_IMG

Expand Down
4 changes: 4 additions & 0 deletions recipes/EmailFaceInpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class ResponseModel(BaseModel):
output_images: list[str]
email_sent: bool = False

@classmethod
def get_example_preferred_fields(self, state: dict) -> list[str]:
return ["email_address"]

def preview_image(self, state: dict) -> str | None:
return DEFAULT_EMAIL_FACE_INPAINTING_META_IMG

Expand Down
6 changes: 4 additions & 2 deletions recipes/FaceInpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def get_raw_price(self, state: dict) -> int:
selected_model = state.get("selected_model")
match selected_model:
case InpaintingModels.dall_e.name:
return 20
unit_price = 20
case _:
return 5
unit_price = 5

return unit_price * state.get("num_outputs", 1)
10 changes: 8 additions & 2 deletions recipes/Img2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class RequestModel(BaseModel):
class ResponseModel(BaseModel):
output_images: list[str]

@classmethod
def get_example_preferred_fields(self, state: dict) -> list[str]:
return ["text_prompt"]

def preview_image(self, state: dict) -> str | None:
return DEFAULT_IMG2IMG_META_IMG

Expand Down Expand Up @@ -202,6 +206,8 @@ def get_raw_price(self, state: dict) -> int:
selected_model = state.get("selected_model")
match selected_model:
case Img2ImgModels.dall_e.name:
return 20
unit_price = 20
case _:
return 5
unit_price = 5

return unit_price * state.get("num_outputs", 1)
6 changes: 4 additions & 2 deletions recipes/ObjectInpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def get_raw_price(self, state: dict) -> int:
selected_model = state.get("selected_model")
match selected_model:
case InpaintingModels.dall_e.name:
return 20
unit_price = 20
case _:
return 5
unit_price = 5

return unit_price * state.get("num_outputs", 1)
11 changes: 10 additions & 1 deletion recipes/QRCodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def related_workflows(self) -> list:
EmailFaceInpaintingPage,
]

@classmethod
def get_example_preferred_fields(cls, state: dict) -> list[str]:
if state.get("qr_code_file"):
return ["qr_code_file"]
elif state.get("qr_code_input_image"):
return ["qr_code_input_image"]
else:
return ["qr_code_data"]

def render_form_v2(self):
st.text_area(
"""
Expand Down Expand Up @@ -735,7 +744,7 @@ def extract_qr_code_data(img: np.ndarray) -> str:
return info


class InvalidQRCode(AssertionError):
class InvalidQRCode(UserError):
pass


Expand Down
4 changes: 4 additions & 0 deletions recipes/TextToSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class ResponseModel(BaseModel):
def fallback_preivew_image(self) -> str | None:
return DEFAULT_TTS_META_IMG

@classmethod
def get_example_preferred_fields(cls, state: dict) -> list[str]:
return ["tts_provider"]

def preview_description(self, state: dict) -> str:
return "Input your text, pick a voice & a Text-to-Speech AI engine to create audio. Compare the best voice generators from Google, UberDuck.ai & more to add automated voices to your podcast, YouTube videos, website, or app."

Expand Down
39 changes: 27 additions & 12 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path
import typing

from django.db.models import QuerySet
from django.db.models import QuerySet, Q
from furl import furl
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -601,11 +601,13 @@ def get_raw_price(self, state: dict):
"raw_tts_text", state.get("raw_output_text", [])
)
tts_state = {"text_prompt": "".join(output_text_list)}
return super().get_raw_price(state) + TextToSpeechPage().get_raw_price(
total = super().get_raw_price(state) + TextToSpeechPage().get_raw_price(
tts_state
)
case _:
return super().get_raw_price(state)
total = super().get_raw_price(state)

return total * state.get("num_outputs", 1)

def additional_notes(self):
tts_provider = st.session_state.get("tts_provider")
Expand Down Expand Up @@ -975,11 +977,11 @@ def render_selected_tab(self, selected_tab):
unsafe_allow_html=True,
)

st.write("---")
st.text_input(
"###### 🤖 [Landbot](https://landbot.io/) URL", key="landbot_url"
)
show_landbot_widget()
# st.write("---")
# st.text_input(
# "###### 🤖 [Landbot](https://landbot.io/) URL", key="landbot_url"
# )
# show_landbot_widget()

def messenger_bot_integration(self):
from routers.facebook_api import ig_connect_url, fb_connect_url
Expand Down Expand Up @@ -1028,15 +1030,28 @@ def messenger_bot_integration(self):

st.button("🔄 Refresh")

current_run, published_run = self.get_runs_from_query_params(
*extract_query_params(gooey_get_query_params())
) # type: ignore

integrations_q = Q(billing_account_uid=self.request.user.uid)

# show admins all the bots connected to the current run
if self.is_current_user_admin():
integrations_q |= Q(saved_run=current_run)
if published_run:
integrations_q |= Q(
saved_run__example_id=published_run.published_run_id
)
integrations_q |= Q(published_run=published_run)

integrations: QuerySet[BotIntegration] = BotIntegration.objects.filter(
billing_account_uid=self.request.user.uid
integrations_q
).order_by("platform", "-created_at")

if not integrations:
return

current_run, published_run = self.get_runs_from_query_params(
*extract_query_params(gooey_get_query_params())
)
for bi in integrations:
is_connected = (bi.saved_run == current_run) or (
(
Expand Down
Loading

0 comments on commit 32939ca

Please sign in to comment.