Skip to content

Commit

Permalink
use api_enum method on GooeyEnum class, use it with LargeLanguageModels
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Sep 12, 2024
1 parent 325cfd2 commit 48abd59
Show file tree
Hide file tree
Showing 13 changed files with 44 additions and 40 deletions.
3 changes: 3 additions & 0 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class BasePage:
)

class RequestModel(BaseModel):
class Config:
use_enum_values = True

functions: list[RecipeFunction] | None = Field(
title="🧩 Developer Tools and Functions",
)
Expand Down
23 changes: 23 additions & 0 deletions daras_ai_v2/custom_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,29 @@ def from_db(cls, db_value) -> typing_extensions.Self:
def api_choices(cls):
return typing.Literal[tuple(e.name for e in cls)]

@classmethod
@property
def api_enum(cls):
"""
Enum that is useful as a type in API requests.
Maps `name`->`name` for all values.
The title (same as the Enum class name) will be
used as the new Enum's title. This will be passed
on to the OpenAPI schema and the generated SDK.
"""
# this cache is a hack to get around a bug where
# dynamic Enums with the same name crash when
# generating the OpenAPI spec
if not hasattr(cls, "_cached_api_enum"):
cls._cached_api_enum = {}
if cls.__name__ not in cls._cached_api_enum:
cls._cached_api_enum[cls.__name__] = Enum(
cls.__name__, {e.name: e.name for e in cls}
)

return cls._cached_api_enum[cls.__name__]

@classmethod
def from_api(cls, name: str) -> typing_extensions.Self:
for e in cls:
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from daras_ai.image_input import gs_url_to_uri, bytes_to_cv2_img, cv2_img_to_bytes
from daras_ai_v2.asr import get_google_auth_session
from daras_ai_v2.custom_enum import GooeyEnum
from daras_ai_v2.exceptions import raise_for_status, UserError
from daras_ai_v2.gpu_server import call_celery_task
from daras_ai_v2.text_splitter import (
Expand Down Expand Up @@ -72,7 +73,7 @@ class LLMSpec(typing.NamedTuple):
supports_json: bool = False


class LargeLanguageModels(Enum):
class LargeLanguageModels(GooeyEnum):
# https://platform.openai.com/docs/models/gpt-4o
gpt_4_o = LLMSpec(
label="GPT-4o (openai)",
Expand Down
4 changes: 1 addition & 3 deletions recipes/BulkEval.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ class RequestModelBase(BasePage.RequestModel):
""",
)

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

class RequestModel(LanguageModelSettings, RequestModelBase):
pass
Expand Down
8 changes: 2 additions & 6 deletions recipes/CompareLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,13 @@ class CompareLLMPage(BasePage):

class RequestModelBase(BasePage.RequestModel):
input_prompt: str | None
selected_models: list[
typing.Literal[tuple(e.name for e in LargeLanguageModels)]
]
selected_models: list[LargeLanguageModels.api_enum] | None

class RequestModel(LanguageModelSettings, RequestModelBase):
pass

class ResponseModel(BaseModel):
output_text: dict[
typing.Literal[tuple(e.name for e in LargeLanguageModels)], list[str]
]
output_text: dict[LargeLanguageModels.api_choices, list[str]]

def preview_image(self, state: dict) -> str | None:
return DEFAULT_COMPARE_LM_META_IMG
Expand Down
4 changes: 1 addition & 3 deletions recipes/DocExtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ class RequestModelBase(BasePage.RequestModel):

task_instructions: str | None

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

class RequestModel(LanguageModelSettings, RequestModelBase):
pass
Expand Down
6 changes: 2 additions & 4 deletions recipes/DocSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing

from furl import furl
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_gui as gui
from bots.models import Workflow
Expand Down Expand Up @@ -72,9 +72,7 @@ class RequestModelBase(DocSearchRequest, BasePage.RequestModel):
task_instructions: str | None
query_instructions: str | None

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None

Expand Down
7 changes: 2 additions & 5 deletions recipes/DocSummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum

from daras_ai_v2.pydantic_validation import FieldHttpUrl
from pydantic import BaseModel, Field
from pydantic import BaseModel

import gooey_gui as gui
from bots.models import Workflow
Expand All @@ -16,7 +16,6 @@
LargeLanguageModels,
run_language_model,
calc_gpt_tokens,
ResponseFormatType,
)
from daras_ai_v2.language_model_settings_widgets import (
language_model_settings,
Expand Down Expand Up @@ -69,9 +68,7 @@ class RequestModelBase(BasePage.RequestModel):
task_instructions: str | None
merge_instructions: str | None

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None

Expand Down
6 changes: 2 additions & 4 deletions recipes/GoogleGPT.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

from furl import furl
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_gui as gui
from bots.models import Workflow
Expand Down Expand Up @@ -82,9 +82,7 @@ class RequestModelBase(BasePage.RequestModel):
task_instructions: str | None
query_instructions: str | None

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

max_search_urls: int | None

Expand Down
4 changes: 1 addition & 3 deletions recipes/SEOSummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ class RequestModelBase(BaseModel):

enable_html: bool | None

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

max_search_urls: int | None

Expand Down
6 changes: 2 additions & 4 deletions recipes/SmartGPT.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

import jinja2.sandbox
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_gui as gui
from bots.models import Workflow
Expand Down Expand Up @@ -39,9 +39,7 @@ class RequestModelBase(BasePage.RequestModel):
reflexion_prompt: str | None
dera_prompt: str | None

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

class RequestModel(LanguageModelSettings, RequestModelBase):
pass
Expand Down
6 changes: 2 additions & 4 deletions recipes/SocialLookupEmail.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing

import requests
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_gui as gui
from bots.models import Workflow
Expand Down Expand Up @@ -55,9 +55,7 @@ class RequestModelBase(BasePage.RequestModel):
# domain: str | None
# key_words: str | None

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None

class RequestModel(LanguageModelSettings, RequestModelBase):
pass
Expand Down
4 changes: 1 addition & 3 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ class RequestModelBase(BasePage.RequestModel):
bot_script: str | None

# llm model
selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)
selected_model: LargeLanguageModels.api_enum | None
document_model: str | None = Field(
title="🩻 Photo / Document Intelligence",
description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? "
Expand Down

0 comments on commit 48abd59

Please sign in to comment.