diff --git a/l2m2/client/base_llm_client.py b/l2m2/client/base_llm_client.py index a471afc..ef85993 100644 --- a/l2m2/client/base_llm_client.py +++ b/l2m2/client/base_llm_client.py @@ -521,11 +521,7 @@ async def _call_google( data: Dict[str, Any] = {} if system_prompt is not None: - # Earlier models don't support system prompts, so prepend it to the prompt - if model_id not in ["gemini-1.5-pro"]: - prompt = f"{system_prompt}\n{prompt}" - else: - data["system_instruction"] = {"parts": {"text": system_prompt}} + data["system_instruction"] = {"parts": {"text": system_prompt}} messages: List[Dict[str, Any]] = [] if isinstance(memory, ChatMemory): diff --git a/l2m2/model_info.py b/l2m2/model_info.py index ec9ae65..12fe10e 100644 --- a/l2m2/model_info.py +++ b/l2m2/model_info.py @@ -187,9 +187,9 @@ class ModelEntry(TypedDict): "extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}}, }, }, - "gemini-1.5-pro": { + "gemini-2.0-flash": { "google": { - "model_id": "gemini-1.5-pro", + "model_id": "gemini-2.0-flash-exp", "params": { "temperature": { "default": PROVIDER_DEFAULT, @@ -205,9 +205,9 @@ class ModelEntry(TypedDict): "extras": {"json_mode_arg": {"response_mime_type": "application/json"}}, }, }, - "gemini-1.0-pro": { + "gemini-1.5-flash": { "google": { - "model_id": "gemini-1.0-pro", + "model_id": "gemini-1.5-flash", "params": { "temperature": { "default": PROVIDER_DEFAULT, @@ -220,7 +220,43 @@ class ModelEntry(TypedDict): "max": 8192, }, }, - "extras": {}, + "extras": {"json_mode_arg": {"response_mime_type": "application/json"}}, + }, + }, + "gemini-1.5-flash-8b": { + "google": { + "model_id": "gemini-1.5-flash-8b", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "custom_key": "max_output_tokens", + "default": PROVIDER_DEFAULT, + # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models + "max": 8192, + }, + }, + "extras": {"json_mode_arg": {"response_mime_type": "application/json"}}, + }, + }, + "gemini-1.5-pro": { + "google": { + "model_id": "gemini-1.5-pro", + "params": { + "temperature": { + "default": PROVIDER_DEFAULT, + "max": 2.0, + }, + "max_tokens": { + "custom_key": "max_output_tokens", + "default": PROVIDER_DEFAULT, + # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models + "max": 8192, + }, + }, + "extras": {"json_mode_arg": {"response_mime_type": "application/json"}}, }, }, "claude-3.5-sonnet": { diff --git a/tests/l2m2/client/test_base_llm_client.py b/tests/l2m2/client/test_base_llm_client.py index 1a6363f..92b58e5 100644 --- a/tests/l2m2/client/test_base_llm_client.py +++ b/tests/l2m2/client/test_base_llm_client.py @@ -267,16 +267,6 @@ async def test_call_google_1_5(mock_get_extra_message, mock_llm_post, llm_client await _generic_test_call(llm_client, "google", "gemini-1.5-pro") -@pytest.mark.asyncio -@patch(LLM_POST_PATH) -@patch(GET_EXTRA_MESSAGE_PATH) -async def test_call_google_1_0(mock_get_extra_message, mock_llm_post, llm_client): - mock_get_extra_message.return_value = "extra message" - mock_return_value = {"candidates": [{"content": {"parts": [{"text": "response"}]}}]} - mock_llm_post.return_value = mock_return_value - await _generic_test_call(llm_client, "google", "gemini-1.0-pro") - - @pytest.mark.asyncio @patch(LLM_POST_PATH) @patch(GET_EXTRA_MESSAGE_PATH)