Skip to content

Commit

Permalink
Allowing additional params for OpenAIEmbeddings. (langchain-ai#7752)
Browse files Browse the repository at this point in the history
(langchain-ai#7654)

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
hanit-com and baskaryan authored Jul 18, 2023
1 parent 8622681 commit 0d23c0c
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 18 deletions.
4 changes: 2 additions & 2 deletions langchain/chat_models/jinachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -155,7 +155,7 @@ class Config:
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = cls._all_required_field_names()
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
Expand Down
4 changes: 2 additions & 2 deletions langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names

if TYPE_CHECKING:
import tiktoken
Expand Down Expand Up @@ -205,7 +205,7 @@ class Config:
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = cls._all_required_field_names()
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
Expand Down
34 changes: 32 additions & 2 deletions langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import logging
import warnings
from typing import (
Any,
Callable,
Expand All @@ -16,7 +17,7 @@
)

import numpy as np
from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
AsyncRetrying,
before_sleep_log,
Expand All @@ -27,7 +28,7 @@
)

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -193,12 +194,40 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
when tiktoken is called, you can specify a model name to use here."""
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)

invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)

values["model_kwargs"] = extra
return values

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
Expand Down Expand Up @@ -261,6 +290,7 @@ def _invocation_params(self) -> Dict:
"api_base": self.openai_api_base,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
**self.model_kwargs,
}
if self.openai_api_type in ("azure", "azure_ad", "azuread"):
openai_args["engine"] = self.deployment
Expand Down
6 changes: 3 additions & 3 deletions langchain/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,13 +186,13 @@ class Config:
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = cls._all_required_field_names()
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
Expand Down
12 changes: 6 additions & 6 deletions langchain/schema/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.utils import get_pydantic_field_names

if TYPE_CHECKING:
from langchain.callbacks.manager import Callbacks
Expand Down Expand Up @@ -246,9 +247,8 @@ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:

@classmethod
def _all_required_field_names(cls) -> Set:
all_required_field_names = set()
for field in cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names
"""DEPRECATED: Kept for backwards compatibility.
Use get_pydantic_field_names.
"""
return get_pydantic_field_names(cls)
15 changes: 14 additions & 1 deletion langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import importlib
import os
from importlib.metadata import version
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

from packaging.version import parse
from requests import HTTPError, Response
Expand Down Expand Up @@ -183,3 +183,16 @@ def check_package_version(
f"Expected {package} version to be >= {gte_version}. Received "
f"{imported_version}."
)


def get_pydantic_field_names(pydantic_cls: Any) -> Set:
"""Get field names, including aliases, for a pydantic class.
Args:
pydantic_cls: Pydantic class."""
all_required_field_names = set()
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ extended_testing = [
"openai",
"sympy",
"rapidfuzz",
"openai",
"rank_bm25",
]

Expand Down
Empty file.
20 changes: 20 additions & 0 deletions tests/unit_tests/embeddings/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import pytest

from langchain.embeddings.openai import OpenAIEmbeddings

os.environ["OPENAI_API_KEY"] = "foo"


@pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
OpenAIEmbeddings(model_kwargs={"model": "foo"})


@pytest.mark.requires("openai")
def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = OpenAIEmbeddings(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}
28 changes: 28 additions & 0 deletions tests/unit_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

import pytest

from langchain.llms.openai import OpenAI

os.environ["OPENAI_API_KEY"] = "foo"


@pytest.mark.requires("openai")
def test_openai_model_param() -> None:
llm = OpenAI(model="foo")
assert llm.model_name == "foo"
llm = OpenAI(model_name="foo")
assert llm.model_name == "foo"


@pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
OpenAI(model_kwargs={"model_name": "foo"})


@pytest.mark.requires("openai")
def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = OpenAI(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}

0 comments on commit 0d23c0c

Please sign in to comment.