Skip to content

Commit

Permalink
Merge branch 'master' into org-billing
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Sep 2, 2024
2 parents e538574 + 45e5f32 commit 9baa904
Show file tree
Hide file tree
Showing 13 changed files with 431 additions and 111 deletions.
2 changes: 1 addition & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
except Exception as e:
if isinstance(e, UserError):
sentry_level = e.sentry_level
logger.warning(e)
logger.warning("\n".join(map(str, [e, e.__cause__])))
else:
sentry_level = "error"
traceback.print_exc()
Expand Down
11 changes: 9 additions & 2 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import typing
from enum import Enum

import gooey_gui as gui
import requests
import typing_extensions
from django.db.models import F
from furl import furl

import gooey_gui as gui
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.azure_asr import azure_asr
Expand All @@ -31,6 +31,7 @@
from daras_ai_v2.google_asr import gcp_asr_v1
from daras_ai_v2.gpu_server import call_celery_task
from daras_ai_v2.redis_cache import redis_cache_decorator
from daras_ai_v2.scraping_proxy import SCRAPING_PROXIES, get_scraping_proxy_cert_path
from daras_ai_v2.text_splitter import text_splitter

TRANSLATE_BATCH_SIZE = 8
Expand Down Expand Up @@ -988,13 +989,19 @@ def download_youtube_to_wav(youtube_url: str) -> bytes:
with _yt_dlp_lock, tempfile.TemporaryDirectory() as tmpdir:
infile = os.path.join(tmpdir, "infile")
outfile = os.path.join(tmpdir, "outfile.wav")
proxy_args = []
if proxy := SCRAPING_PROXIES.get("https"):
proxy_args += ["--proxy", proxy]
if cert := get_scraping_proxy_cert_path():
proxy_args += ["--client-certificate-key", cert]
# run yt-dlp to download audio
call_cmd(
"yt-dlp",
"yt-dlp", "-v",
"--no-playlist",
"--max-downloads", "1",
"--format", "bestaudio",
"--output", infile,
*proxy_args,
youtube_url,
# ignore MaxDownloadsReached - https://github.com/ytdl-org/youtube-dl/blob/a452f9437c8a3048f75fc12f75bcfd3eed78430f/youtube_dl/__init__.py#L468
ok_returncodes={101},
Expand Down
100 changes: 90 additions & 10 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import json
import mimetypes
import re
import typing
Expand Down Expand Up @@ -75,7 +76,7 @@ class LargeLanguageModels(Enum):
# https://platform.openai.com/docs/models/gpt-4o
gpt_4_o = LLMSpec(
label="GPT-4o (openai)",
model_id=("openai-gpt-4o-prod-eastus2-1", "gpt-4o"),
model_id="gpt-4o-2024-08-06",
llm_api=LLMApis.openai,
context_window=128_000,
price=10,
Expand All @@ -92,6 +93,14 @@ class LargeLanguageModels(Enum):
is_vision_model=True,
supports_json=True,
)
chatgpt_4_o = LLMSpec(
label="ChatGPT-4o (openai) 🧪",
model_id="chatgpt-4o-latest",
llm_api=LLMApis.openai,
context_window=128_000,
price=10,
is_vision_model=True,
)
# https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4
gpt_4_turbo_vision = LLMSpec(
label="GPT-4 Turbo with Vision (openai)",
Expand Down Expand Up @@ -232,13 +241,23 @@ class LargeLanguageModels(Enum):
)

# https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
gemini_1_5_flash = LLMSpec(
label="Gemini 1.5 Flash (Google)",
model_id="gemini-1.5-flash",
llm_api=LLMApis.gemini,
context_window=1_048_576,
price=15,
is_vision_model=True,
supports_json=True,
)
gemini_1_5_pro = LLMSpec(
label="Gemini 1.5 Pro (Google)",
model_id="gemini-1.5-pro-preview-0409",
model_id="gemini-1.5-pro",
llm_api=LLMApis.gemini,
context_window=1_000_000,
context_window=2_097_152,
price=15,
is_vision_model=True,
supports_json=True,
)
gemini_1_pro_vision = LLMSpec(
label="Gemini 1.0 Pro Vision (Google)",
Expand Down Expand Up @@ -280,6 +299,7 @@ class LargeLanguageModels(Enum):
context_window=200_000,
price=15,
is_vision_model=True,
supports_json=True,
)
claude_3_opus = LLMSpec(
label="Claude 3 Opus [L] (Anthropic)",
Expand All @@ -288,6 +308,7 @@ class LargeLanguageModels(Enum):
context_window=200_000,
price=75,
is_vision_model=True,
supports_json=True,
)
claude_3_sonnet = LLMSpec(
label="Claude 3 Sonnet [M] (Anthropic)",
Expand All @@ -296,6 +317,7 @@ class LargeLanguageModels(Enum):
context_window=200_000,
price=15,
is_vision_model=True,
supports_json=True,
)
claude_3_haiku = LLMSpec(
label="Claude 3 Haiku [S] (Anthropic)",
Expand All @@ -304,6 +326,7 @@ class LargeLanguageModels(Enum):
context_window=200_000,
price=2,
is_vision_model=True,
supports_json=True,
)

sea_lion_7b_instruct = LLMSpec(
Expand Down Expand Up @@ -666,6 +689,7 @@ def _run_chat_model(
messages=messages,
max_output_tokens=min(max_tokens, 1024), # because of Vertex AI limits
temperature=temperature,
response_format_type=response_format_type,
)
case LLMApis.palm2:
if tools:
Expand Down Expand Up @@ -696,6 +720,7 @@ def _run_chat_model(
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
response_format_type=response_format_type,
)
case LLMApis.self_hosted:
return [
Expand Down Expand Up @@ -785,6 +810,7 @@ def _run_anthropic_chat(
max_tokens: int,
temperature: float,
stop: list[str] | None,
response_format_type: ResponseFormatType | None,
):
import anthropic
from usage_costs.cost_utils import record_cost_auto
Expand Down Expand Up @@ -818,6 +844,27 @@ def _run_anthropic_chat(
content = get_entry_text(msg)
anthropic_msgs.append({"role": role, "content": content})

if response_format_type == "json_object":
kwargs = dict(
tools=[
{
"name": "json_output",
"input_schema": {
"type": "object",
"properties": {
"response": {
"type": "object",
"description": "The response to the user's prompt as a JSON object.",
},
},
},
}
],
tool_choice={"type": "tool", "name": "json_output"},
)
else:
kwargs = {}

client = anthropic.Anthropic()
response = client.messages.create(
model=model,
Expand All @@ -826,6 +873,7 @@ def _run_anthropic_chat(
messages=anthropic_msgs,
stop_sequences=stop,
temperature=temperature,
**kwargs,
)

record_cost_auto(
Expand All @@ -839,9 +887,35 @@ def _run_anthropic_chat(
quantity=response.usage.output_tokens,
)

if response_format_type == "json_object":
if response.stop_reason == "max_tokens":
raise UserError(
"Claude’s response got cut off due to hitting the max_tokens limit, and the truncated response contains an incomplete tool use block. "
"Please retry the request with a higher max_tokens value to get the full tool use. "
) from anthropic.AnthropicError(
f"Hit {response.stop_reason=} when generating JSON: {response.content=}"
)
if response.stop_reason != "tool_use":
raise UserError(
f"Claude was unable to generate a JSON response. Please retry the request with a different prompt, or try a different model."
) from anthropic.AnthropicError(
f"Failed to generate JSON response: {response.stop_reason=} {response.content}"
)
for entry in response.content:
if entry.type != "tool_use":
continue
response = entry.input
if isinstance(response, dict):
response = response.get("response", {})
return [
{
"role": CHATML_ROLE_ASSISTANT,
"content": json.dumps(response),
}
]
return [
{
"role": CHATML_ROLE_USER,
"role": CHATML_ROLE_ASSISTANT,
"content": "".join(entry.text for entry in response.content),
}
]
Expand Down Expand Up @@ -1212,6 +1286,7 @@ def _run_gemini_pro(
messages: list[ConversationEntry],
max_output_tokens: int,
temperature: float,
response_format_type: ResponseFormatType | None,
):
contents = []
for entry in messages:
Expand Down Expand Up @@ -1244,6 +1319,7 @@ def _run_gemini_pro(
contents=contents,
max_output_tokens=max_output_tokens,
temperature=temperature,
response_format_type=response_format_type,
)
return [{"role": CHATML_ROLE_ASSISTANT, "content": msg}]

Expand Down Expand Up @@ -1292,18 +1368,22 @@ def _call_gemini_api(
contents: list[dict],
max_output_tokens: int,
temperature: float,
stop: list[str] = None,
stop: list[str] | None = None,
response_format_type: ResponseFormatType | None = None,
) -> str:
session, project = get_google_auth_session()
generation_config = {
"temperature": temperature,
"maxOutputTokens": max_output_tokens,
"stopSequences": stop or [],
}
if response_format_type == "json_object":
generation_config["response_mime_type"] = "application/json"
r = session.post(
f"https://{settings.GCP_REGION}-aiplatform.googleapis.com/v1/projects/{project}/locations/{settings.GCP_REGION}/publishers/google/models/{model_id}:generateContent",
json={
"contents": contents,
"generation_config": {
"temperature": temperature,
"maxOutputTokens": max_output_tokens,
"stopSequences": stop or [],
},
"generation_config": generation_config,
},
)
raise_for_status(r)
Expand Down
44 changes: 44 additions & 0 deletions daras_ai_v2/scraping_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import random

import requests
from furl import furl

from daras_ai_v2 import settings
from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS

if settings.SCRAPING_PROXY_HOST:
SCRAPING_PROXIES = {
scheme: str(
furl(
scheme="http",
origin=settings.SCRAPING_PROXY_HOST,
username=settings.SCRAPING_PROXY_USERNAME,
password=settings.SCRAPING_PROXY_PASSWORD,
),
)
for scheme in ["http", "https"]
}
else:
SCRAPING_PROXIES = {}


def get_scraping_proxy_cert_path() -> str | None:
if not settings.SCRAPING_PROXY_CERT_URL:
return None

path = settings.BASE_DIR / "proxy_ca_crt.pem"
if not path.exists():
settings.logger.info(f"Downloading proxy cert to {path}")
path.write_bytes(requests.get(settings.SCRAPING_PROXY_CERT_URL).content)

return str(path)


def requests_scraping_kwargs() -> dict:
"""Return kwargs for requests library to use scraping proxy and fake user agent."""
return dict(
headers={"User-Agent": random.choice(FAKE_USER_AGENTS)},
proxies=SCRAPING_PROXIES,
verify=get_scraping_proxy_cert_path(),
# verify=False,
)
5 changes: 5 additions & 0 deletions daras_ai_v2/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,8 @@
ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL = config(
"ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL", 60 * 60 * 24, cast=int # 24 hours
)

SCRAPING_PROXY_HOST = config("SCRAPING_PROXY_HOST", "")
SCRAPING_PROXY_USERNAME = config("SCRAPING_PROXY_USERNAME", "")
SCRAPING_PROXY_PASSWORD = config("SCRAPING_PROXY_PASSWORD", "")
SCRAPING_PROXY_CERT_URL = config("SCRAPING_PROXY_CERT_URL", "")
Loading

0 comments on commit 9baa904

Please sign in to comment.