Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[patch]: remove optional defaults #29097

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries.
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
Expand Down Expand Up @@ -586,9 +586,9 @@ def is_lc_serializable(cls) -> bool:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

if self.disabled_params is None:
Expand Down Expand Up @@ -641,10 +641,12 @@ def validate_environment(self) -> Self:
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if not self.client:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
Expand Down
20 changes: 11 additions & 9 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel):
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
temperature: Optional[float] = None
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
Expand All @@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = 2
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
presence_penalty: Optional[float] = None
"""Penalizes repeated tokens."""
Expand All @@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
n: Optional[int] = None
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
Expand Down Expand Up @@ -532,9 +532,9 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

# Check OPENAI_ORGANIZATION for backwards compatibility.
Expand All @@ -551,10 +551,12 @@ def validate_environment(self) -> Self:
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if self.openai_proxy and (self.http_client or self.http_async_client):
openai_proxy = self.openai_proxy
http_client = self.http_client
Expand Down Expand Up @@ -609,14 +611,14 @@ def _default_params(self) -> Dict[str, Any]:
"stop": self.stop or None, # also exclude empty list for this
"max_tokens": self.max_tokens,
"extra_body": self.extra_body,
"n": self.n,
"temperature": self.temperature,
"reasoning_effort": self.reasoning_effort,
}

params = {
"model": self.model_name,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**{k: v for k, v in exclude_if_none.items() if v is not None},
**self.model_kwargs,
}
Expand Down Expand Up @@ -1565,7 +1567,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]

timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries.
api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
}),
'max_retries': 2,
'max_tokens': 100,
'n': 1,
'openai_api_key': dict({
'id': list([
'AZURE_OPENAI_API_KEY',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'gpt-3.5-turbo',
'n': 1,
'openai_api_key': dict({
'id': list([
'OPENAI_API_KEY',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,6 @@ def test__get_request_payload() -> None:
],
"model": "gpt-4o-2024-08-06",
"stream": False,
"n": 1,
"temperature": 0.7,
}
payload = llm._get_request_payload(messages)
assert payload == expected
Expand Down
3 changes: 3 additions & 0 deletions libs/partners/xai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ integration_test integration_tests: TEST_FILE=tests/integration_tests/
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)

test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)

integration_test integration_tests:
poetry run pytest $(TEST_FILE)

Expand Down
7 changes: 4 additions & 3 deletions libs/partners/xai/langchain_xai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def _get_ls_params(
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
if self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

client_params: dict = {
Expand All @@ -331,10 +331,11 @@ def validate_environment(self) -> Self:
),
"base_url": self.xai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if client_params["api_key"] is None:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-beta',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),
Expand Down
Loading