From 8a8ddda86c0713837614f7a438171e3deaf1ec68 Mon Sep 17 00:00:00 2001 From: sakher Date: Sun, 15 Sep 2024 10:44:13 +0100 Subject: [PATCH] Add support for new models (4o and 01) --- .../elsuite/already_said_that/scripts/make_plots.py | 12 ++++++++++++ evals/elsuite/track_the_stat/scripts/make_plots.py | 12 ++++++++++++ evals/registry.py | 9 +++++++-- evals/registry_test.py | 6 ++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/evals/elsuite/already_said_that/scripts/make_plots.py b/evals/elsuite/already_said_that/scripts/make_plots.py index ede36291ec..6142c7061e 100644 --- a/evals/elsuite/already_said_that/scripts/make_plots.py +++ b/evals/elsuite/already_said_that/scripts/make_plots.py @@ -23,6 +23,9 @@ def zero_if_none(input_num): "cot/gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-4-base", + "gpt-4o", + "o1-preview", + "o1-mini", "gemini-pro", "mixtral-8x7b-instruct", "llama-2-70b-chat", @@ -35,6 +38,9 @@ def zero_if_none(input_num): "cot/gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-4-base", + "gpt-4o", + "o1-preview", + "o1-mini", ] @@ -154,6 +160,12 @@ def get_model(spec): return "gpt-3.5-turbo" elif "gpt-4-base" in spec["completion_fns"][0]: return "gpt-4-base" + elif "gpt-4o" in spec["completion_fns"][0]: + return "gpt-4o" + elif "o1-preview" in spec["completion_fns"][0]: + return "o1-preview" + elif "o1-mini" in spec["completion_fns"][0]: + return "o1-mini" elif "gemini-pro" in spec["completion_fns"][0]: return "gemini-pro" elif "mixtral-8x7b-instruct" in spec["completion_fns"][0]: diff --git a/evals/elsuite/track_the_stat/scripts/make_plots.py b/evals/elsuite/track_the_stat/scripts/make_plots.py index b40e4a3586..d34bad5f36 100644 --- a/evals/elsuite/track_the_stat/scripts/make_plots.py +++ b/evals/elsuite/track_the_stat/scripts/make_plots.py @@ -20,6 +20,9 @@ def zero_if_none(input_num): MODELS = [ "gpt-4-0125-preview", "gpt-4-base", + "gpt-4o", + "o1-preview", + "o1-mini", "gpt-3.5-turbo-0125", "gemini-pro-1.0", "mixtral-8x7b-instruct", @@ -32,6 +35,9 @@ def zero_if_none(input_num): "gpt-4-0125-preview", "gpt-3.5-turbo-0125", "gpt-4-base", + "gpt-4o", + "o1-preview", + "o1-mini", ] STAT_TO_LABEL = { @@ -54,6 +60,12 @@ def get_model(spec): return "gpt-3.5-turbo-0125" elif "gpt-4-base" in spec["completion_fns"][0]: return "gpt-4-base" + elif "gpt-4o" in spec["completion_fns"][0]: + return "gpt-4o" + elif "o1-preview" in spec["completion_fns"][0]: + return "o1-preview" + elif "o1-mini" in spec["completion_fns"][0]: + return "o1-mini" elif "gemini-pro" in spec["completion_fns"][0]: return "gemini-pro-1.0" elif "mixtral-8x7b-instruct" in spec["completion_fns"][0]: diff --git a/evals/registry.py b/evals/registry.py index 2d1c0fee1d..5db892211d 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -42,6 +42,8 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]: ("gpt-3.5-turbo-", 4096), ("gpt-4-32k-", 32768), ("gpt-4-", 8192), + ("gpt-4o-", 128_000), + ("o1-", 128_000), ] MODEL_NAME_TO_N_CTX: dict[str, int] = { "ada": 2048, @@ -65,6 +67,9 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]: "gpt-4-1106-preview": 128_000, "gpt-4-turbo-preview": 128_000, "gpt-4-0125-preview": 128_000, + "gpt-4o": 128_000, + "o1-preview": 128_000, + "o1-mini": 128_000, } # first, look for an exact match @@ -84,12 +89,12 @@ def is_chat_model(model_name: str) -> bool: if model_name in {"gpt-4-base"} or model_name.startswith("gpt-3.5-turbo-instruct"): return False - CHAT_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"} + CHAT_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k", "gpt-4o", "o1-preview", "o1-mini"} if model_name in CHAT_MODEL_NAMES: return True - for model_prefix in {"gpt-3.5-turbo-", "gpt-4-"}: + for model_prefix in {"gpt-3.5-turbo-", "gpt-4-", "gpt-4o-", "o1-"}: if model_name.startswith(model_prefix): return True diff --git a/evals/registry_test.py b/evals/registry_test.py index ef05316220..6736cbbddb 100644 --- a/evals/registry_test.py +++ b/evals/registry_test.py @@ -6,6 +6,9 @@ def test_n_ctx_from_model_name(): assert n_ctx_from_model_name("gpt-3.5-turbo-0613") == 4096 assert n_ctx_from_model_name("gpt-3.5-turbo-16k") == 16384 assert n_ctx_from_model_name("gpt-3.5-turbo-16k-0613") == 16384 + assert n_ctx_from_model_name("gpt-4o") == 128_000 + assert n_ctx_from_model_name("o1-preview") == 128_000 + assert n_ctx_from_model_name("o1-mini") == 128_000 assert n_ctx_from_model_name("gpt-4") == 8192 assert n_ctx_from_model_name("gpt-4-0613") == 8192 assert n_ctx_from_model_name("gpt-4-32k") == 32768 @@ -27,6 +30,9 @@ def test_is_chat_model(): assert is_chat_model("gpt-4-0613") assert is_chat_model("gpt-4-32k") assert is_chat_model("gpt-4-32k-0613") + assert is_chat_model("gpt-4o") + assert is_chat_model("o1-preview") + assert is_chat_model("o1-mini") assert not is_chat_model("text-davinci-003") assert not is_chat_model("gpt4-base") assert not is_chat_model("code-davinci-002")