Skip to content

Commit

Permalink
Add o1 models
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed Dec 17, 2024
1 parent ced52cc commit 94b3d6e
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 27 deletions.
7 changes: 6 additions & 1 deletion l2m2/client/base_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,14 @@ async def _generic_openai_spec_call(
"""Generic call method for providers who follow the OpenAI API spec."""
supports_native_json_mode = "json_mode_arg" in extras

# For o1 and newer, use "developer" messages instead of "system"
system_key = "system"
if provider == "openai" and model_id in ["o1", "o1-preview", "o1-mini"]:
system_key = "developer"

messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": system_key, "content": system_prompt})
if isinstance(memory, ChatMemory):
messages.extend(memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
Expand Down
51 changes: 51 additions & 0 deletions l2m2/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,57 @@ class ModelEntry(TypedDict):
"extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}},
},
},
"o1": {
"openai": {
"model_id": "o1",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
"max": 1.0,
},
"max_tokens": {
"custom_key": "max_completion_tokens",
"default": PROVIDER_DEFAULT,
"max": 4096,
},
},
"extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}},
},
},
"o1-preview": {
"openai": {
"model_id": "o1-preview",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
"max": 1.0,
},
"max_tokens": {
"custom_key": "max_completion_tokens",
"default": PROVIDER_DEFAULT,
"max": 4096,
},
},
"extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}},
},
},
"o1-mini": {
"openai": {
"model_id": "o1-mini",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
"max": 1.0,
},
"max_tokens": {
"custom_key": "max_completion_tokens",
"default": PROVIDER_DEFAULT,
"max": 4096,
},
},
"extras": {"json_mode_arg": {"response_format": {"type": "json_object"}}},
},
},
"gpt-4-turbo": {
"openai": {
"model_id": "gpt-4-turbo-2024-04-09",
Expand Down
62 changes: 36 additions & 26 deletions tests/l2m2/client/test_base_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def test_init_with_providers():
"cohere": "test-key-cohere",
}
assert llm_client.active_providers == {"openai", "cohere"}
assert "gpt-4-turbo" in llm_client.active_models
assert "gpt-4o" in llm_client.active_models
assert "command-r" in llm_client.active_models
assert "claude-3-opus" not in llm_client.active_models

Expand All @@ -89,7 +89,7 @@ async def test_init_with_env_providers():
"cohere": "test-key-cohere",
}
assert llm_client.active_providers == {"openai", "cohere"}
assert "gpt-4-turbo" in llm_client.active_models
assert "gpt-4o" in llm_client.active_models
assert "command-r" in llm_client.active_models
assert "claude-3-opus" not in llm_client.active_models

Expand All @@ -111,7 +111,7 @@ async def test_init_with_env_providers_override():
"anthropic": "new-key-anthropic",
}
assert llm_client.active_providers == {"openai", "cohere", "anthropic"}
assert "gpt-4-turbo" in llm_client.active_models
assert "gpt-4o" in llm_client.active_models
assert "command-r" in llm_client.active_models
assert "claude-3-opus" in llm_client.active_models

Expand All @@ -127,7 +127,7 @@ def test_getters(llm_client):
assert llm_client.get_active_providers() == {"openai", "cohere"}

active_models = llm_client.get_active_models()
assert "gpt-4-turbo" in active_models
assert "gpt-4o" in active_models
assert "command-r" in active_models
assert "claude-3-opus" not in active_models

Expand All @@ -143,7 +143,7 @@ def test_getters(llm_client):
def test_add_provider(llm_client):
llm_client.add_provider("openai", "test-key-openai")
assert "openai" in llm_client.active_providers
assert "gpt-4-turbo" in llm_client.active_models
assert "gpt-4o" in llm_client.active_models


def test_add_provider_invalid(llm_client):
Expand All @@ -165,7 +165,7 @@ def test_remove_provider(llm_client):

assert "openai" not in llm_client.active_providers
assert "anthropic" in llm_client.active_providers
assert "gpt-4-turbo" not in llm_client.active_models
assert "gpt-4o" not in llm_client.active_models
assert "claude-3-opus" in llm_client.active_models


Expand Down Expand Up @@ -253,18 +253,30 @@ async def test_call_openai(mock_get_extra_message, mock_llm_post, llm_client):
mock_get_extra_message.return_value = "extra message"
mock_return_value = {"choices": [{"message": {"content": "response"}}]}
mock_llm_post.return_value = mock_return_value
await _generic_test_call(llm_client, "openai", "gpt-4-turbo")
await _generic_test_call(llm_client, "openai", "gpt-4o")


# Need to test gemini 1.0 and 1.5 separately because of different system prompt handling
# Need to test this separately because of the different system prompt handling
@pytest.mark.asyncio
@patch(LLM_POST_PATH)
@patch(GET_EXTRA_MESSAGE_PATH)
async def test_call_google_1_5(mock_get_extra_message, mock_llm_post, llm_client):
async def test_call_openai_o1_or_newer(
mock_get_extra_message, mock_llm_post, llm_client
):
mock_get_extra_message.return_value = "extra message"
mock_return_value = {"choices": [{"message": {"content": "response"}}]}
mock_llm_post.return_value = mock_return_value
await _generic_test_call(llm_client, "openai", "o1")


