From 63deccfc10cb890a0a9b66d7eca6644235f26e14 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 2 Aug 2024 03:36:15 +0530 Subject: [PATCH 01/38] Add x-fern-sdk-return-value for all status routes --- routers/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/routers/api.py b/routers/api.py index aab15806c..58a0870c3 100644 --- a/routers/api.py +++ b/routers/api.py @@ -283,6 +283,7 @@ def run_api_form_async( operation_id="status__" + page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v3 status)", + openapi_extra={"x-fern-sdk-return-value": "output"}, ) @app.get( os.path.join(endpoint, "status"), From 53a8036e409b7f1bd155f0e854b3f5e771af38f8 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 2 Aug 2024 03:37:35 +0530 Subject: [PATCH 02/38] fern: ignore v2 sync APIs --- routers/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/routers/api.py b/routers/api.py index 58a0870c3..04ce9f84c 100644 --- a/routers/api.py +++ b/routers/api.py @@ -148,6 +148,7 @@ def script_to_api(page_cls: typing.Type[BasePage]): operation_id=page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v2 sync)", + openapi_extra={"x-fern-ignore": True}, ) @app.post( endpoint, From e318e0b865a52bd3067b1769856824aa84c9958d Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 2 Aug 2024 03:52:48 +0530 Subject: [PATCH 03/38] add query param for example_id to v2 sync and v3 async --- routers/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/routers/api.py b/routers/api.py index 04ce9f84c..94011f07c 100644 --- a/routers/api.py +++ b/routers/api.py @@ -162,6 +162,7 @@ def script_to_api(page_cls: typing.Type[BasePage]): def run_api_json( request: Request, page_request: request_model, + example_id: str | None = None, user: AppUser = Depends(api_auth_header), ): return _run_api( @@ -226,6 +227,7 @@ def run_api_json_async( request: Request, response: Response, page_request: request_model, + example_id: str | None = None, user: AppUser = Depends(api_auth_header), ): ret = _run_api( From 88dd5b540a9417d6c4b9012e975d36bd29eebc2e Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 2 Aug 2024 16:16:45 +0530 Subject: [PATCH 04/38] Add method-name and group-name to openapi schema --- daras_ai_v2/base.py | 7 +++++++ routers/api.py | 10 +++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index a90f2504b..43bf49a08 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -303,6 +303,13 @@ def sentry_event_set_user(self, event, hint): } return event + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return { + "x-fern-sdk-group-name": cls.slug_versions[-1].title().replace("-", ""), + "x-fern-sdk-method-name": "status", + } + def refresh_state(self): _, run_id, uid = extract_query_params(gui.get_query_params()) channel = self.realtime_channel_name(run_id, uid) diff --git a/routers/api.py b/routers/api.py index 94011f07c..65484cd5b 100644 --- a/routers/api.py +++ b/routers/api.py @@ -5,7 +5,7 @@ import typing from types import SimpleNamespace -from fastapi import APIRouter +from fastapi import APIRouter, Query from fastapi import Depends from fastapi import Form from fastapi import HTTPException @@ -215,6 +215,7 @@ def run_api_form( name=page_cls.title + " (v3 async)", tags=[page_cls.title], status_code=202, + openapi_extra=page_cls.get_openapi_extra(), ) @app.post( os.path.join(endpoint, "async"), @@ -227,7 +228,7 @@ def run_api_json_async( request: Request, response: Response, page_request: request_model, - example_id: str | None = None, + example_id: str | None = Query(default=None), user: AppUser = Depends(api_auth_header), ): ret = _run_api( @@ -286,7 +287,10 @@ def run_api_form_async( operation_id="status__" + page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v3 status)", - openapi_extra={"x-fern-sdk-return-value": "output"}, + openapi_extra={ + "x-fern-sdk-return-value": "output", + **page_cls.get_openapi_extra(), + }, ) @app.get( os.path.join(endpoint, "status"), From 4aa6f68f75e79b295b42749eb4f308a25c1171a2 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:27:06 +0530 Subject: [PATCH 05/38] Add SDK names and OpenAPI extras for endpoints --- recipes/BulkEval.py | 1 + recipes/BulkRunner.py | 2 ++ recipes/ChyronPlant.py | 6 ++++++ recipes/CompareLLM.py | 7 ++++--- recipes/CompareText2Img.py | 1 + recipes/CompareUpscaler.py | 1 + recipes/DeforumSD.py | 1 + recipes/DocExtract.py | 1 + recipes/DocSearch.py | 1 + recipes/DocSummary.py | 1 + recipes/EmailFaceInpainting.py | 1 + recipes/FaceInpainting.py | 1 + recipes/Functions.py | 1 + recipes/GoogleGPT.py | 9 +++------ recipes/GoogleImageGen.py | 1 + recipes/ImageSegmentation.py | 1 + recipes/Img2Img.py | 1 + recipes/LetterWriter.py | 5 +++++ recipes/Lipsync.py | 1 + recipes/LipsyncTTS.py | 1 + recipes/ObjectInpainting.py | 1 + recipes/QRCodeGenerator.py | 2 +- recipes/RelatedQnA.py | 1 + recipes/RelatedQnADoc.py | 1 + recipes/SEOSummary.py | 1 + recipes/SmartGPT.py | 1 + recipes/SocialLookupEmail.py | 1 + recipes/Text2Audio.py | 1 + recipes/TextToSpeech.py | 1 + recipes/Translation.py | 1 + recipes/VideoBots.py | 10 ++++++++++ recipes/asr_page.py | 1 + recipes/embeddings_page.py | 1 + routers/api.py | 8 +++----- 34 files changed, 60 insertions(+), 15 deletions(-) diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py index 7f92a53e1..3510e396f 100644 --- a/recipes/BulkEval.py +++ b/recipes/BulkEval.py @@ -139,6 +139,7 @@ class BulkEvalPage(BasePage): title = "Evaluator" workflow = Workflow.BULK_EVAL slug_versions = ["bulk-eval", "eval"] + sdk_method_name = "eval" explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aad314f0-9a97-11ee-8318-02420a0001c7/W.I.9.png.png" diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index f700effac..97077e523 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -43,6 +43,8 @@ class BulkRunnerPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/87f35df4-88d7-11ee-aac9-02420a00016b/Bulk%20Runner.png.png" workflow = Workflow.BULK_RUNNER slug_versions = ["bulk-runner", "bulk"] + sdk_method_name = "bulkRun" + price = 1 class RequestModel(BasePage.RequestModel): diff --git a/recipes/ChyronPlant.py b/recipes/ChyronPlant.py index fef4714b9..0777b93c7 100644 --- a/recipes/ChyronPlant.py +++ b/recipes/ChyronPlant.py @@ -1,4 +1,5 @@ import gooey_gui as gui +from gooey_gui.components import typing from pydantic import BaseModel from bots.models import Workflow @@ -12,6 +13,7 @@ class ChyronPlantPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.CHYRON_PLANT slug_versions = ["ChyronPlant"] + sdk_method_name = "" class RequestModel(BasePage.RequestModel): midi_notes: str @@ -30,6 +32,10 @@ class ResponseModel(BaseModel): midi_translation: str chyron_output: str + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return {"x-fern-ignore": True} + def render_form_v2(self): gui.text_input( """ diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 513421146..99390009f 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -30,6 +30,7 @@ class CompareLLMPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ae42015e-88d7-11ee-aac9-02420a00016b/Compare%20LLMs.png.png" workflow = Workflow.COMPARE_LLM slug_versions = ["CompareLLM", "llm", "compare-large-language-models"] + sdk_method_name = "llm" functions_in_settings = False @@ -43,9 +44,9 @@ class CompareLLMPage(BasePage): class RequestModelBase(BasePage.RequestModel): input_prompt: str | None - selected_models: ( - list[typing.Literal[tuple(e.name for e in LargeLanguageModels)]] | None - ) + selected_models: list[ + typing.Literal[tuple(e.name for e in LargeLanguageModels)] + ] class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index f41a46170..0f132eb04 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -39,6 +39,7 @@ class CompareText2ImgPage(BasePage): "text2img", "compare-ai-image-generators", ] + sdk_method_name = "textToImage" sane_defaults = { "guidance_scale": 7.5, diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 4684ab309..dafcb9977 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -20,6 +20,7 @@ class CompareUpscalerPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/64393e0c-88db-11ee-b428-02420a000168/AI%20Image%20Upscaler.png.png" workflow = Workflow.COMPARE_UPSCALER slug_versions = ["compare-ai-upscalers"] + sdk_method_name = "upscale" class RequestModel(BasePage.RequestModel): input_image: FieldHttpUrl | None = Field(None, description="Input Image") diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index ef4b186b5..4cd930cfe 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -166,6 +166,7 @@ class DeforumSDPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/media/users/kxmNIYAOJbfOURxHBKNCWeUSKiP2/dd88c110-88d6-11ee-9b4f-2b58bd50e819/animation.gif" workflow = Workflow.DEFORUM_SD slug_versions = ["DeforumSD", "animation-generator"] + sdk_method_name = "animate" sane_defaults = dict( zoom="0: (1.004)", diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 323e3eab9..63c741e38 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -86,6 +86,7 @@ class DocExtractPage(BasePage): "youtube-bot", "doc-extract", ] + sdk_method_name = "synthesizeData" price = 500 class RequestModelBase(BasePage.RequestModel): diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 04e034dde..81ec19151 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -52,6 +52,7 @@ class DocSearchPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/cbbb4dc6-88d7-11ee-bf6c-02420a000166/Search%20your%20docs%20with%20gpt.png.png" workflow = Workflow.DOC_SEARCH slug_versions = ["doc-search"] + sdk_method_name = "rag" sane_defaults = { "sampling_temperature": 0.1, diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 18412e197..04f739af1 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -49,6 +49,7 @@ class DocSummaryPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1f858a7a-88d8-11ee-a658-02420a000163/Summarize%20your%20docs%20with%20gpt.png.png" workflow = Workflow.DOC_SUMMARY slug_versions = ["doc-summary"] + sdk_method_name = "docSummary" price = 225 diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index 3ebc161a9..408a52715 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -27,6 +27,7 @@ class EmailFaceInpaintingPage(FaceInpaintingPage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ec0df5aa-9521-11ee-93d3-02420a0001e5/Email%20Profile%20Lookup.png.png" workflow = Workflow.EMAIL_FACE_INPAINTING slug_versions = ["EmailFaceInpainting", "ai-image-from-email-lookup"] + sdk_method_name = "imageFromEmail" sane_defaults = { "num_outputs": 1, diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index 8770740fb..ae76609fd 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -33,6 +33,7 @@ class FaceInpaintingPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/10c2ce06-88da-11ee-b428-02420a000168/ai%20image%20with%20a%20face.png.png" workflow = Workflow.FACE_INPAINTING slug_versions = ["FaceInpainting", "face-in-ai-generated-photo"] + sdk_method_name = "portrait" sane_defaults = { "num_outputs": 1, diff --git a/recipes/Functions.py b/recipes/Functions.py index 356381343..9a40559eb 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -22,6 +22,7 @@ class FunctionsPage(BasePage): title = "Functions" workflow = Workflow.FUNCTIONS slug_versions = ["functions", "tools", "function", "fn", "functions"] + sdk_method_name = "functions" show_settings = False price = 1 diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 6a9eaf999..8dc4288cf 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -1,7 +1,7 @@ import typing from furl import furl -from pydantic import BaseModel, Field +from pydantic import BaseModel import gooey_gui as gui from bots.models import Workflow @@ -11,11 +11,7 @@ doc_search_advanced_settings, ) from daras_ai_v2.embedding_model import EmbeddingModels -from daras_ai_v2.language_model import ( - run_language_model, - LargeLanguageModels, - ResponseFormatType, -) +from daras_ai_v2.language_model import run_language_model, LargeLanguageModels from daras_ai_v2.language_model_settings_widgets import ( language_model_settings, language_model_selector, @@ -51,6 +47,7 @@ class GoogleGPTPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/28649544-9406-11ee-bba3-02420a0001cc/Websearch%20GPT%20option%202.png.png" workflow = Workflow.GOOGLE_GPT slug_versions = ["google-gpt"] + sdk_method_name = "webSearchLLM" price = 175 diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index a94e812ce..b719d2b58 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -37,6 +37,7 @@ class GoogleImageGenPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/eb23c078-88da-11ee-aa86-02420a000165/web%20search%20render.png.png" workflow = Workflow.GOOGLE_IMAGE_GEN slug_versions = ["GoogleImageGen", "render-images-with-ai"] + sdk_method_name = "imageFromWebSearch" sane_defaults = dict( num_outputs=1, diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index f886c2b4e..7db0a16c6 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -37,6 +37,7 @@ class ImageSegmentationPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/06fc595e-88db-11ee-b428-02420a000168/AI%20Background%20Remover.png.png" workflow = Workflow.IMAGE_SEGMENTATION slug_versions = ["ImageSegmentation", "remove-image-background-with-ai"] + sdk_method_name = "removeBackground" sane_defaults = { "mask_threshold": 0.5, diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 97de89ab7..7cb9283a7 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -27,6 +27,7 @@ class Img2ImgPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/bcc9351a-88d9-11ee-bf6c-02420a000166/Edit%20an%20image%20with%20AI%201.png.png" workflow = Workflow.IMG_2_IMG slug_versions = ["Img2Img", "ai-photo-editor"] + sdk_method_name = "remixImage" sane_defaults = { "num_outputs": 1, diff --git a/recipes/LetterWriter.py b/recipes/LetterWriter.py index ff39fb3aa..53a0296af 100644 --- a/recipes/LetterWriter.py +++ b/recipes/LetterWriter.py @@ -18,6 +18,7 @@ class LetterWriterPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.LETTER_WRITER slug_versions = ["LetterWriter"] + sdk_method_name = "" class RequestModel(BasePage.RequestModel): action_id: str @@ -46,6 +47,10 @@ class ResponseModel(BaseModel): generated_input_prompt: str final_prompt: str + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return {"x-fern-ignore": True} + def render_description(self): gui.write( """ diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index df1354272..3fdf852f7 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -22,6 +22,7 @@ class LipsyncPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f33e6332-88d8-11ee-89f9-02420a000169/Lipsync%20TTS.png.png" workflow = Workflow.LIPSYNC slug_versions = ["Lipsync"] + sdk_method_name = "lipsync" class RequestModel(LipsyncSettings, BasePage.RequestModel): selected_model: typing.Literal[tuple(e.name for e in LipsyncModel)] = ( diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 79983f36e..1200da5c3 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -23,6 +23,7 @@ class LipsyncTTSPage(LipsyncPage, TextToSpeechPage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1acfa370-88d9-11ee-bf6c-02420a000166/Lipsync%20with%20audio%201.png.png" workflow = Workflow.LIPSYNC_TTS slug_versions = ["LipsyncTTS", "lipsync-maker"] + sdk_method_name = "lipsyncTTS" sane_defaults = { "elevenlabs_model": "eleven_multilingual_v2", diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index 3ec1f6f89..712399047 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -33,6 +33,7 @@ class ObjectInpaintingPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f07b731e-88d9-11ee-a658-02420a000163/W.I.3.png.png" workflow = Workflow.OBJECT_INPAINTING slug_versions = ["ObjectInpainting", "product-photo-background-generator"] + sdk_method_name = "productImage" sane_defaults = { "mask_threshold": 0.7, diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index cfa9ce013..12696e167 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -37,7 +37,6 @@ Schedulers, ) from daras_ai_v2.vcard import VCARD -from recipes.EmailFaceInpainting import get_photo_for_email from recipes.SocialLookupEmail import get_profile_for_email from url_shortener.models import ShortenedURL from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -58,6 +57,7 @@ class QRCodeGeneratorPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/03d6538e-88d5-11ee-ad97-02420a00016c/W.I.2.png.png" workflow = Workflow.QR_CODE slug_versions = ["art-qr-code", "qr", "qr-code"] + sdk_method_name = "qrCode" sane_defaults = dict( num_outputs=2, diff --git a/recipes/RelatedQnA.py b/recipes/RelatedQnA.py index 6372ce65e..0da240910 100644 --- a/recipes/RelatedQnA.py +++ b/recipes/RelatedQnA.py @@ -28,6 +28,7 @@ class RelatedQnAPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/37b0ba22-88d6-11ee-b549-02420a000167/People%20also%20ask.png.png" workflow = Workflow.RELATED_QNA_MAKER slug_versions = ["related-qna-maker"] + sdk_method_name = "seoPeopleAlsoAsk" price = 75 diff --git a/recipes/RelatedQnADoc.py b/recipes/RelatedQnADoc.py index 3f8c2d2d8..9dbdf3e5a 100644 --- a/recipes/RelatedQnADoc.py +++ b/recipes/RelatedQnADoc.py @@ -27,6 +27,7 @@ class RelatedQnADocPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.RELATED_QNA_MAKER_DOC slug_versions = ["related-qna-maker-doc"] + sdk_method_name = "seoPeopleAlsoAskDoc" price = 100 diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 1b4e65365..a6b3968e7 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -61,6 +61,7 @@ class SEOSummaryPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85f38b42-88d6-11ee-ad97-02420a00016c/Create%20SEO%20optimized%20content%20option%202.png.png" workflow = Workflow.SEO_SUMMARY slug_versions = ["SEOSummary", "seo-paragraph-generator"] + sdk_method_name = "seoContent" def preview_image(self, state: dict) -> str | None: return SEO_SUMMARY_DEFAULT_META_IMG diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index fa4066fc7..75a75e965 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -29,6 +29,7 @@ class SmartGPTPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ffd24ad8-88d7-11ee-a658-02420a000163/SmartGPT.png.png" workflow = Workflow.SMART_GPT slug_versions = ["SmartGPT"] + sdk_method_name = "smartGPT" price = 20 class RequestModelBase(BasePage.RequestModel): diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index bc3a0dea1..98a79b0c8 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -32,6 +32,7 @@ class SocialLookupEmailPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5fbd475a-88d7-11ee-aac9-02420a00016b/personalized%20email.png.png" workflow = Workflow.SOCIAL_LOOKUP_EMAIL slug_versions = ["SocialLookupEmail", "email-writer-with-profile-lookup"] + sdk_method_name = "personalizeEmail" sane_defaults = { "selected_model": LargeLanguageModels.gpt_4.name, diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index f3199a95b..f585de99e 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -32,6 +32,7 @@ class Text2AudioPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a4481d58-88d9-11ee-aa86-02420a000165/Text%20guided%20audio%20generator.png.png" workflow = Workflow.TEXT_2_AUDIO slug_versions = ["text2audio"] + sdk_method_name = "textToMusic" sane_defaults = dict( seed=42, diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index dbc877a17..5aba6d4df 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -69,6 +69,7 @@ class TextToSpeechPage(BasePage): "text2speech", "compare-text-to-speech-engines", ] + sdk_method_name = "textToSpeech" sane_defaults = { "tts_provider": TextToSpeechProviders.GOOGLE_TTS.value, diff --git a/recipes/Translation.py b/recipes/Translation.py index b8061beb2..a58edb96e 100644 --- a/recipes/Translation.py +++ b/recipes/Translation.py @@ -38,6 +38,7 @@ class TranslationPage(BasePage): title = "Compare AI Translations" workflow = Workflow.TRANSLATION slug_versions = ["translate", "translation", "compare-ai-translation"] + sdk_method_name = "translate" class RequestModelBase(BasePage.RequestModel): texts: list[str] = Field([]) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 41e7a33a8..8b5593af7 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -137,6 +137,9 @@ class VideoBotsPage(BasePage): workflow = Workflow.VIDEO_BOTS slug_versions = ["video-bots", "bots", "copilot"] + sdk_group_name = "copilot" + sdk_method_name = "completion" + functions_in_settings = False sane_defaults = { @@ -287,6 +290,13 @@ class ResponseModel(BaseModel): finish_reason: list[str] | None + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return { + "x-sdk-group-name": cls.sdk_group_name, + "x-sdk-method-name": cls.sdk_method_name, + } + def preview_image(self, state: dict) -> str | None: return DEFAULT_COPILOT_META_IMG diff --git a/recipes/asr_page.py b/recipes/asr_page.py index 58d49ffa4..6590e8a72 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -38,6 +38,7 @@ class AsrPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5fb7e5f6-88d9-11ee-aa86-02420a000165/Speech.png.png" workflow = Workflow.ASR slug_versions = ["asr", "speech"] + sdk_method_name = "speechRecognition" sane_defaults = dict(output_format=AsrOutputFormat.text.name) diff --git a/recipes/embeddings_page.py b/recipes/embeddings_page.py index e65580681..842677e2a 100644 --- a/recipes/embeddings_page.py +++ b/recipes/embeddings_page.py @@ -17,6 +17,7 @@ class EmbeddingsPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.EMBEDDINGS slug_versions = ["embeddings", "embed", "text-embedings"] + sdk_method_name = "embed" price = 1 class RequestModel(BasePage.RequestModel): diff --git a/routers/api.py b/routers/api.py index c1d2b7b99..dcbc923b4 100644 --- a/routers/api.py +++ b/routers/api.py @@ -152,6 +152,7 @@ def script_to_api(page_cls: typing.Type[BasePage]): tags=[page_cls.title], name=page_cls.title + " (v2 sync)", openapi_extra={"x-fern-ignore": True}, + include_in_schema=False, ) def run_api_json( request: Request, @@ -258,10 +259,7 @@ def run_api_form_async( operation_id="status__" + page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v3 status)", - openapi_extra={ - "x-fern-sdk-return-value": "output", - **page_cls.get_openapi_extra(), - }, + openapi_extra={"x-fern-ignore": True}, ) def get_run_status( run_id: str, @@ -445,6 +443,6 @@ def get_balance(user: AppUser = Depends(api_auth_header)): return BalanceResponse(balance=user.balance) -@app.get("/status") +@app.get("/status", openapi_extra={"x-fern-ignore": True}) async def health(): return "OK" From 14344aa030b4a8d6e7479728c5772b0ee8353558 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:28:07 +0530 Subject: [PATCH 06/38] Fix get_openapi_extra method on BasePage to use sdk_method_name --- daras_ai_v2/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index d4b38a784..ce2500c83 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -120,6 +120,7 @@ class BasePage: title: str workflow: Workflow slug_versions: list[str] + sdk_method_name: str sane_defaults: dict = {} @@ -307,8 +308,7 @@ def sentry_event_set_user(self, event, hint): @classmethod def get_openapi_extra(cls) -> dict[str, typing.Any]: return { - "x-fern-sdk-group-name": cls.slug_versions[-1].title().replace("-", ""), - "x-fern-sdk-method-name": "status", + "x-fern-sdk-method-name": cls.sdk_method_name, } def refresh_state(self): From 0c715a418e9c206754d6842a2f489e998840d5dc Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:29:02 +0530 Subject: [PATCH 07/38] fix method name for get_balance endpoint --- routers/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/routers/api.py b/routers/api.py index dcbc923b4..dfb55315a 100644 --- a/routers/api.py +++ b/routers/api.py @@ -438,7 +438,12 @@ class BalanceResponse(BaseModel): balance: int = Field(description="Current balance in credits") -@app.get("/v1/balance/", response_model=BalanceResponse, tags=["Misc"]) +@app.get( + "/v1/balance/", + response_model=BalanceResponse, + tags=["Misc"], + openapi_extra={"x-fern-sdk-method-name": "getBalance"}, +) def get_balance(user: AppUser = Depends(api_auth_header)): return BalanceResponse(balance=user.balance) From c0ac5d9363508c9beeadd128762ba0d1ba166e61 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:29:31 +0530 Subject: [PATCH 08/38] fix: ignore broadcast and bot APIs for SDK generation --- routers/bots_api.py | 2 ++ routers/broadcast_api.py | 1 + 2 files changed, 3 insertions(+) diff --git a/routers/bots_api.py b/routers/bots_api.py index 780a7b918..f913708ad 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -85,6 +85,7 @@ class CreateStreamResponse(BaseModel): operation_id=VideoBotsPage.slug_versions[0] + "__stream_create", tags=["Copilot Integrations"], name="Copilot Integrations Create Stream", + openapi_extra={"x-fern-ignore": True}, ) def stream_create(request: CreateStreamRequest, response: Response): request_id = str(uuid.uuid4()) @@ -173,6 +174,7 @@ class StreamError(BaseModel): operation_id=VideoBotsPage.slug_versions[0] + "__stream", tags=["Copilot Integrations"], name="Copilot integrations Stream Response", + openapi_extra={"x-fern-ignore": True}, ) def stream_response(request_id: str): r = get_redis_cache().getdel(f"gooey/stream-init/v1/{request_id}") diff --git a/routers/broadcast_api.py b/routers/broadcast_api.py index cbf1ca07c..82cf8ea59 100644 --- a/routers/broadcast_api.py +++ b/routers/broadcast_api.py @@ -51,6 +51,7 @@ class BotBroadcastRequestModel(BaseModel): operation_id=VideoBotsPage.slug_versions[0] + "__broadcast", tags=["Misc"], name=f"Send Broadcast Message", + openapi_extra={"x-fern-ignore": True}, ) @app.post( f"/v2/{VideoBotsPage.slug_versions[0]}/broadcast/send", From 9d99420d2ae12ac6017d11924438788e75b7c1bb Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:33:32 +0530 Subject: [PATCH 09/38] use GooeyEnum for LargeLanguageModels with api_enum() method to generated named Enums --- daras_ai_v2/custom_enum.py | 25 +++++++++++++++++++++++++ daras_ai_v2/language_model.py | 3 ++- recipes/BulkEval.py | 4 +--- recipes/CompareLLM.py | 8 ++------ recipes/DocExtract.py | 4 +--- recipes/DocSearch.py | 4 +--- recipes/DocSummary.py | 4 +--- recipes/GoogleGPT.py | 4 +--- recipes/SEOSummary.py | 4 +--- recipes/SmartGPT.py | 4 +--- recipes/SocialLookupEmail.py | 4 +--- recipes/VideoBots.py | 4 +--- 12 files changed, 38 insertions(+), 34 deletions(-) diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py index b9aacb843..c42c0a0a1 100644 --- a/daras_ai_v2/custom_enum.py +++ b/daras_ai_v2/custom_enum.py @@ -23,6 +23,31 @@ def from_db(cls, db_value) -> typing_extensions.Self: def api_choices(cls): return typing.Literal[tuple(e.name for e in cls)] + @classmethod + def api_enum( + cls, name: str | None = None, include_deprecated: bool = False + ) -> Enum: + """ + Dynamic Enum that maps from `model.name` to `model.name`. + + Preferred over `api_choices` because the Enum's name propagates to the + OpenAPI schema and to the SDK. + + The default enum maps from label->model. + """ + try: + deprecated_items = cls._deprecated() + except AttributeError: + deprecated_items = set() + return Enum( + name or cls.__name__, + { + e.name: e.name + for e in cls + if include_deprecated or e not in deprecated_items + }, + ) + @classmethod def from_api(cls, name: str) -> typing_extensions.Self: for e in cls: diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index d6eb526d2..f02b2c923 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -15,6 +15,7 @@ vertex_ai_should_retry, try_all, ) +from daras_ai_v2.custom_enum import GooeyEnum from django.conf import settings from loguru import logger from openai.types.chat import ( @@ -72,7 +73,7 @@ class LLMSpec(typing.NamedTuple): supports_json: bool = False -class LargeLanguageModels(Enum): +class LargeLanguageModels(GooeyEnum): # https://platform.openai.com/docs/models/gpt-4o gpt_4_o = LLMSpec( label="GPT-4o (openai)", diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py index 3510e396f..42b3e9d0a 100644 --- a/recipes/BulkEval.py +++ b/recipes/BulkEval.py @@ -186,9 +186,7 @@ class RequestModelBase(BasePage.RequestModel): """, ) - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 99390009f..c14d2f29b 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -44,17 +44,13 @@ class CompareLLMPage(BasePage): class RequestModelBase(BasePage.RequestModel): input_prompt: str | None - selected_models: list[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)] - ] + selected_models: list[LargeLanguageModels.api_enum()] | None class RequestModel(LanguageModelSettings, RequestModelBase): pass class ResponseModel(BaseModel): - output_text: dict[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)], list[str] - ] + output_text: dict[LargeLanguageModels.api_enum(), list[str]] def preview_image(self, state: dict) -> str | None: return DEFAULT_COMPARE_LM_META_IMG diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 63c741e38..66e7a0a55 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -105,9 +105,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 81ec19151..ea2c21558 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -72,9 +72,7 @@ class RequestModelBase(DocSearchRequest, BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 04f739af1..d64680e3c 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -69,9 +69,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None merge_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 8dc4288cf..470aec1e3 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -82,9 +82,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None max_search_urls: int | None diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index a6b3968e7..abc8020c4 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -100,9 +100,7 @@ class RequestModelBase(BaseModel): enable_html: bool | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None max_search_urls: int | None diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 75a75e965..74445fb64 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -39,9 +39,7 @@ class RequestModelBase(BasePage.RequestModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 98a79b0c8..aab4a8432 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -55,9 +55,7 @@ class RequestModelBase(BasePage.RequestModel): # domain: str | None # key_words: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 8b5593af7..5439bc1db 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -197,9 +197,7 @@ class RequestModelBase(BasePage.RequestModel): bot_script: str | None # llm model - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum() | None document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " From e1edb2ccc69faf1ee123cb4c24cc2dfe4ffb165c Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:34:04 +0530 Subject: [PATCH 10/38] fix sdk method name for copilot endpoint --- recipes/VideoBots.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 5439bc1db..9ff17cdad 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -291,8 +291,8 @@ class ResponseModel(BaseModel): @classmethod def get_openapi_extra(cls) -> dict[str, typing.Any]: return { - "x-sdk-group-name": cls.sdk_group_name, - "x-sdk-method-name": cls.sdk_method_name, + "x-fern-sdk-group-name": cls.sdk_group_name, + "x-fern-sdk-method-name": cls.sdk_method_name, } def preview_image(self, state: dict) -> str | None: From 325cfd24fdbf0f1eda33520bc7b371f28cd814b9 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 18:05:24 +0530 Subject: [PATCH 11/38] Revert "use GooeyEnum for LargeLanguageModels with api_enum() method to generated named Enums" This reverts commit 9d99420d2ae12ac6017d11924438788e75b7c1bb. --- daras_ai_v2/custom_enum.py | 25 ------------------------- daras_ai_v2/language_model.py | 3 +-- recipes/BulkEval.py | 4 +++- recipes/CompareLLM.py | 8 ++++++-- recipes/DocExtract.py | 4 +++- recipes/DocSearch.py | 4 +++- recipes/DocSummary.py | 4 +++- recipes/GoogleGPT.py | 4 +++- recipes/SEOSummary.py | 4 +++- recipes/SmartGPT.py | 4 +++- recipes/SocialLookupEmail.py | 4 +++- recipes/VideoBots.py | 4 +++- 12 files changed, 34 insertions(+), 38 deletions(-) diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py index c42c0a0a1..b9aacb843 100644 --- a/daras_ai_v2/custom_enum.py +++ b/daras_ai_v2/custom_enum.py @@ -23,31 +23,6 @@ def from_db(cls, db_value) -> typing_extensions.Self: def api_choices(cls): return typing.Literal[tuple(e.name for e in cls)] - @classmethod - def api_enum( - cls, name: str | None = None, include_deprecated: bool = False - ) -> Enum: - """ - Dynamic Enum that maps from `model.name` to `model.name`. - - Preferred over `api_choices` because the Enum's name propagates to the - OpenAPI schema and to the SDK. - - The default enum maps from label->model. - """ - try: - deprecated_items = cls._deprecated() - except AttributeError: - deprecated_items = set() - return Enum( - name or cls.__name__, - { - e.name: e.name - for e in cls - if include_deprecated or e not in deprecated_items - }, - ) - @classmethod def from_api(cls, name: str) -> typing_extensions.Self: for e in cls: diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index f02b2c923..d6eb526d2 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -15,7 +15,6 @@ vertex_ai_should_retry, try_all, ) -from daras_ai_v2.custom_enum import GooeyEnum from django.conf import settings from loguru import logger from openai.types.chat import ( @@ -73,7 +72,7 @@ class LLMSpec(typing.NamedTuple): supports_json: bool = False -class LargeLanguageModels(GooeyEnum): +class LargeLanguageModels(Enum): # https://platform.openai.com/docs/models/gpt-4o gpt_4_o = LLMSpec( label="GPT-4o (openai)", diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py index 42b3e9d0a..3510e396f 100644 --- a/recipes/BulkEval.py +++ b/recipes/BulkEval.py @@ -186,7 +186,9 @@ class RequestModelBase(BasePage.RequestModel): """, ) - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index c14d2f29b..99390009f 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -44,13 +44,17 @@ class CompareLLMPage(BasePage): class RequestModelBase(BasePage.RequestModel): input_prompt: str | None - selected_models: list[LargeLanguageModels.api_enum()] | None + selected_models: list[ + typing.Literal[tuple(e.name for e in LargeLanguageModels)] + ] class RequestModel(LanguageModelSettings, RequestModelBase): pass class ResponseModel(BaseModel): - output_text: dict[LargeLanguageModels.api_enum(), list[str]] + output_text: dict[ + typing.Literal[tuple(e.name for e in LargeLanguageModels)], list[str] + ] def preview_image(self, state: dict) -> str | None: return DEFAULT_COMPARE_LM_META_IMG diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 66e7a0a55..63c741e38 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -105,7 +105,9 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index ea2c21558..81ec19151 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -72,7 +72,9 @@ class RequestModelBase(DocSearchRequest, BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index d64680e3c..04f739af1 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -69,7 +69,9 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None merge_instructions: str | None - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 470aec1e3..8dc4288cf 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -82,7 +82,9 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) max_search_urls: int | None diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index abc8020c4..a6b3968e7 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -100,7 +100,9 @@ class RequestModelBase(BaseModel): enable_html: bool | None - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) max_search_urls: int | None diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 74445fb64..75a75e965 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -39,7 +39,9 @@ class RequestModelBase(BasePage.RequestModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index aab4a8432..98a79b0c8 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -55,7 +55,9 @@ class RequestModelBase(BasePage.RequestModel): # domain: str | None # key_words: str | None - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 9ff17cdad..dfa52045e 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -197,7 +197,9 @@ class RequestModelBase(BasePage.RequestModel): bot_script: str | None # llm model - selected_model: LargeLanguageModels.api_enum() | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " From 48abd5973c70a601692e2e5b760eee4cb2f36e9b Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 19:18:37 +0530 Subject: [PATCH 12/38] use api_enum method on GooeyEnum class, use it with LargeLanguageModels --- daras_ai_v2/base.py | 3 +++ daras_ai_v2/custom_enum.py | 23 +++++++++++++++++++++++ daras_ai_v2/language_model.py | 3 ++- recipes/BulkEval.py | 4 +--- recipes/CompareLLM.py | 8 ++------ recipes/DocExtract.py | 4 +--- recipes/DocSearch.py | 6 ++---- recipes/DocSummary.py | 7 ++----- recipes/GoogleGPT.py | 6 ++---- recipes/SEOSummary.py | 4 +--- recipes/SmartGPT.py | 6 ++---- recipes/SocialLookupEmail.py | 6 ++---- recipes/VideoBots.py | 4 +--- 13 files changed, 44 insertions(+), 40 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index ce2500c83..e65dab0bb 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -138,6 +138,9 @@ class BasePage: ) class RequestModel(BaseModel): + class Config: + use_enum_values = True + functions: list[RecipeFunction] | None = Field( title="🧩 Developer Tools and Functions", ) diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py index b9aacb843..56255d26d 100644 --- a/daras_ai_v2/custom_enum.py +++ b/daras_ai_v2/custom_enum.py @@ -23,6 +23,29 @@ def from_db(cls, db_value) -> typing_extensions.Self: def api_choices(cls): return typing.Literal[tuple(e.name for e in cls)] + @classmethod + @property + def api_enum(cls): + """ + Enum that is useful as a type in API requests. + + Maps `name`->`name` for all values. + The title (same as the Enum class name) will be + used as the new Enum's title. This will be passed + on to the OpenAPI schema and the generated SDK. + """ + # this cache is a hack to get around a bug where + # dynamic Enums with the same name crash when + # generating the OpenAPI spec + if not hasattr(cls, "_cached_api_enum"): + cls._cached_api_enum = {} + if cls.__name__ not in cls._cached_api_enum: + cls._cached_api_enum[cls.__name__] = Enum( + cls.__name__, {e.name: e.name for e in cls} + ) + + return cls._cached_api_enum[cls.__name__] + @classmethod def from_api(cls, name: str) -> typing_extensions.Self: for e in cls: diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index d6eb526d2..8bbf650fc 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -25,6 +25,7 @@ from daras_ai.image_input import gs_url_to_uri, bytes_to_cv2_img, cv2_img_to_bytes from daras_ai_v2.asr import get_google_auth_session +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import raise_for_status, UserError from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.text_splitter import ( @@ -72,7 +73,7 @@ class LLMSpec(typing.NamedTuple): supports_json: bool = False -class LargeLanguageModels(Enum): +class LargeLanguageModels(GooeyEnum): # https://platform.openai.com/docs/models/gpt-4o gpt_4_o = LLMSpec( label="GPT-4o (openai)", diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py index 3510e396f..62aabe8c3 100644 --- a/recipes/BulkEval.py +++ b/recipes/BulkEval.py @@ -186,9 +186,7 @@ class RequestModelBase(BasePage.RequestModel): """, ) - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 99390009f..0f46f4e48 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -44,17 +44,13 @@ class CompareLLMPage(BasePage): class RequestModelBase(BasePage.RequestModel): input_prompt: str | None - selected_models: list[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)] - ] + selected_models: list[LargeLanguageModels.api_enum] | None class RequestModel(LanguageModelSettings, RequestModelBase): pass class ResponseModel(BaseModel): - output_text: dict[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)], list[str] - ] + output_text: dict[LargeLanguageModels.api_choices, list[str]] def preview_image(self, state: dict) -> str | None: return DEFAULT_COMPARE_LM_META_IMG diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 63c741e38..a0d210fe0 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -105,9 +105,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 81ec19151..591574a6b 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -2,7 +2,7 @@ import typing from furl import furl -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -72,9 +72,7 @@ class RequestModelBase(DocSearchRequest, BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 04f739af1..d34368914 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -2,7 +2,7 @@ from enum import Enum from daras_ai_v2.pydantic_validation import FieldHttpUrl -from pydantic import BaseModel, Field +from pydantic import BaseModel import gooey_gui as gui from bots.models import Workflow @@ -16,7 +16,6 @@ LargeLanguageModels, run_language_model, calc_gpt_tokens, - ResponseFormatType, ) from daras_ai_v2.language_model_settings_widgets import ( language_model_settings, @@ -69,9 +68,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None merge_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 8dc4288cf..5686fd58a 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -1,7 +1,7 @@ import typing from furl import furl -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -82,9 +82,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None max_search_urls: int | None diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index a6b3968e7..6e6031206 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -100,9 +100,7 @@ class RequestModelBase(BaseModel): enable_html: bool | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None max_search_urls: int | None diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 75a75e965..7141132c8 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -1,7 +1,7 @@ import typing import jinja2.sandbox -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -39,9 +39,7 @@ class RequestModelBase(BasePage.RequestModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 98a79b0c8..f1e3a3bd5 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -2,7 +2,7 @@ import typing import requests -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -55,9 +55,7 @@ class RequestModelBase(BasePage.RequestModel): # domain: str | None # key_words: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index dfa52045e..05f7c2dd0 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -197,9 +197,7 @@ class RequestModelBase(BasePage.RequestModel): bot_script: str | None # llm model - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " From 90fad1269f60ab02814ff85c3db91f79b0385b33 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 20:32:38 +0530 Subject: [PATCH 13/38] use .api_enum in ResponseModel for CompareLLM --- recipes/CompareLLM.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 0f46f4e48..55d317176 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -50,7 +50,10 @@ class RequestModel(LanguageModelSettings, RequestModelBase): pass class ResponseModel(BaseModel): - output_text: dict[LargeLanguageModels.api_choices, list[str]] + class Config: + use_enum_values = True + + output_text: dict[LargeLanguageModels.api_enum, list[str]] def preview_image(self, state: dict) -> str | None: return DEFAULT_COMPARE_LM_META_IMG From 1f5150e6062e6374ebd4e0f9ef5f1799ce70d2de Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 20:53:41 +0530 Subject: [PATCH 14/38] Use .api_enum with LipsyncModel and rename LipsyncModel -> LipsyncModels We use plural models everywhere else --- daras_ai_v2/lipsync_api.py | 4 ++-- daras_ai_v2/lipsync_settings_widgets.py | 6 +++--- recipes/Lipsync.py | 18 ++++++++---------- recipes/LipsyncTTS.py | 8 +++----- recipes/VideoBots.py | 8 +++----- 5 files changed, 19 insertions(+), 25 deletions(-) diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index 4e05002c6..38416e166 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -1,15 +1,15 @@ import typing -from enum import Enum from loguru import logger from pydantic import BaseModel, Field +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import UserError, GPUError from daras_ai_v2.gpu_server import call_celery_task_outfile from daras_ai_v2.pydantic_validation import FieldHttpUrl -class LipsyncModel(Enum): +class LipsyncModels(GooeyEnum): Wav2Lip = "Rudrabha/Wav2Lip" SadTalker = "OpenTalker/SadTalker" diff --git a/daras_ai_v2/lipsync_settings_widgets.py b/daras_ai_v2/lipsync_settings_widgets.py index 515be000b..b06032686 100644 --- a/daras_ai_v2/lipsync_settings_widgets.py +++ b/daras_ai_v2/lipsync_settings_widgets.py @@ -1,15 +1,15 @@ import gooey_gui as gui from daras_ai_v2.field_render import field_label_val -from daras_ai_v2.lipsync_api import LipsyncModel, SadTalkerSettings +from daras_ai_v2.lipsync_api import LipsyncModels, SadTalkerSettings def lipsync_settings(selected_model: str): match selected_model: - case LipsyncModel.Wav2Lip.name: + case LipsyncModels.Wav2Lip.name: wav2lip_settings() gui.session_state.pop("sadtalker_settings", None) - case LipsyncModel.SadTalker.name: + case LipsyncModels.SadTalker.name: settings = SadTalkerSettings.parse_obj( gui.session_state.setdefault( "sadtalker_settings", SadTalkerSettings().dict() diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 3fdf852f7..84ffe5cca 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -8,7 +8,7 @@ from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.lipsync_api import run_wav2lip, run_sadtalker, LipsyncSettings -from daras_ai_v2.lipsync_settings_widgets import lipsync_settings, LipsyncModel +from daras_ai_v2.lipsync_settings_widgets import lipsync_settings, LipsyncModels from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.pydantic_validation import FieldHttpUrl @@ -25,9 +25,7 @@ class LipsyncPage(BasePage): sdk_method_name = "lipsync" class RequestModel(LipsyncSettings, BasePage.RequestModel): - selected_model: typing.Literal[tuple(e.name for e in LipsyncModel)] = ( - LipsyncModel.Wav2Lip.name - ) + selected_model: LipsyncModels.api_enum = LipsyncModels.Wav2Lip.name input_audio: FieldHttpUrl = None class ResponseModel(BaseModel): @@ -56,7 +54,7 @@ def render_form_v2(self): ) enum_selector( - LipsyncModel, + LipsyncModels, label="###### Lipsync Model", key="selected_model", use_selectbox=True, @@ -72,10 +70,10 @@ def render_settings(self): def run(self, state: dict) -> typing.Iterator[str | None]: request = self.RequestModel.parse_obj(state) - model = LipsyncModel[request.selected_model] + model = LipsyncModels[request.selected_model] yield f"Running {model.value}..." match model: - case LipsyncModel.Wav2Lip: + case LipsyncModels.Wav2Lip: state["output_video"] = run_wav2lip( face=request.input_face, audio=request.input_audio, @@ -86,7 +84,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: request.face_padding_right or 0, ), ) - case LipsyncModel.SadTalker: + case LipsyncModels.SadTalker: state["output_video"] = run_sadtalker( request.sadtalker_settings, face=request.input_face, @@ -121,7 +119,7 @@ def preview_description(self, state: dict) -> str: def get_cost_note(self) -> str | None: multiplier = ( 3 - if gui.session_state.get("lipsync_model") == LipsyncModel.SadTalker.name + if gui.session_state.get("lipsync_model") == LipsyncModels.SadTalker.name else 1 ) return f"{CREDITS_PER_MB * multiplier} credits per MB" @@ -141,6 +139,6 @@ def get_raw_price(self, state: dict) -> float: total_mb = total_bytes / 1024 / 1024 multiplier = ( - 3 if state.get("lipsync_model") == LipsyncModel.SadTalker.name else 1 + 3 if state.get("lipsync_model") == LipsyncModels.SadTalker.name else 1 ) return total_mb * CREDITS_PER_MB * multiplier diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 1200da5c3..49226995e 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -6,7 +6,7 @@ import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModel +from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModels from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.text_to_speech_settings_widgets import ( @@ -32,9 +32,7 @@ class LipsyncTTSPage(LipsyncPage, TextToSpeechPage): } class RequestModel(LipsyncSettings, TextToSpeechPage.RequestModel): - selected_model: typing.Literal[tuple(e.name for e in LipsyncModel)] = ( - LipsyncModel.Wav2Lip.name - ) + selected_model: LipsyncModels.api_enum = LipsyncModels.Wav2Lip.name class ResponseModel(BaseModel): audio_url: str | None @@ -70,7 +68,7 @@ def render_form_v2(self): ) enum_selector( - LipsyncModel, + LipsyncModels, label="###### Lipsync Model", key="selected_model", use_selectbox=True, diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 05f7c2dd0..8c268e634 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -79,7 +79,7 @@ language_model_selector, LanguageModelSettings, ) -from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModel +from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModels from daras_ai_v2.lipsync_settings_widgets import lipsync_settings from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.prompt_vars import render_prompt_vars @@ -251,9 +251,7 @@ class RequestModelBase(BasePage.RequestModel): """, ) - lipsync_model: typing.Literal[tuple(e.name for e in LipsyncModel)] = ( - LipsyncModel.Wav2Lip.name - ) + lipsync_model: LipsyncModels.api_enum = LipsyncModels.Wav2Lip.name tools: list[LLMTools] | None = Field( title="🛠️ Tools", @@ -381,7 +379,7 @@ def render_form_v2(self): key="input_face", ) enum_selector( - LipsyncModel, + LipsyncModels, label="###### Lipsync Model", key="lipsync_model", use_selectbox=True, From 146c771832f6b9619a498dab8dc3aab745776415 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 20:56:52 +0530 Subject: [PATCH 15/38] use GooeyEnum for AnimationModels --- recipes/DeforumSD.py | 31 +++++++++++++++++++++-------- scripts/init_self_hosted_pricing.py | 2 +- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index 4cd930cfe..db47dabe1 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -1,8 +1,8 @@ import typing import uuid +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.pydantic_validation import FieldHttpUrl -from django.db.models import TextChoices from pydantic import BaseModel from typing_extensions import TypedDict @@ -18,17 +18,28 @@ DEFAULT_DEFORUMSD_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7dc25196-93fe-11ee-9e3a-02420a0001ce/AI%20Animation%20generator.jpg.png" -class AnimationModels(TextChoices): - protogen_2_2 = ("Protogen_V2.2.ckpt", "Protogen V2.2 (darkstorm2150)") - epicdream = ("epicdream.safetensors", "epiCDream (epinikion)") +class AnimationModel(typing.NamedTuple): + model_id: str + label: str -class _AnimationPrompt(TypedDict): +class AnimationModels(AnimationModel, GooeyEnum): + protogen_2_2 = AnimationModel( + model_id="Protogen_V2.2.ckpt", + label="Protogen V2.2 (darkstorm2150)", + ) + epicdream = AnimationModel( + model_id="epicdream.safetensors", + label="epiCDream (epinikion)", + ) + + +class AnimationPrompt(TypedDict): frame: str prompt: str -AnimationPrompts = list[_AnimationPrompt] +AnimationPrompts = list[AnimationPrompt] CREDITS_PER_FRAME = 1.5 MODEL_ESTIMATED_TIME_PER_FRAME = 2.4 # seconds @@ -186,7 +197,7 @@ class RequestModel(BasePage.RequestModel): animation_prompts: AnimationPrompts max_frames: int | None - selected_model: typing.Literal[tuple(e.name for e in AnimationModels)] | None + selected_model: AnimationModels.api_enum | None animation_mode: str | None zoom: str | None @@ -456,11 +467,15 @@ def run(self, state: dict): if not self.request.user.disable_safety_checker: safety_checker(text=self.preview_input(state)) + print("selected_model", request.selected_model) + print(f'{state["selected_model"]=}') + print(f"{type(request.selected_model)=}") + try: state["output_video"] = call_celery_task_outfile( "deforum", pipeline=dict( - model_id=AnimationModels[request.selected_model].value, + model_id=AnimationModels[request.selected_model].model_id, seed=request.seed, ), inputs=dict( diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index d1df62a58..09dd7b8e3 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -18,7 +18,7 @@ def run(): for model in AnimationModels: - add_model(model.value, model.name) + add_model(model.model_id, model.name) for model_enum, model_ids in [ (Text2ImgModels, text2img_model_ids), (Img2ImgModels, img2img_model_ids), From e8a342b673fd1b8330d2b8645191fada8abfb159 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 12 Sep 2024 21:12:00 +0530 Subject: [PATCH 16/38] fix: AsrModels to use GooeyEnum --- daras_ai_v2/asr.py | 3 ++- recipes/DocExtract.py | 2 +- recipes/DocSummary.py | 2 +- recipes/VideoBots.py | 2 +- recipes/asr_page.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 699e9fe2a..ef715aa99 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -14,6 +14,7 @@ from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings from daras_ai_v2.azure_asr import azure_asr +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import ( raise_for_status, UserError, @@ -204,7 +205,7 @@ GHANA_NLP_MAXLEN = 500 -class AsrModels(Enum): +class AsrModels(GooeyEnum): whisper_large_v2 = "Whisper Large v2 (openai)" whisper_large_v3 = "Whisper Large v3 (openai)" whisper_hindi_large_v2 = "Whisper Hindi Large v2 (Bhashini)" diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index a0d210fe0..9a41f4f19 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -94,7 +94,7 @@ class RequestModelBase(BasePage.RequestModel): sheet_url: FieldHttpUrl | None - selected_asr_model: typing.Literal[tuple(e.name for e in AsrModels)] | None + selected_asr_model: AsrModels.api_enum | None # language: str | None google_translate_target: str | None glossary_document: FieldHttpUrl | None = Field( diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index d34368914..8e348d96d 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -72,7 +72,7 @@ class RequestModelBase(BasePage.RequestModel): chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None - selected_asr_model: typing.Literal[tuple(e.name for e in AsrModels)] | None + selected_asr_model: AsrModels.api_enum | None google_translate_target: str | None class RequestModel(LanguageModelSettings, RequestModelBase): diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 8c268e634..ff9621b66 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -221,7 +221,7 @@ class RequestModelBase(BasePage.RequestModel): citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None use_url_shortener: bool | None - asr_model: typing.Literal[tuple(e.name for e in AsrModels)] | None = Field( + asr_model: AsrModels.api_enum | None = Field( title="Speech-to-Text Provider", description="Choose a model to transcribe incoming audio messages to text.", ) diff --git a/recipes/asr_page.py b/recipes/asr_page.py index 6590e8a72..435cd206d 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -44,7 +44,7 @@ class AsrPage(BasePage): class RequestModelBase(BasePage.RequestModel): documents: list[FieldHttpUrl] - selected_model: typing.Literal[tuple(e.name for e in AsrModels)] | None + selected_model: AsrModels.api_enum | None language: str | None translation_model: ( From cbb3a091222486c604b1078660ec0d636df63110 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:47:03 +0530 Subject: [PATCH 17/38] Use GooeyEnum for EmbeddingModels --- daras_ai_v2/embedding_model.py | 12 ++++-------- daras_ai_v2/vector_search.py | 5 ++++- embeddings/models.py | 2 +- recipes/GoogleGPT.py | 2 +- recipes/VideoBots.py | 2 +- recipes/embeddings_page.py | 2 +- 6 files changed, 12 insertions(+), 13 deletions(-) diff --git a/daras_ai_v2/embedding_model.py b/daras_ai_v2/embedding_model.py index 9f5f3ae1b..5f2927bcc 100644 --- a/daras_ai_v2/embedding_model.py +++ b/daras_ai_v2/embedding_model.py @@ -1,7 +1,6 @@ import hashlib import io import typing -from enum import Enum from functools import partial import numpy as np @@ -13,6 +12,7 @@ from jinja2.lexer import whitespace_re from loguru import logger +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.language_model import get_openai_client from daras_ai_v2.redis_cache import ( @@ -25,7 +25,7 @@ class EmbeddingModel(typing.NamedTuple): label: str -class EmbeddingModels(Enum): +class EmbeddingModels(EmbeddingModel, GooeyEnum): openai_3_large = EmbeddingModel( model_id=("openai-text-embedding-3-large-prod-ca-1", "text-embedding-3-large"), label="Text Embedding 3 Large (OpenAI)", @@ -65,12 +65,8 @@ class EmbeddingModels(Enum): ) @property - def model_id(self) -> typing.Iterable[str] | str: - return self.value.model_id - - @property - def label(self) -> str: - return self.value.label + def db_value(self): + return self.name @classmethod def get(cls, key, default=None): diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 36de65f29..fe694d56b 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -71,6 +71,9 @@ class DocSearchRequest(BaseModel): + class Config: + use_enum_values = True + search_query: str keyword_query: str | list[str] | None @@ -82,7 +85,7 @@ class DocSearchRequest(BaseModel): doc_extract_url: str | None - embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + embedding_model: EmbeddingModels.api_enum | None dense_weight: float | None = Field( ge=0.0, le=1.0, diff --git a/embeddings/models.py b/embeddings/models.py index f30b1a0bf..f5f350bbc 100644 --- a/embeddings/models.py +++ b/embeddings/models.py @@ -28,7 +28,7 @@ class EmbeddedFile(models.Model): selected_asr_model = models.CharField(max_length=100, blank=True) embedding_model = models.CharField( max_length=100, - choices=[(model.name, model.label) for model in EmbeddingModels], + choices=EmbeddingModels.db_choices(), default=EmbeddingModels.openai_3_large.name, ) diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 5686fd58a..ca43253af 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -90,7 +90,7 @@ class RequestModelBase(BasePage.RequestModel): max_context_words: int | None scroll_jump: int | None - embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + embedding_model: EmbeddingModels.api_enum | None dense_weight: float | None = DocSearchRequest.__fields__[ "dense_weight" ].field_info diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index ff9621b66..9ff802a20 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -213,7 +213,7 @@ class RequestModelBase(BasePage.RequestModel): max_context_words: int | None scroll_jump: int | None - embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + embedding_model: EmbeddingModels.api_enum | None dense_weight: float | None = DocSearchRequest.__fields__[ "dense_weight" ].field_info diff --git a/recipes/embeddings_page.py b/recipes/embeddings_page.py index 842677e2a..d3cff0244 100644 --- a/recipes/embeddings_page.py +++ b/recipes/embeddings_page.py @@ -22,7 +22,7 @@ class EmbeddingsPage(BasePage): class RequestModel(BasePage.RequestModel): texts: list[str] - selected_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + selected_model: EmbeddingModels.api_enum | None class ResponseModel(BaseModel): embeddings: list[list[float]] From 76fe9aa8fe912267ac12fa24c2f785e268983467 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:48:26 +0530 Subject: [PATCH 18/38] Refactor GooeyEnum to separate .name and .api_value We have a different name vs api_value for SerpSearchLocation --- daras_ai_v2/custom_enum.py | 51 ++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py index 56255d26d..1f781afb9 100644 --- a/daras_ai_v2/custom_enum.py +++ b/daras_ai_v2/custom_enum.py @@ -1,12 +1,38 @@ +import functools import typing from enum import Enum import typing_extensions + +def cached_classmethod(func: typing.Callable): + """ + This cache is a hack to get around a bug where + dynamic Enums with the same name cause a crash + when generating the OpenAPI spec. + """ + + @functools.wraps(func) + def wrapper(cls): + if not hasattr(cls, "_cached_classmethod"): + cls._cached_classmethod = {} + if id(func) not in cls._cached_classmethod: + cls._cached_classmethod[id(func)] = func(cls) + + return cls._cached_classmethod[id(func)] + + return wrapper + + T = typing.TypeVar("T", bound="GooeyEnum") class GooeyEnum(Enum): + @property + def api_value(self): + # api_value is usually the name + return self.name + @classmethod def db_choices(cls): return [(e.db_value, e.label) for e in cls] @@ -21,34 +47,27 @@ def from_db(cls, db_value) -> typing_extensions.Self: @classmethod @property def api_choices(cls): - return typing.Literal[tuple(e.name for e in cls)] + return typing.Literal[tuple(e.api_value for e in cls)] @classmethod @property + @cached_classmethod def api_enum(cls): """ Enum that is useful as a type in API requests. - Maps `name`->`name` for all values. + Maps `api_value`->`api_value` (default: `name`->`name`) + for all values. + The title (same as the Enum class name) will be used as the new Enum's title. This will be passed on to the OpenAPI schema and the generated SDK. """ - # this cache is a hack to get around a bug where - # dynamic Enums with the same name crash when - # generating the OpenAPI spec - if not hasattr(cls, "_cached_api_enum"): - cls._cached_api_enum = {} - if cls.__name__ not in cls._cached_api_enum: - cls._cached_api_enum[cls.__name__] = Enum( - cls.__name__, {e.name: e.name for e in cls} - ) - - return cls._cached_api_enum[cls.__name__] + return Enum(cls.__name__, {e.api_value: e.api_value for e in cls}) @classmethod - def from_api(cls, name: str) -> typing_extensions.Self: + def from_api(cls, api_value: str) -> typing_extensions.Self: for e in cls: - if e.name == name: + if e.api_value == api_value: return e - raise ValueError(f"Invalid {cls.__name__} {name=}") + raise ValueError(f"Invalid {cls.__name__} {api_value=}") From aacf0b8baf12b01a90c161f99c93fca456dae49c Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:51:55 +0530 Subject: [PATCH 19/38] Use GooeyEnum for SerpSearchType & SerpSearchLocation --- daras_ai_v2/serp_search.py | 6 ++-- daras_ai_v2/serp_search_locations.py | 48 ++++++++++++++++++---------- recipes/GoogleGPT.py | 16 +++++----- recipes/GoogleImageGen.py | 12 +++---- recipes/RelatedQnA.py | 11 +++---- recipes/RelatedQnADoc.py | 11 +++---- recipes/SEOSummary.py | 12 +++---- 7 files changed, 63 insertions(+), 53 deletions(-) diff --git a/daras_ai_v2/serp_search.py b/daras_ai_v2/serp_search.py index b95586461..7dea257de 100644 --- a/daras_ai_v2/serp_search.py +++ b/daras_ai_v2/serp_search.py @@ -20,7 +20,7 @@ def get_related_questions_from_serp_api( ) -> tuple[dict, list[str]]: data = call_serp_api( search_query, - search_type=SerpSearchType.SEARCH, + search_type=SerpSearchType.search, search_location=search_location, ) items = data.get("peopleAlsoAsk", []) or data.get("relatedSearches", []) @@ -66,10 +66,10 @@ def call_serp_api( search_location: SerpSearchLocation, ) -> dict: r = requests.post( - "https://google.serper.dev/" + search_type.value, + "https://google.serper.dev/" + search_type.api_value, json=dict( q=query, - gl=search_location.value, + gl=search_location.api_value, ), headers={"X-API-KEY": settings.SERPER_API_KEY}, ) diff --git a/daras_ai_v2/serp_search_locations.py b/daras_ai_v2/serp_search_locations.py index 2c26dc8a9..b2049e1ef 100644 --- a/daras_ai_v2/serp_search_locations.py +++ b/daras_ai_v2/serp_search_locations.py @@ -1,8 +1,11 @@ -from django.db.models import TextChoices +import typing + from pydantic import BaseModel from pydantic import Field import gooey_gui as gui +from daras_ai_v2.custom_enum import GooeyEnum +from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.field_render import field_title_desc @@ -26,10 +29,10 @@ def serp_search_settings(): def serp_search_type_selectbox(key="serp_search_type"): - gui.selectbox( + enum_selector( + SerpSearchType, f"###### {field_title_desc(GoogleSearchMixin, key)}", - options=SerpSearchType, - format_func=lambda x: x.label, + use_selectbox=True, key=key, ) @@ -37,22 +40,35 @@ def serp_search_type_selectbox(key="serp_search_type"): def serp_search_location_selectbox(key="serp_search_location"): gui.selectbox( f"###### {field_title_desc(GoogleSearchMixin, key)}", - options=SerpSearchLocation, - format_func=lambda x: f"{x.label} ({x.value})", + options=[e.api_value for e in SerpSearchLocations], + format_func=lambda e: f"{SerpSearchLocations.from_api(e).label} ({e})", key=key, - value=SerpSearchLocation.UNITED_STATES, + value=SerpSearchLocations.UNITED_STATES.name, ) -class SerpSearchType(TextChoices): - SEARCH = "search", "🔎 Search" - IMAGES = "images", "📷 Images" - VIDEOS = "videos", "🎥 Videos" - PLACES = "places", "📍 Places" - NEWS = "news", "📰 News" +class SerpSearchType(GooeyEnum): + search = "🔎 Search" + images = "📷 Images" + videos = "🎥 Videos" + places = "📍 Places" + news = "📰 News" + + @property + def label(self): + return self.value + + @property + def api_value(self): + return self.name + + +class SerpSearchLocation(typing.NamedTuple): + api_value: str + label: str -class SerpSearchLocation(TextChoices): +class SerpSearchLocations(SerpSearchLocation, GooeyEnum): AFGHANISTAN = "af", "Afghanistan" ALBANIA = "al", "Albania" ALGERIA = "dz", "Algeria" @@ -304,7 +320,7 @@ class SerpSearchLocation(TextChoices): class GoogleSearchLocationMixin(BaseModel): - serp_search_location: SerpSearchLocation | None = Field( + serp_search_location: SerpSearchLocations.api_enum | None = Field( title="Web Search Location", ) scaleserp_locations: list[str] | None = Field( @@ -313,7 +329,7 @@ class GoogleSearchLocationMixin(BaseModel): class GoogleSearchMixin(GoogleSearchLocationMixin, BaseModel): - serp_search_type: SerpSearchType | None = Field( + serp_search_type: SerpSearchType.api_enum | None = Field( title="Web Search Type", ) scaleserp_search_field: str | None = Field( diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index ca43253af..28e41b34a 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -1,7 +1,7 @@ import typing from furl import furl -from pydantic import BaseModel, Field +from pydantic import BaseModel import gooey_gui as gui from bots.models import Workflow @@ -27,9 +27,9 @@ from daras_ai_v2.serp_search import get_links_from_serp_api from daras_ai_v2.serp_search_locations import ( GoogleSearchMixin, - serp_search_settings, - SerpSearchLocation, + SerpSearchLocations, SerpSearchType, + serp_search_settings, ) from daras_ai_v2.vector_search import render_sources_widget from recipes.DocSearch import ( @@ -56,8 +56,8 @@ class GoogleGPTPage(BasePage): keywords="outdoor rugs,8x10 rugs,rug sizes,checkered rugs,5x7 rugs", title="Ruggable", company_url="https://ruggable.com", - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search, + serp_search_location=SerpSearchLocations.UNITED_STATES.name, enable_html=False, selected_model=LargeLanguageModels.text_davinci_003.name, sampling_temperature=0.8, @@ -211,7 +211,7 @@ def run_v2( self, request: "GoogleGPTPage.RequestModel", response: "GoogleGPTPage.ResponseModel", - ): + ) -> typing.Iterator[str | None]: model = LargeLanguageModels[request.selected_model] query_instructions = (request.query_instructions or "").strip() @@ -231,8 +231,8 @@ def run_v2( ) response.serp_results, links = get_links_from_serp_api( response.final_search_query, - search_type=request.serp_search_type, - search_location=request.serp_search_location, + search_type=SerpSearchType.from_api(request.serp_search_type), + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) # extract links & their corresponding titles link_titles = {item.url: f"{item.title} | {item.snippet}" for item in links} diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index b719d2b58..7b5636dd8 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -17,10 +17,10 @@ from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.serp_search import call_serp_api from daras_ai_v2.serp_search_locations import ( - serp_search_location_selectbox, GoogleSearchLocationMixin, + SerpSearchLocations, SerpSearchType, - SerpSearchLocation, + serp_search_location_selectbox, ) from daras_ai_v2.stable_diffusion import ( img2img, @@ -47,8 +47,8 @@ class GoogleImageGenPage(BasePage): sd_2_upscaling=False, seed=42, image_guidance_scale=1.2, - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search, + serp_search_location=SerpSearchLocations.UNITED_STATES.name, ) class RequestModel(GoogleSearchLocationMixin, BasePage.RequestModel): @@ -113,8 +113,8 @@ def run(self, state: dict): serp_results = call_serp_api( request.search_query, - search_type=SerpSearchType.IMAGES, - search_location=request.serp_search_location, + search_type=SerpSearchType.images, + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) image_urls = [ link diff --git a/recipes/RelatedQnA.py b/recipes/RelatedQnA.py index 0da240910..ee1a064e4 100644 --- a/recipes/RelatedQnA.py +++ b/recipes/RelatedQnA.py @@ -8,10 +8,7 @@ LargeLanguageModels, ) from daras_ai_v2.serp_search import get_related_questions_from_serp_api -from daras_ai_v2.serp_search_locations import ( - SerpSearchLocation, - SerpSearchType, -) +from daras_ai_v2.serp_search_locations import SerpSearchLocations, SerpSearchType from recipes.DocSearch import render_doc_search_step, EmptySearchResults from recipes.GoogleGPT import GoogleGPTPage from recipes.RelatedQnADoc import render_qna_outputs @@ -37,8 +34,8 @@ class RelatedQnAPage(BasePage): max_context_words=200, scroll_jump=5, dense_weight=1.0, - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search, + serp_search_location=SerpSearchLocations.UNITED_STATES.name, ) class RequestModel(GoogleGPTPage.RequestModel): @@ -118,7 +115,7 @@ def run_v2( related_questions, ) = get_related_questions_from_serp_api( request.search_query, - search_location=request.serp_search_location, + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) all_questions = [request.search_query] + related_questions[:9] diff --git a/recipes/RelatedQnADoc.py b/recipes/RelatedQnADoc.py index 9dbdf3e5a..3e4fab620 100644 --- a/recipes/RelatedQnADoc.py +++ b/recipes/RelatedQnADoc.py @@ -7,10 +7,7 @@ from daras_ai_v2.language_model import LargeLanguageModels from daras_ai_v2.search_ref import CitationStyles from daras_ai_v2.serp_search import get_related_questions_from_serp_api -from daras_ai_v2.serp_search_locations import ( - SerpSearchLocation, - SerpSearchType, -) +from daras_ai_v2.serp_search_locations import SerpSearchLocations, SerpSearchType from daras_ai_v2.vector_search import render_sources_widget from recipes.DocSearch import DocSearchPage, render_doc_search_step, EmptySearchResults from recipes.GoogleGPT import render_output_with_refs, GoogleSearchMixin @@ -34,8 +31,8 @@ class RelatedQnADocPage(BasePage): sane_defaults = dict( citation_style=CitationStyles.number.name, dense_weight=1.0, - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search, + serp_search_location=SerpSearchLocations.UNITED_STATES.name, ) class RequestModel(GoogleSearchMixin, DocSearchPage.RequestModel): @@ -112,7 +109,7 @@ def run_v2( related_questions, ) = get_related_questions_from_serp_api( request.search_query, - search_location=request.serp_search_location, + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) all_questions = [request.search_query] + related_questions[:9] diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 6e6031206..ad40e3d94 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -28,9 +28,9 @@ from daras_ai_v2.scrollable_html_widget import scrollable_html from daras_ai_v2.serp_search import get_links_from_serp_api from daras_ai_v2.serp_search_locations import ( - serp_search_settings, - SerpSearchLocation, + SerpSearchLocations, SerpSearchType, + serp_search_settings, ) from daras_ai_v2.settings import EXTERNAL_REQUEST_TIMEOUT_SEC from recipes.GoogleGPT import GoogleSearchMixin @@ -74,8 +74,8 @@ def preview_description(self, state: dict) -> str: keywords="outdoor rugs,8x10 rugs,rug sizes,checkered rugs,5x7 rugs", title="Ruggable", company_url="https://ruggable.com", - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search, + serp_search_location=SerpSearchLocations.UNITED_STATES.name, enable_html=False, selected_model=LargeLanguageModels.text_davinci_003.name, sampling_temperature=0.8, @@ -274,7 +274,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: serp_results, links = get_links_from_serp_api( request.search_query, search_type=request.serp_search_type, - search_location=request.serp_search_location, + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) state["serp_results"] = serp_results state["search_urls"] = [it.url for it in links] @@ -314,7 +314,7 @@ def _crosslink_keywords(output_content, request): lambda keyword: get_links_from_serp_api( f"site:{host} {keyword}", search_type=request.serp_search_type, - search_location=request.serp_search_location, + search_location=SerpSearchLocations.from_api(request.serp_search_location), )[1], relevant_keywords, ) From 3f8d305faa76c427f9ccbdc1dee09e00dbf7c6f8 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:14:44 +0530 Subject: [PATCH 20/38] Add title for QRCode VCard field to VCard --- recipes/QRCodeGenerator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 12696e167..8a8d4921f 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -8,7 +8,7 @@ from django.core.exceptions import ValidationError from django.core.validators import URLValidator from furl import furl -from pydantic import BaseModel +from pydantic import BaseModel, Field from pyzbar import pyzbar import gooey_gui as gui @@ -82,7 +82,7 @@ def __init__(self, *args, **kwargs): class RequestModel(BasePage.RequestModel): qr_code_data: str | None qr_code_input_image: FieldHttpUrl | None - qr_code_vcard: VCARD | None + qr_code_vcard: VCARD | None = Field(title="VCard") qr_code_file: FieldHttpUrl | None use_url_shortener: bool | None From 83cacc677e85965c9a205d59d537abf53bab1933 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:29:16 +0530 Subject: [PATCH 21/38] fix serp types for SEOSummary recipe --- recipes/SEOSummary.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index ad40e3d94..c53d03333 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -91,6 +91,9 @@ def preview_description(self, state: dict) -> str: ) class RequestModelBase(BaseModel): + class Config: + use_enum_values = True + search_query: str keywords: str title: str @@ -273,7 +276,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: serp_results, links = get_links_from_serp_api( request.search_query, - search_type=request.serp_search_type, + search_type=SerpSearchType.from_api(request.serp_search_type), search_location=SerpSearchLocations.from_api(request.serp_search_location), ) state["serp_results"] = serp_results @@ -313,7 +316,7 @@ def _crosslink_keywords(output_content, request): all_results = map_parallel( lambda keyword: get_links_from_serp_api( f"site:{host} {keyword}", - search_type=request.serp_search_type, + search_type=SerpSearchType.from_api(request.serp_search_type), search_location=SerpSearchLocations.from_api(request.serp_search_location), )[1], relevant_keywords, From bee12c4d4c9c3339da162c4db0433ca42399bcdc Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:36:21 +0530 Subject: [PATCH 22/38] Use Enum for ResponseFormatType --- daras_ai_v2/language_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 8bbf650fc..ce7d85d4c 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -459,7 +459,9 @@ def get_entry_text(entry: ConversationEntry) -> str: ) -ResponseFormatType = typing.Literal["text", "json_object"] +class ResponseFormatType(str, GooeyEnum): + text = "text" + json_object = "json_object" def run_language_model( From c8f296d62933a5f9cac3b5ed93dce9aa7dd693ef Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 16:27:41 +0530 Subject: [PATCH 23/38] Use GooeyEnum for ControlNetModels --- daras_ai_v2/img_model_settings_widgets.py | 13 ++- daras_ai_v2/stable_diffusion.py | 105 +++++++++++++--------- recipes/Img2Img.py | 4 +- recipes/QRCodeGenerator.py | 8 +- 4 files changed, 72 insertions(+), 58 deletions(-) diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index 7f62fb878..733465cd1 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -6,7 +6,6 @@ InpaintingModels, Img2ImgModels, ControlNetModels, - controlnet_model_explanations, Schedulers, ) @@ -130,9 +129,7 @@ def controlnet_settings( if not models: return - if extra_explanations is None: - extra_explanations = {} - explanations = controlnet_model_explanations | extra_explanations + extra_explanations = extra_explanations or {} state_values = gui.session_state.get("controlnet_conditioning_scale", []) new_values = [] @@ -157,7 +154,9 @@ def controlnet_settings( pass new_values.append( controlnet_weight_setting( - selected_controlnet_model=model, explanations=explanations, key=key + selected_controlnet_model=model, + extra_explanations=extra_explanations, + key=key, ), ) gui.session_state["controlnet_conditioning_scale"] = new_values @@ -166,13 +165,13 @@ def controlnet_settings( def controlnet_weight_setting( *, selected_controlnet_model: str, - explanations: dict[ControlNetModels, str], + extra_explanations: dict[ControlNetModels, str], key: str = "controlnet_conditioning_scale", ): model = ControlNetModels[selected_controlnet_model] return gui.slider( label=f""" - {explanations[model]}. + {extra_explanations.get(model, model.explanation)}. """, key=key, min_value=CONTROLNET_CONDITIONING_SCALE_RANGE[0], diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index f59b044c3..99d3e6608 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -13,6 +13,7 @@ resize_img_fit, get_downscale_factor, ) +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import ( raise_for_status, UserError, @@ -119,47 +120,68 @@ def _deprecated(cls): } -class ControlNetModels(Enum): - sd_controlnet_canny = "Canny" - sd_controlnet_depth = "Depth" - sd_controlnet_hed = "HED Boundary" - sd_controlnet_mlsd = "M-LSD Straight Line" - sd_controlnet_normal = "Normal Map" - sd_controlnet_openpose = "Human Pose" - sd_controlnet_scribble = "Scribble" - sd_controlnet_seg = "Image Segmentation" - sd_controlnet_tile = "Tiling" - sd_controlnet_brightness = "Brightness" - control_v1p_sd15_qrcode_monster_v2 = "QR Monster V2" - - -controlnet_model_explanations = { - ControlNetModels.sd_controlnet_canny: "Canny edge detection", - ControlNetModels.sd_controlnet_depth: "Depth estimation", - ControlNetModels.sd_controlnet_hed: "HED edge detection", - ControlNetModels.sd_controlnet_mlsd: "M-LSD straight line detection", - ControlNetModels.sd_controlnet_normal: "Normal map estimation", - ControlNetModels.sd_controlnet_openpose: "Human pose estimation", - ControlNetModels.sd_controlnet_scribble: "Scribble", - ControlNetModels.sd_controlnet_seg: "Image segmentation", - ControlNetModels.sd_controlnet_tile: "Tiling: to preserve small details", - ControlNetModels.sd_controlnet_brightness: "Brightness: to increase contrast naturally", - ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose", -} +class ControlNetModel(typing.NamedTuple): + label: str + model_id: str + explanation: str -controlnet_model_ids = { - ControlNetModels.sd_controlnet_canny: "lllyasviel/sd-controlnet-canny", - ControlNetModels.sd_controlnet_depth: "lllyasviel/sd-controlnet-depth", - ControlNetModels.sd_controlnet_hed: "lllyasviel/sd-controlnet-hed", - ControlNetModels.sd_controlnet_mlsd: "lllyasviel/sd-controlnet-mlsd", - ControlNetModels.sd_controlnet_normal: "lllyasviel/sd-controlnet-normal", - ControlNetModels.sd_controlnet_openpose: "lllyasviel/sd-controlnet-openpose", - ControlNetModels.sd_controlnet_scribble: "lllyasviel/sd-controlnet-scribble", - ControlNetModels.sd_controlnet_seg: "lllyasviel/sd-controlnet-seg", - ControlNetModels.sd_controlnet_tile: "lllyasviel/control_v11f1e_sd15_tile", - ControlNetModels.sd_controlnet_brightness: "ioclab/control_v1p_sd15_brightness", - ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "monster-labs/control_v1p_sd15_qrcode_monster/v2", -} + +class ControlNetModels(ControlNetModel, GooeyEnum): + sd_controlnet_canny = ControlNetModel( + label="Canny", + explanation="Canny edge detection", + model_id="lllyasviel/sd-controlnet-canny", + ) + sd_controlnet_depth = ControlNetModel( + label="Depth", + explanation="Depth estimation", + model_id="lllyasviel/sd-controlnet-depth", + ) + sd_controlnet_hed = ControlNetModel( + label="HED Boundary", + explanation="HED edge detection", + model_id="lllyasviel/sd-controlnet-hed", + ) + sd_controlnet_mlsd = ControlNetModel( + label="M-LSD Straight Line", + explanation="M-LSD straight line detection", + model_id="lllyasviel/sd-controlnet-mlsd", + ) + sd_controlnet_normal = ControlNetModel( + label="Normal Map", + explanation="Normal map estimation", + model_id="lllyasviel/sd-controlnet-normal", + ) + sd_controlnet_openpose = ControlNetModel( + label="Human Pose", + explanation="Human pose estimation", + model_id="lllyasviel/sd-controlnet-openpose", + ) + sd_controlnet_scribble = ControlNetModel( + label="Scribble", + explanation="Scribble", + model_id="lllyasviel/sd-controlnet-scribble", + ) + sd_controlnet_seg = ControlNetModel( + label="Image Segmentation", + explanation="Image segmentation", + model_id="lllyasviel/sd-controlnet-seg", + ) + sd_controlnet_tile = ControlNetModel( + label="Tiling", + explanation="Tiling: to preserve small details", + model_id="lllyasviel/control_v11f1e_sd15_tile", + ) + sd_controlnet_brightness = ControlNetModel( + label="Brightness", + explanation="Brightness: to increase contrast naturally", + model_id="ioclab/control_v1p_sd15_brightness", + ) + control_v1p_sd15_qrcode_monster_v2 = ControlNetModel( + label="QR Monster V2", + explanation="QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose", + model_id="monster-labs/control_v1p_sd15_qrcode_monster/v2", + ) class Schedulers(models.TextChoices): @@ -463,8 +485,7 @@ def controlnet( ), "disable_safety_checker": True, "controlnet_model_id": [ - controlnet_model_ids[ControlNetModels[model]] - for model in selected_controlnet_model + ControlNetModels[model].model_id for model in selected_controlnet_model ], }, inputs={ diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 7cb9283a7..12d2ede32 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -48,9 +48,7 @@ class RequestModel(BasePage.RequestModel): selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None selected_controlnet_model: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)]] - | typing.Literal[tuple(e.name for e in ControlNetModels)] - | None + list[ControlNetModels.api_enum] | ControlNetModels.api_enum | None ) negative_prompt: str | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 8a8d4921f..b57905885 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -90,18 +90,14 @@ class RequestModel(BasePage.RequestModel): text_prompt: str negative_prompt: str | None image_prompt: str | None - image_prompt_controlnet_models: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None - ) + image_prompt_controlnet_models: list[ControlNetModels.api_enum] | None image_prompt_strength: float | None image_prompt_scale: float | None image_prompt_pos_x: float | None image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None - selected_controlnet_model: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None - ) + selected_controlnet_model: list[ControlNetModels.api_enum] | None output_width: int | None output_height: int | None From 8509c61fe00ef12bf3587e74be4a7b0618502d86 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:02:48 +0530 Subject: [PATCH 24/38] Use GooeyEnum for Text2ImgModels, Img2ImgModels, UpscalerModels --- daras_ai_v2/stable_diffusion.py | 176 +++++++++++++++++----------- daras_ai_v2/upscaler_models.py | 4 +- recipes/CompareText2Img.py | 13 +- recipes/CompareUpscaler.py | 20 ++-- recipes/EmailFaceInpainting.py | 2 +- recipes/FaceInpainting.py | 2 +- recipes/GoogleImageGen.py | 2 +- recipes/Img2Img.py | 2 +- recipes/ObjectInpainting.py | 2 +- recipes/QRCodeGenerator.py | 4 +- scripts/init_self_hosted_pricing.py | 28 ++--- usage_costs/models.py | 6 +- 12 files changed, 144 insertions(+), 117 deletions(-) diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index 99d3e6608..60419cc33 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -28,98 +28,140 @@ SD_IMG_MAX_SIZE = (768, 768) -class InpaintingModels(Enum): - sd_2 = "Stable Diffusion v2.1 (stability.ai)" - runway_ml = "Stable Diffusion v1.5 (RunwayML)" - dall_e = "Dall-E (OpenAI)" +class InpaintingModel(typing.NamedTuple): + model_id: str | None + label: str + + +class InpaintingModels(InpaintingModel, GooeyEnum): + sd_2 = InpaintingModel( + label="Stable Diffusion v2.1 (stability.ai)", + model_id="stabilityai/stable-diffusion-2-inpainting", + ) + runway_ml = InpaintingModel( + label="Stable Diffusion v1.5 (RunwayML)", + model_id="runwayml/stable-diffusion-inpainting", + ) + dall_e = InpaintingModel(label="Dall-E (OpenAI)", model_id="dall-e-2") - jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" + jack_qiao = InpaintingModel( + label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None + ) @classmethod def _deprecated(cls): return {cls.jack_qiao} -inpaint_model_ids = { - InpaintingModels.sd_2: "stabilityai/stable-diffusion-2-inpainting", - InpaintingModels.runway_ml: "runwayml/stable-diffusion-inpainting", -} +class Text2ImgModel(typing.NamedTuple): + model_id: str | None + label: str -class Text2ImgModels(Enum): +class Text2ImgModels(Text2ImgModel, GooeyEnum): # sd_1_4 = "SD v1.4 (RunwayML)" # Host this too? - dream_shaper = "DreamShaper (Lykon)" - dreamlike_2 = "Dreamlike Photoreal 2.0 (dreamlike.art)" - sd_2 = "Stable Diffusion v2.1 (stability.ai)" - sd_1_5 = "Stable Diffusion v1.5 (RunwayML)" + dream_shaper = Text2ImgModel( + label="DreamShaper (Lykon)", model_id="Lykon/DreamShaper" + ) + dreamlike_2 = Text2ImgModel( + label="Dreamlike Photoreal 2.0 (dreamlike.art)", + model_id="dreamlike-art/dreamlike-photoreal-2.0", + ) + sd_2 = Text2ImgModel( + label="Stable Diffusion v2.1 (stability.ai)", + model_id="stabilityai/stable-diffusion-2-1", + ) + sd_1_5 = Text2ImgModel( + label="Stable Diffusion v1.5 (RunwayML)", + model_id="runwayml/stable-diffusion-v1-5", + ) - dall_e = "DALL·E 2 (OpenAI)" - dall_e_3 = "DALL·E 3 (OpenAI)" + dall_e = Text2ImgModel(label="DALL·E 2 (OpenAI)", model_id="dall-e-2") + dall_e_3 = Text2ImgModel(label="DALL·E 3 (OpenAI)", model_id="dall-e-3") - openjourney_2 = "Open Journey v2 beta (PromptHero)" - openjourney = "Open Journey (PromptHero)" - analog_diffusion = "Analog Diffusion (wavymulder)" - protogen_5_3 = "Protogen v5.3 (darkstorm2150)" + openjourney_2 = Text2ImgModel( + label="Open Journey v2 beta (PromptHero)", model_id="prompthero/openjourney-v2" + ) + openjourney = Text2ImgModel( + label="Open Journey (PromptHero)", model_id="prompthero/openjourney" + ) + analog_diffusion = Text2ImgModel( + label="Analog Diffusion (wavymulder)", model_id="wavymulder/Analog-Diffusion" + ) + protogen_5_3 = Text2ImgModel( + label="Protogen v5.3 (darkstorm2150)", + model_id="darkstorm2150/Protogen_v5.3_Official_Release", + ) - jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" - rodent_diffusion_1_5 = "Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)" - deepfloyd_if = "DeepFloyd IF [Deprecated] (stability.ai)" + jack_qiao = Text2ImgModel( + label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None + ) + rodent_diffusion_1_5 = Text2ImgModel( + label="Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)", model_id=None + ) + deepfloyd_if = Text2ImgModel( + label="DeepFloyd IF [Deprecated] (stability.ai)", model_id=None + ) @classmethod def _deprecated(cls): return {cls.jack_qiao, cls.deepfloyd_if, cls.rodent_diffusion_1_5} -text2img_model_ids = { - Text2ImgModels.sd_1_5: "runwayml/stable-diffusion-v1-5", - Text2ImgModels.sd_2: "stabilityai/stable-diffusion-2-1", - Text2ImgModels.dream_shaper: "Lykon/DreamShaper", - Text2ImgModels.analog_diffusion: "wavymulder/Analog-Diffusion", - Text2ImgModels.openjourney: "prompthero/openjourney", - Text2ImgModels.openjourney_2: "prompthero/openjourney-v2", - Text2ImgModels.dreamlike_2: "dreamlike-art/dreamlike-photoreal-2.0", - Text2ImgModels.protogen_5_3: "darkstorm2150/Protogen_v5.3_Official_Release", -} -dall_e_model_ids = { - Text2ImgModels.dall_e: "dall-e-2", - Text2ImgModels.dall_e_3: "dall-e-3", -} +class Img2ImgModel(typing.NamedTuple): + model_id: str | None + label: str -class Img2ImgModels(Enum): - dream_shaper = "DreamShaper (Lykon)" - dreamlike_2 = "Dreamlike Photoreal 2.0 (dreamlike.art)" - sd_2 = "Stable Diffusion v2.1 (stability.ai)" - sd_1_5 = "Stable Diffusion v1.5 (RunwayML)" +class Img2ImgModels(Img2ImgModel, GooeyEnum): + dream_shaper = Img2ImgModel( + label="DreamShaper (Lykon)", model_id="Lykon/DreamShaper" + ) + dreamlike_2 = Img2ImgModel( + label="Dreamlike Photoreal 2.0 (dreamlike.art)", + model_id="dreamlike-art/dreamlike-photoreal-2.0", + ) + sd_2 = Img2ImgModel( + label="Stable Diffusion v2.1 (stability.ai)", + model_id="stabilityai/stable-diffusion-2-1", + ) + sd_1_5 = Img2ImgModel( + label="Stable Diffusion v1.5 (RunwayML)", + model_id="runwayml/stable-diffusion-v1-5", + ) - dall_e = "Dall-E (OpenAI)" + dall_e = Img2ImgModel(label="Dall-E (OpenAI)", model_id=None) - instruct_pix2pix = "✨ InstructPix2Pix (Tim Brooks)" - openjourney_2 = "Open Journey v2 beta (PromptHero) 🐢" - openjourney = "Open Journey (PromptHero) 🐢" - analog_diffusion = "Analog Diffusion (wavymulder) 🐢" - protogen_5_3 = "Protogen v5.3 (darkstorm2150) 🐢" + instruct_pix2pix = Img2ImgModel( + label="✨ InstructPix2Pix (Tim Brooks)", model_id=None + ) + openjourney_2 = Img2ImgModel( + label="Open Journey v2 beta (PromptHero) 🐢", + model_id="prompthero/openjourney-v2", + ) + openjourney = Img2ImgModel( + label="Open Journey (PromptHero) 🐢", model_id="prompthero/openjourney" + ) + analog_diffusion = Img2ImgModel( + label="Analog Diffusion (wavymulder) 🐢", model_id="wavymulder/Analog-Diffusion" + ) + protogen_5_3 = Img2ImgModel( + label="Protogen v5.3 (darkstorm2150) 🐢", + model_id="darkstorm2150/Protogen_v5.3_Official_Release", + ) - jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" - rodent_diffusion_1_5 = "Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)" + jack_qiao = Img2ImgModel( + label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None + ) + rodent_diffusion_1_5 = Img2ImgModel( + label="Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)", model_id=None + ) @classmethod def _deprecated(cls): return {cls.jack_qiao, cls.rodent_diffusion_1_5} -img2img_model_ids = { - Img2ImgModels.sd_2: "stabilityai/stable-diffusion-2-1", - Img2ImgModels.sd_1_5: "runwayml/stable-diffusion-v1-5", - Img2ImgModels.dream_shaper: "Lykon/DreamShaper", - Img2ImgModels.openjourney: "prompthero/openjourney", - Img2ImgModels.openjourney_2: "prompthero/openjourney-v2", - Img2ImgModels.analog_diffusion: "wavymulder/Analog-Diffusion", - Img2ImgModels.protogen_5_3: "darkstorm2150/Protogen_v5.3_Official_Release", - Img2ImgModels.dreamlike_2: "dreamlike-art/dreamlike-photoreal-2.0", -} - - class ControlNetModel(typing.NamedTuple): label: str model_id: str @@ -315,7 +357,7 @@ def text2img( width, height = _get_dall_e_3_img_size(width, height) with capture_openai_content_policy_violation(): response = client.images.generate( - model=dall_e_model_ids[Text2ImgModels[selected_model]], + model=Text2ImgModels[selected_model].model_id, n=1, # num_outputs, not supported yet prompt=prompt, response_format="b64_json", @@ -342,7 +384,7 @@ def text2img( return call_sd_multi( "diffusion.text2img", pipeline={ - "model_id": text2img_model_ids[Text2ImgModels[selected_model]], + "model_id": Text2ImgModels[selected_model].model_id, "scheduler": Schedulers[scheduler].label if scheduler else None, "disable_safety_checker": True, "seed": seed, @@ -435,7 +477,7 @@ def img2img( return call_sd_multi( "diffusion.img2img", pipeline={ - "model_id": img2img_model_ids[Img2ImgModels[selected_model]], + "model_id": Img2ImgModels[selected_model].model_id, # "scheduler": "UniPCMultistepScheduler", "disable_safety_checker": True, "seed": seed, @@ -478,7 +520,7 @@ def controlnet( return call_sd_multi( "diffusion.controlnet", pipeline={ - "model_id": text2img_model_ids[Text2ImgModels[selected_model]], + "model_id": Text2ImgModels[selected_model].model_id, "seed": seed, "scheduler": ( Schedulers[scheduler].label if scheduler else "UniPCMultistepScheduler" @@ -556,7 +598,7 @@ def inpainting( out_imgs_urls = call_sd_multi( "diffusion.inpaint", pipeline={ - "model_id": inpaint_model_ids[InpaintingModels[selected_model]], + "model_id": InpaintingModels[selected_model].model_id, "seed": seed, # "scheduler": Schedulers[scheduler].label # if scheduler diff --git a/daras_ai_v2/upscaler_models.py b/daras_ai_v2/upscaler_models.py index 11d8ff225..f5d46a40a 100644 --- a/daras_ai_v2/upscaler_models.py +++ b/daras_ai_v2/upscaler_models.py @@ -1,11 +1,11 @@ import typing -from enum import Enum from pathlib import Path import replicate import requests from daras_ai.image_input import upload_file_from_bytes +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import UserError from daras_ai_v2.gpu_server import call_celery_task_outfile from daras_ai_v2.pydantic_validation import FieldHttpUrl @@ -19,7 +19,7 @@ class UpscalerModel(typing.NamedTuple): is_bg_model: bool = False -class UpscalerModels(UpscalerModel, Enum): +class UpscalerModels(UpscalerModel, GooeyEnum): gfpgan_1_4 = UpscalerModel( model_id="GFPGANv1.4", label="GFPGAN v1.4 (Tencent ARC)", diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 0f132eb04..3dd323da4 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -66,19 +66,14 @@ class RequestModel(BasePage.RequestModel): seed: int | None sd_2_upscaling: bool | None - selected_models: ( - list[typing.Literal[tuple(e.name for e in Text2ImgModels)]] | None - ) + selected_models: list[Text2ImgModels.api_enum] | None scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None edit_instruction: str | None image_guidance_scale: float | None class ResponseModel(BaseModel): - output_images: dict[ - typing.Literal[tuple(e.name for e in Text2ImgModels)], - list[FieldHttpUrl], - ] + output_images: dict[Text2ImgModels.api_enum, list[FieldHttpUrl]] @classmethod def get_example_preferred_fields(cls, state: dict) -> list[str]: @@ -193,7 +188,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_images"] = output_images = {} for selected_model in request.selected_models: - yield f"Running {Text2ImgModels[selected_model].value}..." + yield f"Running {Text2ImgModels[selected_model].label}..." output_images[selected_model] = text2img( selected_model=selected_model, @@ -254,7 +249,7 @@ def _render_outputs(self, state): output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: gui.image( - img, caption=Text2ImgModels[key].value, show_download_button=True + img, caption=Text2ImgModels[key].label, show_download_button=True ) def preview_description(self, state: dict) -> str: diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index dafcb9977..e7032095c 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -30,21 +30,22 @@ class RequestModel(BasePage.RequestModel): description="The final upsampling scale of the image", ge=1, le=4 ) - selected_models: ( - list[typing.Literal[tuple(e.name for e in UpscalerModels)]] | None - ) + selected_models: list[UpscalerModels.api_enum] | None selected_bg_model: ( typing.Literal[tuple(e.name for e in UpscalerModels if e.is_bg_model)] | None + ) = Field( + title="Selected Background Model", + **{"x-fern-type-name": "BackgroundUpscalerModels"}, ) class ResponseModel(BaseModel): - output_images: dict[ - typing.Literal[tuple(e.name for e in UpscalerModels)], FieldHttpUrl - ] = Field({}, description="Output Images") - output_videos: dict[ - typing.Literal[tuple(e.name for e in UpscalerModels)], FieldHttpUrl - ] = Field({}, description="Output Videos") + output_images: dict[UpscalerModels.api_enum, FieldHttpUrl] = Field( + default_factory=dict, description="Output Images" + ) + output_videos: dict[UpscalerModels.api_enum, FieldHttpUrl] = Field( + default_factory=dict, description="Output Videos" + ) def validate_form_v2(self): assert gui.session_state.get( @@ -70,6 +71,7 @@ def run_v2( for selected_model in request.selected_models: model = UpscalerModels[selected_model] yield f"Running {model.label}..." + print(f"{request.input_image=}, {request.input_video=}") if request.input_image: response.output_images[selected_model] = run_upscaler_model( selected_model=model, diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index 408a52715..0fcf39835 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -50,7 +50,7 @@ class RequestModel(BasePage.RequestModel): face_pos_x: float | None face_pos_y: float | None - selected_model: typing.Literal[tuple(e.name for e in InpaintingModels)] | None + selected_model: InpaintingModels.api_enum | None negative_prompt: str | None diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index ae76609fd..a28041d37 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -53,7 +53,7 @@ class RequestModel(BasePage.RequestModel): face_pos_x: float | None face_pos_y: float | None - selected_model: typing.Literal[tuple(e.name for e in InpaintingModels)] | None + selected_model: InpaintingModels.api_enum | None negative_prompt: str | None diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index 7b5636dd8..af64c45cd 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -55,7 +55,7 @@ class RequestModel(GoogleSearchLocationMixin, BasePage.RequestModel): search_query: str text_prompt: str - selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None + selected_model: Img2ImgModels.api_enum | None negative_prompt: str | None diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 12d2ede32..e7140a632 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -46,7 +46,7 @@ class RequestModel(BasePage.RequestModel): input_image: FieldHttpUrl text_prompt: str | None - selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None + selected_model: Img2ImgModels.api_enum | None selected_controlnet_model: ( list[ControlNetModels.api_enum] | ControlNetModels.api_enum | None ) diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index 712399047..65f902ca4 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -56,7 +56,7 @@ class RequestModel(BasePage.RequestModel): mask_threshold: float | None - selected_model: typing.Literal[tuple(e.name for e in InpaintingModels)] | None + selected_model: InpaintingModels.api_enum | None negative_prompt: str | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index b57905885..2c39cdbd5 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -96,7 +96,7 @@ class RequestModel(BasePage.RequestModel): image_prompt_pos_x: float | None image_prompt_pos_y: float | None - selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None + selected_model: Text2ImgModels.api_enum | None selected_controlnet_model: list[ControlNetModels.api_enum] | None output_width: int | None @@ -482,7 +482,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["raw_images"] = raw_images = [] - yield f"Running {Text2ImgModels[request.selected_model].value}..." + yield f"Running {Text2ImgModels[request.selected_model].label}..." if isinstance(request.selected_controlnet_model, str): request.selected_controlnet_model = [request.selected_controlnet_model] init_images = [image] * len(request.selected_controlnet_model) diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 09dd7b8e3..5d7ef48b8 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -1,14 +1,7 @@ from decimal import Decimal from daras_ai_v2.gpu_server import build_queue_name -from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, - Img2ImgModels, - InpaintingModels, - text2img_model_ids, - img2img_model_ids, - inpaint_model_ids, -) +from daras_ai_v2.stable_diffusion import Text2ImgModels, Img2ImgModels, InpaintingModels from recipes.DeforumSD import AnimationModels from usage_costs.models import ModelPricing from usage_costs.models import ModelSku, ModelCategory, ModelProvider @@ -17,20 +10,15 @@ def run(): - for model in AnimationModels: - add_model(model.model_id, model.name) - for model_enum, model_ids in [ - (Text2ImgModels, text2img_model_ids), - (Img2ImgModels, img2img_model_ids), - (InpaintingModels, inpaint_model_ids), + for model_enum in [ + AnimationModels, + Text2ImgModels, + Img2ImgModels, + InpaintingModels, ]: for m in model_enum: - if "dall_e" in m.name: - continue - try: - add_model(model_ids[m], m.name) - except KeyError: - pass + if "dall_e" not in m.name and m.model_id: + add_model(m.model_id, m.name) add_model("wav2lip_gan.pth", "wav2lip") diff --git a/usage_costs/models.py b/usage_costs/models.py index bfea9ec61..2d1b38271 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -63,9 +63,9 @@ def get_model_choices(): return ( [(api.name, api.value) for api in LargeLanguageModels] + [(model.name, model.label) for model in AnimationModels] - + [(model.name, model.value) for model in Text2ImgModels] - + [(model.name, model.value) for model in Img2ImgModels] - + [(model.name, model.value) for model in InpaintingModels] + + [(model.name, model.label) for model in Text2ImgModels] + + [(model.name, model.label) for model in Img2ImgModels] + + [(model.name, model.label) for model in InpaintingModels] + [("wav2lip", "LipSync (wav2lip)")] ) From 4f60610f4f680c75cabddab30ec3982b0e60ceff Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:13:33 +0530 Subject: [PATCH 25/38] use GooeyEnum for translation models --- daras_ai_v2/asr.py | 2 +- recipes/Translation.py | 6 +++--- recipes/VideoBots.py | 4 +--- recipes/asr_page.py | 4 +--- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index ef715aa99..9354b5477 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -291,7 +291,7 @@ class TranslationModel(typing.NamedTuple): supports_auto_detect: bool = False -class TranslationModels(TranslationModel, Enum): +class TranslationModels(TranslationModel, GooeyEnum): google = TranslationModel( label="Google Translate", supports_glossary=True, diff --git a/recipes/Translation.py b/recipes/Translation.py index a58edb96e..19ce68795 100644 --- a/recipes/Translation.py +++ b/recipes/Translation.py @@ -43,9 +43,9 @@ class TranslationPage(BasePage): class RequestModelBase(BasePage.RequestModel): texts: list[str] = Field([]) - selected_model: ( - typing.Literal[tuple(e.name for e in TranslationModels)] - ) | None = Field(TranslationModels.google.name) + selected_model: TranslationModels.api_enum | None = Field( + TranslationModels.google.name + ) class RequestModel(TranslationOptions, RequestModelBase): pass diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 9ff802a20..419b06242 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -230,9 +230,7 @@ class RequestModelBase(BasePage.RequestModel): description="Choose a language to transcribe incoming audio messages to text.", ) - translation_model: ( - typing.Literal[tuple(e.name for e in TranslationModels)] | None - ) + translation_model: TranslationModels.api_enum | None user_language: str | None = Field( title="User Language", description="Choose a language to translate incoming text & audio messages to English and responses back to your selected language. Useful for low-resource languages.", diff --git a/recipes/asr_page.py b/recipes/asr_page.py index 435cd206d..52ef08aee 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -47,9 +47,7 @@ class RequestModelBase(BasePage.RequestModel): selected_model: AsrModels.api_enum | None language: str | None - translation_model: ( - typing.Literal[tuple(e.name for e in TranslationModels)] | None - ) + translation_model: TranslationModels.api_enum | None output_format: typing.Literal[tuple(e.name for e in AsrOutputFormat)] | None From 1f70c0a6c987a62be7f46e648f4596c684f599dd Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:13:49 +0530 Subject: [PATCH 26/38] Use GooeyEnum for segmentation models --- daras_ai_v2/image_segmentation.py | 5 ++--- recipes/ImageSegmentation.py | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/daras_ai_v2/image_segmentation.py b/daras_ai_v2/image_segmentation.py index 099832979..004264afc 100644 --- a/daras_ai_v2/image_segmentation.py +++ b/daras_ai_v2/image_segmentation.py @@ -1,14 +1,13 @@ -from enum import Enum - import requests +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.gpu_server import ( call_celery_task_outfile, ) -class ImageSegmentationModels(Enum): +class ImageSegmentationModels(str, GooeyEnum): dis = "Dichotomous Image Segmentation" u2net = "U²-Net" diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 7db0a16c6..c9e16d988 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -16,7 +16,7 @@ ) from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.image_segmentation import u2net, ImageSegmentationModels, dis +from daras_ai_v2.image_segmentation import ImageSegmentationModels, dis, u2net from daras_ai_v2.img_io import opencv_to_pil, pil_to_bytes from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.polygon_fitter import ( @@ -51,9 +51,7 @@ class ImageSegmentationPage(BasePage): class RequestModel(BasePage.RequestModel): input_image: FieldHttpUrl - selected_model: ( - typing.Literal[tuple(e.name for e in ImageSegmentationModels)] | None - ) + selected_model: ImageSegmentationModels.api_enum | None mask_threshold: float | None rect_persepective_transform: bool | None From 00b6bfc2475a9bef503d7df5dadb370ed1eabf7a Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:18:15 +0530 Subject: [PATCH 27/38] Use GooeyEnum for AsrOutputFormat --- daras_ai_v2/asr.py | 3 +-- recipes/asr_page.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 9354b5477..44a05749d 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -3,7 +3,6 @@ import os.path import tempfile import typing -from enum import Enum import gooey_gui as gui import requests @@ -278,7 +277,7 @@ class AsrOutputJson(typing_extensions.TypedDict): chunks: typing_extensions.NotRequired[list[AsrChunk]] -class AsrOutputFormat(Enum): +class AsrOutputFormat(GooeyEnum): text = "Text" json = "JSON" srt = "SRT" diff --git a/recipes/asr_page.py b/recipes/asr_page.py index 52ef08aee..79d679a23 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -49,7 +49,7 @@ class RequestModelBase(BasePage.RequestModel): translation_model: TranslationModels.api_enum | None - output_format: typing.Literal[tuple(e.name for e in AsrOutputFormat)] | None + output_format: AsrOutputFormat.api_enum | None google_translate_target: str | None = Field( deprecated=True, From d38fc4f258771bde9152c819d6e795d2eaa4d64c Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:18:30 +0530 Subject: [PATCH 28/38] Use GooeyEnum for CitationStyles --- daras_ai_v2/search_ref.py | 4 ++-- recipes/DocSearch.py | 2 +- recipes/VideoBots.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 0ed45386e..6ea91368e 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -1,10 +1,10 @@ import re import typing -from enum import Enum import jinja2 from typing_extensions import TypedDict +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import UserError from daras_ai_v2.scrollable_html_widget import scrollable_html @@ -16,7 +16,7 @@ class SearchReference(TypedDict): score: float -class CitationStyles(Enum): +class CitationStyles(GooeyEnum): number = "Numbers ( [1] [2] [3] ..)" title = "Source Title ( [Source 1] [Source 2] [Source 3] ..)" url = "Source URL ( [https://source1.com] [https://source2.com] [https://source3.com] ..)" diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 591574a6b..f758933b2 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -74,7 +74,7 @@ class RequestModelBase(DocSearchRequest, BasePage.RequestModel): selected_model: LargeLanguageModels.api_enum | None - citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None + citation_style: CitationStyles.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 419b06242..d760409dc 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -218,7 +218,7 @@ class RequestModelBase(BasePage.RequestModel): "dense_weight" ].field_info - citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None + citation_style: CitationStyles.api_enum | None use_url_shortener: bool | None asr_model: AsrModels.api_enum | None = Field( From dbe84f3f8295a021d32ccc4983a1b99a6a3365ff Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:51:39 +0530 Subject: [PATCH 29/38] Use GooeyEnum for Scheduler and TextToSpeechProviders --- daras_ai_v2/stable_diffusion.py | 78 ++++++++++--------- .../text_to_speech_settings_widgets.py | 2 +- recipes/CompareText2Img.py | 2 +- recipes/QRCodeGenerator.py | 2 +- recipes/TextToSpeech.py | 2 +- 5 files changed, 47 insertions(+), 39 deletions(-) diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index 60419cc33..42aa786bb 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -1,6 +1,5 @@ import io import typing -from enum import Enum import requests from PIL import Image @@ -226,44 +225,51 @@ class ControlNetModels(ControlNetModel, GooeyEnum): ) -class Schedulers(models.TextChoices): - singlestep_dpm_solver = ( - "DPM", - "DPMSolverSinglestepScheduler", +class Scheduler(typing.NamedTuple): + label: str + model_id: str + + +class Schedulers(Scheduler, GooeyEnum): + singlestep_dpm_solver = Scheduler( + label="DPM", + model_id="DPMSolverSinglestepScheduler", + ) + multistep_dpm_solver = Scheduler( + label="DPM Multistep", model_id="DPMSolverMultistepScheduler" ) - multistep_dpm_solver = "DPM Multistep", "DPMSolverMultistepScheduler" - dpm_sde = ( - "DPM SDE", - "DPMSolverSDEScheduler", + dpm_sde = Scheduler( + label="DPM SDE", + model_id="DPMSolverSDEScheduler", ) - dpm_discrete = ( - "DPM Discrete", - "KDPM2DiscreteScheduler", + dpm_discrete = Scheduler( + label="DPM Discrete", + model_id="KDPM2DiscreteScheduler", ) - dpm_discrete_ancestral = ( - "DPM Anscetral", - "KDPM2AncestralDiscreteScheduler", + dpm_discrete_ancestral = Scheduler( + label="DPM Anscetral", + model_id="KDPM2AncestralDiscreteScheduler", ) - unipc = "UniPC", "UniPCMultistepScheduler" - lms_discrete = ( - "LMS", - "LMSDiscreteScheduler", + unipc = Scheduler(label="UniPC", model_id="UniPCMultistepScheduler") + lms_discrete = Scheduler( + label="LMS", + model_id="LMSDiscreteScheduler", ) - heun = ( - "Heun", - "HeunDiscreteScheduler", + heun = Scheduler( + label="Heun", + model_id="HeunDiscreteScheduler", ) - euler = "Euler", "EulerDiscreteScheduler" - euler_ancestral = ( - "Euler ancestral", - "EulerAncestralDiscreteScheduler", + euler = Scheduler("Euler", "EulerDiscreteScheduler") + euler_ancestral = Scheduler( + label="Euler ancestral", + model_id="EulerAncestralDiscreteScheduler", ) - pndm = "PNDM", "PNDMScheduler" - ddpm = "DDPM", "DDPMScheduler" - ddim = "DDIM", "DDIMScheduler" - deis = ( - "DEIS", - "DEISMultistepScheduler", + pndm = Scheduler(label="PNDM", model_id="PNDMScheduler") + ddpm = Scheduler(label="DDPM", model_id="DDPMScheduler") + ddim = Scheduler(label="DDIM", model_id="DDIMScheduler") + deis = Scheduler( + label="DEIS", + model_id="DEISMultistepScheduler", ) @@ -385,7 +391,7 @@ def text2img( "diffusion.text2img", pipeline={ "model_id": Text2ImgModels[selected_model].model_id, - "scheduler": Schedulers[scheduler].label if scheduler else None, + "scheduler": Schedulers[scheduler].model_id if scheduler else None, "disable_safety_checker": True, "seed": seed, }, @@ -523,7 +529,9 @@ def controlnet( "model_id": Text2ImgModels[selected_model].model_id, "seed": seed, "scheduler": ( - Schedulers[scheduler].label if scheduler else "UniPCMultistepScheduler" + Schedulers[scheduler].model_id + if scheduler + else Schedulers.unipc.model_id ), "disable_safety_checker": True, "controlnet_model_id": [ @@ -600,7 +608,7 @@ def inpainting( pipeline={ "model_id": InpaintingModels[selected_model].model_id, "seed": seed, - # "scheduler": Schedulers[scheduler].label + # "scheduler": Schedulers[scheduler].model_id # if scheduler # else "UniPCMultistepScheduler", "disable_safety_checker": True, diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index 5214d5cc5..65b5b56ea 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -46,7 +46,7 @@ class OpenAI_TTS_Voices(GooeyEnum): shimmer = "shimmer" -class TextToSpeechProviders(Enum): +class TextToSpeechProviders(GooeyEnum): GOOGLE_TTS = "Google Text-to-Speech" ELEVEN_LABS = "Eleven Labs" UBERDUCK = "Uberduck.ai" diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 3dd323da4..5c7b7b3f0 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -67,7 +67,7 @@ class RequestModel(BasePage.RequestModel): sd_2_upscaling: bool | None selected_models: list[Text2ImgModels.api_enum] | None - scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None + scheduler: Schedulers.api_enum | None edit_instruction: str | None image_guidance_scale: float | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 2c39cdbd5..e91b1a03c 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -107,7 +107,7 @@ class RequestModel(BasePage.RequestModel): num_outputs: int | None quality: int | None - scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None + scheduler: Schedulers.api_enum | None seed: int | None diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 5aba6d4df..6a66e8cc0 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -31,7 +31,7 @@ class TextToSpeechSettings(BaseModel): - tts_provider: typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + tts_provider: TextToSpeechProviders.api_enum | None uberduck_voice_name: str | None uberduck_speaking_rate: float | None From 3192e762f54bbf205c9bc217f771c18882bd5308 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:59:24 +0530 Subject: [PATCH 30/38] Use GooeyEnum for text2audio models and combine document chain type --- recipes/DocSummary.py | 6 +++--- recipes/Text2Audio.py | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 8e348d96d..1711de053 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -1,6 +1,6 @@ import typing -from enum import Enum +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.pydantic_validation import FieldHttpUrl from pydantic import BaseModel @@ -37,7 +37,7 @@ DEFAULT_DOC_SUMMARY_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f35796d2-93fe-11ee-b86c-02420a0001c7/Summarize%20with%20GPT.jpg.png" -class CombineDocumentsChains(Enum): +class CombineDocumentsChains(GooeyEnum): map_reduce = "Map Reduce" # refine = "Refine" # stuff = "Stuffing (Only works for small documents)" @@ -70,7 +70,7 @@ class RequestModelBase(BasePage.RequestModel): selected_model: LargeLanguageModels.api_enum | None - chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None + chain_type: CombineDocumentsChains.api_enum | None selected_asr_model: AsrModels.api_enum | None google_translate_target: str | None diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index f585de99e..302bf236f 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -1,6 +1,6 @@ import typing -from enum import Enum +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.pydantic_validation import FieldHttpUrl from pydantic import BaseModel @@ -18,7 +18,7 @@ DEFAULT_TEXT2AUDIO_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85cf8ea4-9457-11ee-bd77-02420a0001ce/Text%20guided%20audio.jpg.png" -class Text2AudioModels(Enum): +class Text2AudioModels(GooeyEnum): audio_ldm = "AudioLDM (CVSSP)" @@ -51,13 +51,11 @@ class RequestModel(BasePage.RequestModel): seed: int | None sd_2_upscaling: bool | None - selected_models: ( - list[typing.Literal[tuple(e.name for e in Text2AudioModels)]] | None - ) + selected_models: list[Text2AudioModels.api_enum] | None class ResponseModel(BaseModel): output_audios: dict[ - typing.Literal[tuple(e.name for e in Text2AudioModels)], + Text2AudioModels.api_enum, list[FieldHttpUrl], ] From 25fe3f269baf161b30a38a522756dcef10524022 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:06:47 +0530 Subject: [PATCH 31/38] Rename Text2Img->TextToImage and Img2Img->ImageToImage for parity with SDK names --- daras_ai_v2/img_model_settings_widgets.py | 40 +++++------ daras_ai_v2/stable_diffusion.py | 82 +++++++++++------------ recipes/CompareText2Img.py | 16 ++--- recipes/GoogleImageGen.py | 10 +-- recipes/Img2Img.py | 10 +-- recipes/QRCodeGenerator.py | 18 ++--- usage_costs/models.py | 6 +- 7 files changed, 92 insertions(+), 90 deletions(-) diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index 733465cd1..3b9705479 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -2,9 +2,9 @@ from daras_ai_v2.enum_selector_widget import enum_selector, enum_multiselect from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, + TextToImageModels, InpaintingModels, - Img2ImgModels, + ImageToImageModels, ControlNetModels, Schedulers, ) @@ -36,10 +36,10 @@ def img_model_settings( negative_prompt_setting(selected_model) num_outputs_setting(selected_model) - if models_enum is not Img2ImgModels: + if models_enum is not ImageToImageModels: output_resolution_setting() - if models_enum is Text2ImgModels: + if models_enum is TextToImageModels: sd_2_upscaling_setting() col1, col2 = gui.columns(2) @@ -48,11 +48,11 @@ def img_model_settings( guidance_scale_setting(selected_model) with col2: - if models_enum is Img2ImgModels and not gui.session_state.get( + if models_enum is ImageToImageModels and not gui.session_state.get( "selected_controlnet_model" ): prompt_strength_setting(selected_model) - if selected_model == Img2ImgModels.instruct_pix2pix.name: + if selected_model == ImageToImageModels.instruct_pix2pix.name: instruct_pix2pix_settings() if show_scheduler: @@ -72,10 +72,10 @@ def model_selector( high_explanation: str = "At {high} the control nets will be applied tightly to the prompted visual, possibly overriding the prompt", ): controlnet_unsupported_models = [ - Img2ImgModels.instruct_pix2pix.name, - Img2ImgModels.dall_e.name, - Img2ImgModels.jack_qiao.name, - Img2ImgModels.sd_2.name, + ImageToImageModels.instruct_pix2pix.name, + ImageToImageModels.dall_e.name, + ImageToImageModels.jack_qiao.name, + ImageToImageModels.sd_2.name, ] col1, col2 = gui.columns(2) with col1: @@ -95,12 +95,12 @@ def model_selector( """ ) if ( - models_enum is Img2ImgModels + models_enum is ImageToImageModels and gui.session_state.get("selected_model") in controlnet_unsupported_models ): if "selected_controlnet_model" in gui.session_state: gui.session_state["selected_controlnet_model"] = None - elif models_enum is Img2ImgModels: + elif models_enum is ImageToImageModels: enum_multiselect( ControlNetModels, label=controlnet_explanation, @@ -214,7 +214,7 @@ def quality_setting(selected_models=None): return if any( [ - selected_model in [Text2ImgModels.dall_e_3.name] + selected_model in [TextToImageModels.dall_e_3.name] for selected_model in selected_models ] ): @@ -374,8 +374,8 @@ def sd_2_upscaling_setting(): def scheduler_setting(selected_model: str = None): if selected_model in [ - Text2ImgModels.dall_e.name, - Text2ImgModels.jack_qiao, + TextToImageModels.dall_e.name, + TextToImageModels.jack_qiao, ]: return enum_selector( @@ -394,8 +394,8 @@ def scheduler_setting(selected_model: str = None): def guidance_scale_setting(selected_model: str = None): if selected_model in [ - Text2ImgModels.dall_e.name, - Text2ImgModels.jack_qiao, + TextToImageModels.dall_e.name, + TextToImageModels.jack_qiao, ]: return gui.slider( @@ -434,8 +434,8 @@ def instruct_pix2pix_settings(): def prompt_strength_setting(selected_model: str = None): if selected_model in [ - Img2ImgModels.dall_e.name, - Img2ImgModels.instruct_pix2pix.name, + ImageToImageModels.dall_e.name, + ImageToImageModels.instruct_pix2pix.name, ]: return @@ -457,7 +457,7 @@ def prompt_strength_setting(selected_model: str = None): def negative_prompt_setting(selected_model: str = None): - if selected_model in [Text2ImgModels.dall_e.name]: + if selected_model in [TextToImageModels.dall_e.name]: return gui.text_area( diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index 42aa786bb..a6ee14422 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -52,53 +52,53 @@ def _deprecated(cls): return {cls.jack_qiao} -class Text2ImgModel(typing.NamedTuple): +class TextToImageModel(typing.NamedTuple): model_id: str | None label: str -class Text2ImgModels(Text2ImgModel, GooeyEnum): +class TextToImageModels(TextToImageModel, GooeyEnum): # sd_1_4 = "SD v1.4 (RunwayML)" # Host this too? - dream_shaper = Text2ImgModel( + dream_shaper = TextToImageModel( label="DreamShaper (Lykon)", model_id="Lykon/DreamShaper" ) - dreamlike_2 = Text2ImgModel( + dreamlike_2 = TextToImageModel( label="Dreamlike Photoreal 2.0 (dreamlike.art)", model_id="dreamlike-art/dreamlike-photoreal-2.0", ) - sd_2 = Text2ImgModel( + sd_2 = TextToImageModel( label="Stable Diffusion v2.1 (stability.ai)", model_id="stabilityai/stable-diffusion-2-1", ) - sd_1_5 = Text2ImgModel( + sd_1_5 = TextToImageModel( label="Stable Diffusion v1.5 (RunwayML)", model_id="runwayml/stable-diffusion-v1-5", ) - dall_e = Text2ImgModel(label="DALL·E 2 (OpenAI)", model_id="dall-e-2") - dall_e_3 = Text2ImgModel(label="DALL·E 3 (OpenAI)", model_id="dall-e-3") + dall_e = TextToImageModel(label="DALL·E 2 (OpenAI)", model_id="dall-e-2") + dall_e_3 = TextToImageModel(label="DALL·E 3 (OpenAI)", model_id="dall-e-3") - openjourney_2 = Text2ImgModel( + openjourney_2 = TextToImageModel( label="Open Journey v2 beta (PromptHero)", model_id="prompthero/openjourney-v2" ) - openjourney = Text2ImgModel( + openjourney = TextToImageModel( label="Open Journey (PromptHero)", model_id="prompthero/openjourney" ) - analog_diffusion = Text2ImgModel( + analog_diffusion = TextToImageModel( label="Analog Diffusion (wavymulder)", model_id="wavymulder/Analog-Diffusion" ) - protogen_5_3 = Text2ImgModel( + protogen_5_3 = TextToImageModel( label="Protogen v5.3 (darkstorm2150)", model_id="darkstorm2150/Protogen_v5.3_Official_Release", ) - jack_qiao = Text2ImgModel( + jack_qiao = TextToImageModel( label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None ) - rodent_diffusion_1_5 = Text2ImgModel( + rodent_diffusion_1_5 = TextToImageModel( label="Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)", model_id=None ) - deepfloyd_if = Text2ImgModel( + deepfloyd_if = TextToImageModel( label="DeepFloyd IF [Deprecated] (stability.ai)", model_id=None ) @@ -107,52 +107,52 @@ def _deprecated(cls): return {cls.jack_qiao, cls.deepfloyd_if, cls.rodent_diffusion_1_5} -class Img2ImgModel(typing.NamedTuple): +class ImageToImageModel(typing.NamedTuple): model_id: str | None label: str -class Img2ImgModels(Img2ImgModel, GooeyEnum): - dream_shaper = Img2ImgModel( +class ImageToImageModels(ImageToImageModel, GooeyEnum): + dream_shaper = ImageToImageModel( label="DreamShaper (Lykon)", model_id="Lykon/DreamShaper" ) - dreamlike_2 = Img2ImgModel( + dreamlike_2 = ImageToImageModel( label="Dreamlike Photoreal 2.0 (dreamlike.art)", model_id="dreamlike-art/dreamlike-photoreal-2.0", ) - sd_2 = Img2ImgModel( + sd_2 = ImageToImageModel( label="Stable Diffusion v2.1 (stability.ai)", model_id="stabilityai/stable-diffusion-2-1", ) - sd_1_5 = Img2ImgModel( + sd_1_5 = ImageToImageModel( label="Stable Diffusion v1.5 (RunwayML)", model_id="runwayml/stable-diffusion-v1-5", ) - dall_e = Img2ImgModel(label="Dall-E (OpenAI)", model_id=None) + dall_e = ImageToImageModel(label="Dall-E (OpenAI)", model_id=None) - instruct_pix2pix = Img2ImgModel( + instruct_pix2pix = ImageToImageModel( label="✨ InstructPix2Pix (Tim Brooks)", model_id=None ) - openjourney_2 = Img2ImgModel( + openjourney_2 = ImageToImageModel( label="Open Journey v2 beta (PromptHero) 🐢", model_id="prompthero/openjourney-v2", ) - openjourney = Img2ImgModel( + openjourney = ImageToImageModel( label="Open Journey (PromptHero) 🐢", model_id="prompthero/openjourney" ) - analog_diffusion = Img2ImgModel( + analog_diffusion = ImageToImageModel( label="Analog Diffusion (wavymulder) 🐢", model_id="wavymulder/Analog-Diffusion" ) - protogen_5_3 = Img2ImgModel( + protogen_5_3 = ImageToImageModel( label="Protogen v5.3 (darkstorm2150) 🐢", model_id="darkstorm2150/Protogen_v5.3_Official_Release", ) - jack_qiao = Img2ImgModel( + jack_qiao = ImageToImageModel( label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None ) - rodent_diffusion_1_5 = Img2ImgModel( + rodent_diffusion_1_5 = ImageToImageModel( label="Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)", model_id=None ) @@ -352,18 +352,18 @@ def text2img( dall_e_3_quality: str | None = None, dall_e_3_style: str | None = None, ): - if selected_model != Text2ImgModels.dall_e_3.name: + if selected_model != TextToImageModels.dall_e_3.name: _resolution_check(width, height, max_size=(1024, 1024)) match selected_model: - case Text2ImgModels.dall_e_3.name: + case TextToImageModels.dall_e_3.name: from openai import OpenAI client = OpenAI() width, height = _get_dall_e_3_img_size(width, height) with capture_openai_content_policy_violation(): response = client.images.generate( - model=Text2ImgModels[selected_model].model_id, + model=TextToImageModels[selected_model].model_id, n=1, # num_outputs, not supported yet prompt=prompt, response_format="b64_json", @@ -372,7 +372,7 @@ def text2img( size=f"{width}x{height}", ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] - case Text2ImgModels.dall_e.name: + case TextToImageModels.dall_e.name: from openai import OpenAI edge = _get_dall_e_img_size(width, height) @@ -390,7 +390,7 @@ def text2img( return call_sd_multi( "diffusion.text2img", pipeline={ - "model_id": Text2ImgModels[selected_model].model_id, + "model_id": TextToImageModels[selected_model].model_id, "scheduler": Schedulers[scheduler].model_id if scheduler else None, "disable_safety_checker": True, "seed": seed, @@ -452,7 +452,7 @@ def img2img( _resolution_check(width, height) match selected_model: - case Img2ImgModels.dall_e.name: + case ImageToImageModels.dall_e.name: from openai import OpenAI edge = _get_dall_e_img_size(width, height) @@ -483,7 +483,7 @@ def img2img( return call_sd_multi( "diffusion.img2img", pipeline={ - "model_id": Img2ImgModels[selected_model].model_id, + "model_id": ImageToImageModels[selected_model].model_id, # "scheduler": "UniPCMultistepScheduler", "disable_safety_checker": True, "seed": seed, @@ -526,7 +526,7 @@ def controlnet( return call_sd_multi( "diffusion.controlnet", pipeline={ - "model_id": Text2ImgModels[selected_model].model_id, + "model_id": TextToImageModels[selected_model].model_id, "seed": seed, "scheduler": ( Schedulers[scheduler].model_id @@ -553,13 +553,13 @@ def controlnet( def add_prompt_prefix(prompt: str, selected_model: str) -> str: match selected_model: - case Text2ImgModels.openjourney.name: + case TextToImageModels.openjourney.name: prompt = "mdjrny-v4 style " + prompt - case Text2ImgModels.analog_diffusion.name: + case TextToImageModels.analog_diffusion.name: prompt = "analog style " + prompt - case Text2ImgModels.protogen_5_3.name: + case TextToImageModels.protogen_5_3.name: prompt = "modelshoot style " + prompt - case Text2ImgModels.dreamlike_2.name: + case TextToImageModels.dreamlike_2.name: prompt = "photo, " + prompt return prompt diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 5c7b7b3f0..ee0e985fb 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -20,7 +20,7 @@ from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, + TextToImageModels, text2img, instruct_pix2pix, sd_upscale, @@ -66,14 +66,14 @@ class RequestModel(BasePage.RequestModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[Text2ImgModels.api_enum] | None + selected_models: list[TextToImageModels.api_enum] | None scheduler: Schedulers.api_enum | None edit_instruction: str | None image_guidance_scale: float | None class ResponseModel(BaseModel): - output_images: dict[Text2ImgModels.api_enum, list[FieldHttpUrl]] + output_images: dict[TextToImageModels.api_enum, list[FieldHttpUrl]] @classmethod def get_example_preferred_fields(cls, state: dict) -> list[str]: @@ -120,7 +120,7 @@ def render_form_v2(self): """ ) enum_multiselect( - Text2ImgModels, + TextToImageModels, key="selected_models", ) @@ -188,7 +188,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_images"] = output_images = {} for selected_model in request.selected_models: - yield f"Running {Text2ImgModels[selected_model].label}..." + yield f"Running {TextToImageModels[selected_model].label}..." output_images[selected_model] = text2img( selected_model=selected_model, @@ -249,7 +249,7 @@ def _render_outputs(self, state): output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: gui.image( - img, caption=Text2ImgModels[key].label, show_download_button=True + img, caption=TextToImageModels[key].label, show_download_button=True ) def preview_description(self, state: dict) -> str: @@ -260,9 +260,9 @@ def get_raw_price(self, state: dict) -> int: total = 0 for name in selected_models: match name: - case Text2ImgModels.deepfloyd_if.name: + case TextToImageModels.deepfloyd_if.name: total += 5 - case Text2ImgModels.dall_e.name | Text2ImgModels.dall_e_3.name: + case TextToImageModels.dall_e.name | TextToImageModels.dall_e_3.name: total += 15 case _: total += 2 diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index af64c45cd..e882852f5 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -24,7 +24,7 @@ ) from daras_ai_v2.stable_diffusion import ( img2img, - Img2ImgModels, + ImageToImageModels, SD_IMG_MAX_SIZE, instruct_pix2pix, ) @@ -55,7 +55,7 @@ class RequestModel(GoogleSearchLocationMixin, BasePage.RequestModel): search_query: str text_prompt: str - selected_model: Img2ImgModels.api_enum | None + selected_model: ImageToImageModels.api_enum | None negative_prompt: str | None @@ -153,7 +153,7 @@ def run(self, state: dict): yield "Generating Images..." - if request.selected_model == Img2ImgModels.instruct_pix2pix.name: + if request.selected_model == ImageToImageModels.instruct_pix2pix.name: state["output_images"] = instruct_pix2pix( prompt=request.text_prompt, num_outputs=request.num_outputs, @@ -186,7 +186,7 @@ def render_form_v2(self): """, key="search_query", ) - model_selector(Img2ImgModels) + model_selector(ImageToImageModels) gui.text_area( """ #### 👩‍💻 Prompt @@ -200,7 +200,7 @@ def render_usage_guide(self): youtube_video("rnjvtaYYe8g") def render_settings(self): - img_model_settings(Img2ImgModels, render_model_selector=False) + img_model_settings(ImageToImageModels, render_model_selector=False) serp_search_location_selectbox() def render_output(self): diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index e7140a632..5e276ec44 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -10,7 +10,7 @@ from daras_ai_v2.img_model_settings_widgets import img_model_settings from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.stable_diffusion import ( - Img2ImgModels, + ImageToImageModels, img2img, SD_IMG_MAX_SIZE, instruct_pix2pix, @@ -46,7 +46,7 @@ class RequestModel(BasePage.RequestModel): input_image: FieldHttpUrl text_prompt: str | None - selected_model: Img2ImgModels.api_enum | None + selected_model: ImageToImageModels.api_enum | None selected_controlnet_model: ( list[ControlNetModels.api_enum] | ControlNetModels.api_enum | None ) @@ -123,7 +123,7 @@ def render_description(self): ) def render_settings(self): - img_model_settings(Img2ImgModels) + img_model_settings(ImageToImageModels) def render_usage_guide(self): youtube_video("narcZNyuNAg") @@ -160,7 +160,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield "Generating Image..." - if request.selected_model == Img2ImgModels.instruct_pix2pix.name: + if request.selected_model == ImageToImageModels.instruct_pix2pix.name: state["output_images"] = instruct_pix2pix( prompt=request.text_prompt, num_outputs=request.num_outputs, @@ -204,7 +204,7 @@ def preview_description(self, state: dict) -> str: def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: - case Img2ImgModels.dall_e.name: + case ImageToImageModels.dall_e.name: unit_price = 20 case _: unit_price = 5 diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index e91b1a03c..34ea1f076 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -30,10 +30,10 @@ from daras_ai_v2.repositioning import reposition_object, repositioning_preview_widget from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, + TextToImageModels, controlnet, ControlNetModels, - Img2ImgModels, + ImageToImageModels, Schedulers, ) from daras_ai_v2.vcard import VCARD @@ -96,7 +96,7 @@ class RequestModel(BasePage.RequestModel): image_prompt_pos_x: float | None image_prompt_pos_y: float | None - selected_model: Text2ImgModels.api_enum | None + selected_model: TextToImageModels.api_enum | None selected_controlnet_model: list[ControlNetModels.api_enum] | None output_width: int | None @@ -293,7 +293,7 @@ def render_settings(self): ) img_model_settings( - Img2ImgModels, + ImageToImageModels, show_scheduler=True, require_controlnet=True, extra_explanations={ @@ -482,7 +482,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["raw_images"] = raw_images = [] - yield f"Running {Text2ImgModels[request.selected_model].label}..." + yield f"Running {TextToImageModels[request.selected_model].label}..." if isinstance(request.selected_controlnet_model, str): request.selected_controlnet_model = [request.selected_controlnet_model] init_images = [image] * len(request.selected_controlnet_model) @@ -543,12 +543,14 @@ def preview_description(self, state: dict) -> str: """ def get_raw_price(self, state: dict) -> int: - selected_model = state.get("selected_model", Text2ImgModels.dream_shaper.name) + selected_model = state.get( + "selected_model", TextToImageModels.dream_shaper.name + ) total = 5 match selected_model: - case Text2ImgModels.deepfloyd_if.name: + case TextToImageModels.deepfloyd_if.name: total += 3 - case Text2ImgModels.dall_e.name: + case TextToImageModels.dall_e.name: total += 10 return total * state.get("num_outputs", 1) diff --git a/usage_costs/models.py b/usage_costs/models.py index 2d1b38271..37ce69256 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -58,13 +58,13 @@ class ModelProvider(models.IntegerChoices): def get_model_choices(): from daras_ai_v2.language_model import LargeLanguageModels from recipes.DeforumSD import AnimationModels - from daras_ai_v2.stable_diffusion import Text2ImgModels, Img2ImgModels + from daras_ai_v2.stable_diffusion import TextToImageModels, ImageToImageModels return ( [(api.name, api.value) for api in LargeLanguageModels] + [(model.name, model.label) for model in AnimationModels] - + [(model.name, model.label) for model in Text2ImgModels] - + [(model.name, model.label) for model in Img2ImgModels] + + [(model.name, model.label) for model in TextToImageModels] + + [(model.name, model.label) for model in ImageToImageModels] + [(model.name, model.label) for model in InpaintingModels] + [("wav2lip", "LipSync (wav2lip)")] ) From 4c46fe6cc25ff51aed21cb7014d18d99f328643c Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:07:23 +0530 Subject: [PATCH 32/38] Rename for parity with SDK --- scripts/init_self_hosted_pricing.py | 10 +++++++--- scripts/run_all_diffusion.py | 26 +++++++++++++------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 5d7ef48b8..51d89a446 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -1,7 +1,11 @@ from decimal import Decimal from daras_ai_v2.gpu_server import build_queue_name -from daras_ai_v2.stable_diffusion import Text2ImgModels, Img2ImgModels, InpaintingModels +from daras_ai_v2.stable_diffusion import ( + TextToImageModels, + ImageToImageModels, + InpaintingModels, +) from recipes.DeforumSD import AnimationModels from usage_costs.models import ModelPricing from usage_costs.models import ModelSku, ModelCategory, ModelProvider @@ -12,8 +16,8 @@ def run(): for model_enum in [ AnimationModels, - Text2ImgModels, - Img2ImgModels, + TextToImageModels, + ImageToImageModels, InpaintingModels, ]: for m in model_enum: diff --git a/scripts/run_all_diffusion.py b/scripts/run_all_diffusion.py index 183260c5d..c6132b953 100644 --- a/scripts/run_all_diffusion.py +++ b/scripts/run_all_diffusion.py @@ -20,9 +20,9 @@ from daras_ai_v2.stable_diffusion import ( controlnet, ControlNetModels, - Img2ImgModels, + ImageToImageModels, text2img, - Text2ImgModels, + TextToImageModels, img2img, instruct_pix2pix, sd_upscale, @@ -34,7 +34,7 @@ # def fn(): # text2img( -# selected_model=Img2ImgModels.sd_1_5.name, +# selected_model=ImageToImageModels.sd_1_5.name, # prompt=get_random_string(100, string.ascii_letters), # num_outputs=1, # num_inference_steps=1, @@ -45,7 +45,7 @@ # # r = requests.get(GpuEndpoints.sd_multi / "magic") # # raise_for_status(r) # # img2img( -# # selected_model=Img2ImgModels.sd_1_5.name, +# # selected_model=ImageToImageModels.sd_1_5.name, # # prompt=get_random_string(100, string.ascii_letters), # # num_outputs=1, # # init_image=random_img, @@ -55,7 +55,7 @@ # # ) # # controlnet( # # selected_controlnet_model=ControlNetModels.sd_controlnet_depth.name, -# # selected_model=Img2ImgModels.sd_1_5.name, +# # selected_model=ImageToImageModels.sd_1_5.name, # # prompt=get_random_string(100, string.ascii_letters), # # num_outputs=1, # # init_image=random_img, @@ -72,11 +72,11 @@ # exit() tasks = [] -for model in Img2ImgModels: +for model in ImageToImageModels: if model in [ - Img2ImgModels.instruct_pix2pix, - Img2ImgModels.dall_e, - Img2ImgModels.jack_qiao, + ImageToImageModels.instruct_pix2pix, + ImageToImageModels.dall_e, + ImageToImageModels.jack_qiao, ]: continue print(model) @@ -96,7 +96,7 @@ ) for controlnet_model in ControlNetModels: if model in [ - Img2ImgModels.sd_2, + ImageToImageModels.sd_2, ]: continue print(controlnet_model) @@ -115,10 +115,10 @@ ) ) -for model in Text2ImgModels: +for model in TextToImageModels: if model in [ - Text2ImgModels.dall_e, - Text2ImgModels.jack_qiao, + TextToImageModels.dall_e, + TextToImageModels.jack_qiao, ]: continue print(model) From 55e5ea9c7b2b80c5696b5ad5b06229071dffd6cc Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:08:21 +0530 Subject: [PATCH 33/38] x-fern-type-name for VCard field in QRCode --- recipes/QRCodeGenerator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 34ea1f076..7f61f8d76 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -82,7 +82,9 @@ def __init__(self, *args, **kwargs): class RequestModel(BasePage.RequestModel): qr_code_data: str | None qr_code_input_image: FieldHttpUrl | None - qr_code_vcard: VCARD | None = Field(title="VCard") + qr_code_vcard: VCARD | None = Field( + title="VCard", **{"x-fern-type-name": "VCard"} + ) qr_code_file: FieldHttpUrl | None use_url_shortener: bool | None From e48b197a2e9ee4d5a506445a606a923372a51389 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:24:33 +0530 Subject: [PATCH 34/38] more renaming and remove api_choices in favor of api_enum --- daras_ai_v2/custom_enum.py | 5 ----- daras_ai_v2/vector_search.py | 4 +++- functions/models.py | 4 ++-- recipes/Img2Img.py | 4 ++-- recipes/TextToSpeech.py | 4 ++-- 5 files changed, 9 insertions(+), 12 deletions(-) diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py index 1f781afb9..270d4834c 100644 --- a/daras_ai_v2/custom_enum.py +++ b/daras_ai_v2/custom_enum.py @@ -44,11 +44,6 @@ def from_db(cls, db_value) -> typing_extensions.Self: return e raise ValueError(f"Invalid {cls.__name__} {db_value=}") - @classmethod - @property - def api_choices(cls): - return typing.Literal[tuple(e.api_value for e in cls)] - @classmethod @property @cached_classmethod diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index fe694d56b..ba0c4ce59 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -75,7 +75,9 @@ class Config: use_enum_values = True search_query: str - keyword_query: str | list[str] | None + keyword_query: str | list[str] | None = Field( + **{"x-fern-type-name": "KeywordQuery"} + ) documents: list[str] | None diff --git a/functions/models.py b/functions/models.py index be1be7a5e..711805d64 100644 --- a/functions/models.py +++ b/functions/models.py @@ -22,7 +22,7 @@ class RecipeFunction(BaseModel): title="URL", description="The URL of the [function](https://gooey.ai/functions) to call.", ) - trigger: FunctionTrigger.api_choices = Field( + trigger: FunctionTrigger.api_enum = Field( title="Trigger", description="When to run this function. `pre` runs before the recipe, `post` runs after the recipe.", ) @@ -30,7 +30,7 @@ class RecipeFunction(BaseModel): class CalledFunctionResponse(BaseModel): url: str - trigger: FunctionTrigger.api_choices + trigger: FunctionTrigger.api_enum return_value: typing.Any @classmethod diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 5e276ec44..601a1d378 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -2,7 +2,7 @@ from daras_ai_v2.pydantic_validation import FieldHttpUrl import requests -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -49,7 +49,7 @@ class RequestModel(BasePage.RequestModel): selected_model: ImageToImageModels.api_enum | None selected_controlnet_model: ( list[ControlNetModels.api_enum] | ControlNetModels.api_enum | None - ) + ) = Field(**{"x-fern-type-name": "SelectedControlNetModels"}) negative_prompt: str | None num_outputs: int | None diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 6a66e8cc0..0488a6fba 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -55,8 +55,8 @@ class TextToSpeechSettings(BaseModel): azure_voice_name: str | None - openai_voice_name: OpenAI_TTS_Voices.api_choices | None - openai_tts_model: OpenAI_TTS_Models.api_choices | None + openai_voice_name: OpenAI_TTS_Voices.api_enum | None + openai_tts_model: OpenAI_TTS_Models.api_enum | None class TextToSpeechPage(BasePage): From 25abd67e3d29be13499ea06bddfd8d73180de194 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:48:18 +0530 Subject: [PATCH 35/38] Fix defaults for serp_search_type and serp_search_location --- daras_ai_v2/serp_search_locations.py | 6 +----- recipes/GoogleGPT.py | 4 ++-- recipes/GoogleImageGen.py | 4 ++-- recipes/RelatedQnA.py | 4 ++-- recipes/RelatedQnADoc.py | 4 ++-- recipes/SEOSummary.py | 4 ++-- 6 files changed, 11 insertions(+), 15 deletions(-) diff --git a/daras_ai_v2/serp_search_locations.py b/daras_ai_v2/serp_search_locations.py index b2049e1ef..bcb39c053 100644 --- a/daras_ai_v2/serp_search_locations.py +++ b/daras_ai_v2/serp_search_locations.py @@ -43,7 +43,7 @@ def serp_search_location_selectbox(key="serp_search_location"): options=[e.api_value for e in SerpSearchLocations], format_func=lambda e: f"{SerpSearchLocations.from_api(e).label} ({e})", key=key, - value=SerpSearchLocations.UNITED_STATES.name, + value=SerpSearchLocations.UNITED_STATES.api_value, ) @@ -58,10 +58,6 @@ class SerpSearchType(GooeyEnum): def label(self): return self.value - @property - def api_value(self): - return self.name - class SerpSearchLocation(typing.NamedTuple): api_value: str diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 4e03c0bea..93d7a0ed2 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -56,8 +56,8 @@ class GoogleGPTPage(BasePage): keywords="outdoor rugs,8x10 rugs,rug sizes,checkered rugs,5x7 rugs", title="Ruggable", company_url="https://ruggable.com", - serp_search_type=SerpSearchType.search, - serp_search_location=SerpSearchLocations.UNITED_STATES.name, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, enable_html=False, selected_model=LargeLanguageModels.text_davinci_003.name, sampling_temperature=0.8, diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index e882852f5..61488b82d 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -47,8 +47,8 @@ class GoogleImageGenPage(BasePage): sd_2_upscaling=False, seed=42, image_guidance_scale=1.2, - serp_search_type=SerpSearchType.search, - serp_search_location=SerpSearchLocations.UNITED_STATES.name, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, ) class RequestModel(GoogleSearchLocationMixin, BasePage.RequestModel): diff --git a/recipes/RelatedQnA.py b/recipes/RelatedQnA.py index ee1a064e4..f43c84097 100644 --- a/recipes/RelatedQnA.py +++ b/recipes/RelatedQnA.py @@ -34,8 +34,8 @@ class RelatedQnAPage(BasePage): max_context_words=200, scroll_jump=5, dense_weight=1.0, - serp_search_type=SerpSearchType.search, - serp_search_location=SerpSearchLocations.UNITED_STATES.name, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, ) class RequestModel(GoogleGPTPage.RequestModel): diff --git a/recipes/RelatedQnADoc.py b/recipes/RelatedQnADoc.py index 3e4fab620..8362a7046 100644 --- a/recipes/RelatedQnADoc.py +++ b/recipes/RelatedQnADoc.py @@ -31,8 +31,8 @@ class RelatedQnADocPage(BasePage): sane_defaults = dict( citation_style=CitationStyles.number.name, dense_weight=1.0, - serp_search_type=SerpSearchType.search, - serp_search_location=SerpSearchLocations.UNITED_STATES.name, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, ) class RequestModel(GoogleSearchMixin, DocSearchPage.RequestModel): diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index c53d03333..2274b4662 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -74,8 +74,8 @@ def preview_description(self, state: dict) -> str: keywords="outdoor rugs,8x10 rugs,rug sizes,checkered rugs,5x7 rugs", title="Ruggable", company_url="https://ruggable.com", - serp_search_type=SerpSearchType.search, - serp_search_location=SerpSearchLocations.UNITED_STATES.name, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, enable_html=False, selected_model=LargeLanguageModels.text_davinci_003.name, sampling_temperature=0.8, From d75faeccc51ef741ef9d41e0119009e0f5d21ae8 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:58:07 +0530 Subject: [PATCH 36/38] remove debug print statements --- recipes/DeforumSD.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index db47dabe1..eba951bf4 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -467,10 +467,6 @@ def run(self, state: dict): if not self.request.user.disable_safety_checker: safety_checker(text=self.preview_input(state)) - print("selected_model", request.selected_model) - print(f'{state["selected_model"]=}') - print(f"{type(request.selected_model)=}") - try: state["output_video"] = call_celery_task_outfile( "deforum", From 0713227a3404df07c84e1c58d5016fc4fa4fbaf4 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 16 Sep 2024 13:18:22 +0530 Subject: [PATCH 37/38] Revert "Add openapi params for fern bearer auth, hide healthcheck from fern" This reverts commit 1b5f46f95f5138cc23669e3f08953917ad0b9b27. --- auth/token_authentication.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/auth/token_authentication.py b/auth/token_authentication.py index a1281faa7..483e291b6 100644 --- a/auth/token_authentication.py +++ b/auth/token_authentication.py @@ -1,5 +1,3 @@ -from typing import Any - from fastapi import Request from fastapi.exceptions import HTTPException from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType @@ -57,7 +55,7 @@ class APIAuth(SecurityBase): ### Usage: ```python - api_auth = APIAuth(scheme_name="bearer", description="Bearer $GOOEY_API_KEY") + api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY") @app.get("/api/users") def get_users(authenticated_user: AppUser = Depends(api_auth)): @@ -65,14 +63,9 @@ def get_users(authenticated_user: AppUser = Depends(api_auth)): ``` """ - def __init__( - self, scheme_name: str, description: str, openapi_extra: dict[str, Any] = None - ): + def __init__(self, scheme_name: str, description: str): self.model = HTTPBaseModel( - type=SecuritySchemeType.http, - scheme=scheme_name, - description=description, - **(openapi_extra or {}), + type=SecuritySchemeType.http, scheme=scheme_name, description=description ) self.scheme_name = scheme_name self.description = description @@ -95,9 +88,7 @@ def __call__(self, request: Request) -> AppUser: return authenticate_credentials(auth[1]) -auth_scheme = "bearer" +auth_scheme = "Bearer" api_auth_header = APIAuth( - scheme_name=auth_scheme, - description=f"{auth_scheme} $GOOEY_API_KEY", - openapi_extra={"x-fern-bearer": {"name": "apiKey", "env": "GOOEY_API_KEY"}}, + scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY" ) From 9a5f74bf88ecf9b7c90c6e52a0410fcb12ddef07 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 16 Sep 2024 13:20:02 +0530 Subject: [PATCH 38/38] Revert "Refactor auth code to output auth scheme in OpenAPI spec" This reverts commit 6442d2deb868a59edd74959b14523aa9aad2571b. --- auth/token_authentication.py | 91 +++++++++++------------------- daras_ai_v2/api_examples_widget.py | 38 ++++++------- 2 files changed, 51 insertions(+), 78 deletions(-) diff --git a/auth/token_authentication.py b/auth/token_authentication.py index 483e291b6..b33bbbbd0 100644 --- a/auth/token_authentication.py +++ b/auth/token_authentication.py @@ -1,27 +1,39 @@ -from fastapi import Request +import threading + +from fastapi import Header from fastapi.exceptions import HTTPException -from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType -from fastapi.security.base import SecurityBase -from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from app_users.models import AppUser from auth.auth_backend import authlocal from daras_ai_v2 import db from daras_ai_v2.crypto import PBKDF2PasswordHasher +auth_keyword = "Bearer" -class AuthenticationError(HTTPException): - status_code = HTTP_401_UNAUTHORIZED - - def __init__(self, msg: str): - super().__init__(status_code=self.status_code, detail={"error": msg}) +def api_auth_header( + authorization: str = Header( + alias="Authorization", + description=f"{auth_keyword} $GOOEY_API_KEY", + ), +) -> AppUser: + if authlocal: + return authlocal[0] + return authenticate(authorization) -class AuthorizationError(HTTPException): - status_code = HTTP_403_FORBIDDEN - def __init__(self, msg: str): - super().__init__(status_code=self.status_code, detail={"error": msg}) +def authenticate(auth_token: str) -> AppUser: + auth = auth_token.split() + if not auth or auth[0].lower() != auth_keyword.lower(): + msg = "Invalid Authorization header." + raise HTTPException(status_code=401, detail={"error": msg}) + if len(auth) == 1: + msg = "Invalid Authorization header. No credentials provided." + raise HTTPException(status_code=401, detail={"error": msg}) + elif len(auth) > 2: + msg = "Invalid Authorization header. Token string should not contain spaces." + raise HTTPException(status_code=401, detail={"error": msg}) + return authenticate_credentials(auth[1]) def authenticate_credentials(token: str) -> AppUser: @@ -36,7 +48,12 @@ def authenticate_credentials(token: str) -> AppUser: .get()[0] ) except IndexError: - raise AuthorizationError("Invalid API Key.") + raise HTTPException( + status_code=403, + detail={ + "error": "Invalid API Key.", + }, + ) uid = doc.get("uid") user = AppUser.objects.get_or_create_from_uid(uid)[0] @@ -45,50 +62,6 @@ def authenticate_credentials(token: str) -> AppUser: "Your Gooey.AI account has been disabled for violating our Terms of Service. " "Contact us at support@gooey.ai if you think this is a mistake." ) - raise AuthenticationError(msg) + raise HTTPException(status_code=401, detail={"error": msg}) return user - - -class APIAuth(SecurityBase): - """ - ### Usage: - - ```python - api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY") - - @app.get("/api/users") - def get_users(authenticated_user: AppUser = Depends(api_auth)): - ... - ``` - """ - - def __init__(self, scheme_name: str, description: str): - self.model = HTTPBaseModel( - type=SecuritySchemeType.http, scheme=scheme_name, description=description - ) - self.scheme_name = scheme_name - self.description = description - - def __call__(self, request: Request) -> AppUser: - if authlocal: # testing only! - return authlocal[0] - - auth = request.headers.get("Authorization", "").split() - if not auth or auth[0].lower() != self.scheme_name.lower(): - raise AuthenticationError("Invalid Authorization header.") - if len(auth) == 1: - raise AuthenticationError( - "Invalid Authorization header. No credentials provided." - ) - elif len(auth) > 2: - raise AuthenticationError( - "Invalid Authorization header. Token string should not contain spaces." - ) - return authenticate_credentials(auth[1]) - - -auth_scheme = "Bearer" -api_auth_header = APIAuth( - scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY" -) diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py index ef54058c2..9cd9beeb1 100644 --- a/daras_ai_v2/api_examples_widget.py +++ b/daras_ai_v2/api_examples_widget.py @@ -6,7 +6,7 @@ from furl import furl import gooey_gui as gui -from auth.token_authentication import auth_scheme +from auth.token_authentication import auth_keyword from daras_ai_v2 import settings from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url @@ -48,12 +48,12 @@ def api_example_generator( if as_form_data: curl_code = r""" curl %(api_url)s \ - -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \ + -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \ %(files)s \ -F json=%(json)s """ % dict( api_url=shlex.quote(api_url), - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, files=" \\\n ".join( f"-F {key}=@{shlex.quote(filename)}" for key, filename in filenames ), @@ -62,12 +62,12 @@ def api_example_generator( else: curl_code = r""" curl %(api_url)s \ - -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \ + -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \ -H 'Content-Type: application/json' \ -d %(json)s """ % dict( api_url=shlex.quote(api_url), - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=shlex.quote(json.dumps(request_body, indent=2)), ) if as_async: @@ -77,7 +77,7 @@ def api_example_generator( ) while true; do - result=$(curl $status_url -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY") + result=$(curl $status_url -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY") status=$(echo $result | jq -r '.status') if [ "$status" = "completed" ]; then echo $result @@ -91,7 +91,7 @@ def api_example_generator( """ % dict( curl_code=indent(curl_code.strip(), " " * 2), api_url=shlex.quote(api_url), - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=shlex.quote(json.dumps(request_body, indent=2)), ) @@ -128,7 +128,7 @@ def api_example_generator( response = requests.post( "%(api_url)s", headers={ - "Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"], }, files=files, data={"json": json.dumps(payload)}, @@ -140,7 +140,7 @@ def api_example_generator( ), json=repr(request_body), api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: py_code = r""" @@ -152,14 +152,14 @@ def api_example_generator( response = requests.post( "%(api_url)s", headers={ - "Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"], }, json=payload, ) assert response.ok, response.content """ % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=repr(request_body), ) if as_async: @@ -168,7 +168,7 @@ def api_example_generator( status_url = response.headers["Location"] while True: - response = requests.get(status_url, headers={"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"]}) + response = requests.get(status_url, headers={"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"]}) assert response.ok, response.content result = response.json() if result["status"] == "completed": @@ -181,7 +181,7 @@ def api_example_generator( sleep(3) """ % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: py_code += r""" @@ -229,7 +229,7 @@ def api_example_generator( const response = await fetch("%(api_url)s", { method: "POST", headers: { - "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], }, body: formData, }); @@ -243,7 +243,7 @@ def api_example_generator( " " * 2, ), api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: @@ -256,14 +256,14 @@ def api_example_generator( const response = await fetch("%(api_url)s", { method: "POST", headers: { - "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], "Content-Type": "application/json", }, body: JSON.stringify(payload), }); """ % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, json=json.dumps(request_body, indent=2), ) @@ -280,7 +280,7 @@ def api_example_generator( const response = await fetch(status_url, { method: "GET", headers: { - "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], }, }); if (!response.ok) { @@ -299,7 +299,7 @@ def api_example_generator( } }""" % dict( api_url=api_url, - auth_scheme=auth_scheme, + auth_keyword=auth_keyword, ) else: js_code += """