Skip to content

Commit

Permalink
Add/deprecate various gemini models
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed Dec 13, 2024
1 parent 475a2f4 commit f643fe6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
6 changes: 1 addition & 5 deletions l2m2/client/base_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,7 @@ async def _call_google(
data: Dict[str, Any] = {}

if system_prompt is not None:
# Earlier models don't support system prompts, so prepend it to the prompt
if model_id not in ["gemini-1.5-pro"]:
prompt = f"{system_prompt}\n{prompt}"
else:
data["system_instruction"] = {"parts": {"text": system_prompt}}
data["system_instruction"] = {"parts": {"text": system_prompt}}

messages: List[Dict[str, Any]] = []
if isinstance(memory, ChatMemory):
Expand Down
46 changes: 41 additions & 5 deletions l2m2/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ class ModelEntry(TypedDict):
"extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}},
},
},
"gemini-1.5-pro": {
"gemini-2.0-flash": {
"google": {
"model_id": "gemini-1.5-pro",
"model_id": "gemini-2.0-flash-exp",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
Expand All @@ -205,9 +205,9 @@ class ModelEntry(TypedDict):
"extras": {"json_mode_arg": {"response_mime_type": "application/json"}},
},
},
"gemini-1.0-pro": {
"gemini-1.5-flash": {
"google": {
"model_id": "gemini-1.0-pro",
"model_id": "gemini-1.5-flash",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
Expand All @@ -220,7 +220,43 @@ class ModelEntry(TypedDict):
"max": 8192,
},
},
"extras": {},
"extras": {"json_mode_arg": {"response_mime_type": "application/json"}},
},
},
"gemini-1.5-flash-8b": {
"google": {
"model_id": "gemini-1.5-flash-8b",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
"max": 2.0,
},
"max_tokens": {
"custom_key": "max_output_tokens",
"default": PROVIDER_DEFAULT,
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models
"max": 8192,
},
},
"extras": {"json_mode_arg": {"response_mime_type": "application/json"}},
},
},
"gemini-1.5-pro": {
"google": {
"model_id": "gemini-1.5-pro",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
"max": 2.0,
},
"max_tokens": {
"custom_key": "max_output_tokens",
"default": PROVIDER_DEFAULT,
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models
"max": 8192,
},
},
"extras": {"json_mode_arg": {"response_mime_type": "application/json"}},
},
},
"claude-3.5-sonnet": {
Expand Down
10 changes: 0 additions & 10 deletions tests/l2m2/client/test_base_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,6 @@ async def test_call_google_1_5(mock_get_extra_message, mock_llm_post, llm_client
await _generic_test_call(llm_client, "google", "gemini-1.5-pro")


@pytest.mark.asyncio
@patch(LLM_POST_PATH)
@patch(GET_EXTRA_MESSAGE_PATH)
async def test_call_google_1_0(mock_get_extra_message, mock_llm_post, llm_client):
mock_get_extra_message.return_value = "extra message"
mock_return_value = {"candidates": [{"content": {"parts": [{"text": "response"}]}}]}
mock_llm_post.return_value = mock_return_value
await _generic_test_call(llm_client, "google", "gemini-1.0-pro")


@pytest.mark.asyncio
@patch(LLM_POST_PATH)
@patch(GET_EXTRA_MESSAGE_PATH)
Expand Down

0 comments on commit f643fe6

Please sign in to comment.