@pytest.mark.asyncio
@patch(LLM_POST_PATH)
@patch(GET_EXTRA_MESSAGE_PATH)
async def test_call_google(mock_get_extra_message, mock_llm_post, llm_client):
mock_get_extra_message.return_value = "extra message"
mock_return_value = {"candidates": [{"content": {"parts": [{"text": "response"}]}}]}
mock_llm_post.return_value = mock_return_value
await _generic_test_call(llm_client, "google", "gemini-1.5-pro")
await _generic_test_call(llm_client, "google", "gemini-2.0-flash")


@pytest.mark.asyncio
Expand Down Expand Up @@ -340,7 +352,7 @@ async def test_call_google_gemini_fails(mock_llm_post, llm_client):
@pytest.mark.asyncio
async def test_call_valid_model_not_active(llm_client):
with pytest.raises(ValueError):
await llm_client.call(prompt="Hello", model="gpt-4-turbo")
await llm_client.call(prompt="Hello", model="gpt-4o")


@pytest.mark.asyncio
Expand All @@ -353,16 +365,14 @@ async def test_call_invalid_model(llm_client):
async def test_call_tokens_too_large(llm_client):
llm_client.add_provider("openai", "fake-api-key")
with pytest.raises(ValueError):
await llm_client.call(
prompt="Hello", model="gpt-4-turbo", max_tokens=float("inf")
)
await llm_client.call(prompt="Hello", model="gpt-4o", max_tokens=float("inf"))


@pytest.mark.asyncio
async def test_call_temperature_too_high(llm_client):
llm_client.add_provider("openai", "fake-api-key")
with pytest.raises(ValueError):
await llm_client.call(prompt="Hello", model="gpt-4-turbo", temperature=3.0)
await llm_client.call(prompt="Hello", model="gpt-4o", temperature=3.0)


# -- Tests for call_custom -- #
Expand Down Expand Up @@ -499,7 +509,7 @@ async def test_chat_memory(mock_call_openai, llm_client_mem_chat):
memory.add_user_message("A")
memory.add_agent_message("B")

response = await llm_client_mem_chat.call(prompt="C", model="gpt-4-turbo")
response = await llm_client_mem_chat.call(prompt="C", model="gpt-4o")
assert response == "response"
assert memory.unpack("role", "content", "user", "assistant") == [
{"role": "user", "content": "A"},
Expand Down Expand Up @@ -542,14 +552,14 @@ async def test_external_memory_system_prompt(mock_call_openai, llm_client_mem_ex

memory.set_contents("stuff")

await llm_client_mem_ext_sys.call(prompt="Hello", model="gpt-4-turbo")
await llm_client_mem_ext_sys.call(prompt="Hello", model="gpt-4o")
assert mock_call_openai.call_args.kwargs["data"]["messages"] == [
{"role": "system", "content": "stuff"},
{"role": "user", "content": "Hello"},
]

await llm_client_mem_ext_sys.call(
system_prompt="system-123", prompt="Hello", model="gpt-4-turbo"
system_prompt="system-123", prompt="Hello", model="gpt-4o"
)
assert mock_call_openai.call_args.kwargs["data"]["messages"] == [
{"role": "system", "content": "system-123\nstuff"},
Expand All @@ -568,13 +578,13 @@ async def test_external_memory_user_prompt(mock_call_openai, llm_client_mem_ext_

memory.set_contents("stuff")

await llm_client_mem_ext_usr.call(prompt="Hello", model="gpt-4-turbo")
await llm_client_mem_ext_usr.call(prompt="Hello", model="gpt-4o")
assert mock_call_openai.call_args.kwargs["data"]["messages"] == [
{"role": "user", "content": "Hello\nstuff"},
]

await llm_client_mem_ext_usr.call(
system_prompt="system-123", prompt="Hello", model="gpt-4-turbo"
system_prompt="system-123", prompt="Hello", model="gpt-4o"
)
assert mock_call_openai.call_args.kwargs["data"]["messages"] == [
{"role": "system", "content": "system-123"},
Expand Down Expand Up @@ -627,12 +637,12 @@ async def test_alt_memory(mock_call_openai, llm_client):
m2 = ChatMemory()
llm_client.load_memory(ChatMemory())

await llm_client.call(prompt="A", model="gpt-4-turbo", alt_memory=m1)
await llm_client.call(prompt="X", model="gpt-4-turbo", alt_memory=m2)
await llm_client.call(prompt="B", model="gpt-4-turbo", alt_memory=m1)
await llm_client.call(prompt="Y", model="gpt-4-turbo", alt_memory=m2)
await llm_client.call(prompt="C", model="gpt-4-turbo", alt_memory=m1)
await llm_client.call(prompt="Z", model="gpt-4-turbo", alt_memory=m2)
await llm_client.call(prompt="A", model="gpt-4o", alt_memory=m1)
await llm_client.call(prompt="X", model="gpt-4o", alt_memory=m2)
await llm_client.call(prompt="B", model="gpt-4o", alt_memory=m1)
await llm_client.call(prompt="Y", model="gpt-4o", alt_memory=m2)
await llm_client.call(prompt="C", model="gpt-4o", alt_memory=m1)
await llm_client.call(prompt="Z", model="gpt-4o", alt_memory=m2)

assert m1.unpack("role", "content", "user", "assistant") == [
{"role": "user", "content": "A"},
Expand Down

0 comments on commit 94b3d6e

Please sign in to comment.