diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index baaa74f637b31..233fa8700a436 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -480,6 +480,15 @@ def build_extra(cls, values: Dict[str, Any]) -> Any: values = _build_model_kwargs(values, all_required_field_names) return values + @model_validator(mode="before") + @classmethod + def validate_temperature(cls, values: Dict[str, Any]) -> Any: + """Currently o1 models only allow temperature=1.""" + model = values.get("model_name") or values.get("model") or "" + if model.startswith("o1") and "temperature" not in values: + values["temperature"] = 1 + return values + @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index cc03698e5ef93..c6065a137e5c7 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -35,6 +35,13 @@ def test_openai_model_param() -> None: assert llm.model_name == "foo" +def test_openai_o1_temperature() -> None: + llm = ChatOpenAI(model="o1-preview") + assert llm.temperature == 1 + llm = ChatOpenAI(model_name="o1-mini") # type: ignore[call-arg] + assert llm.temperature == 1 + + def test_function_message_dict_to_function_message() -> None: content = json.dumps({"result": "Example #1"}) name = "test_function"