Skip to content

Commit

Permalink
genai[major]: upgrade pydantic (#481)
Browse files Browse the repository at this point in the history
Todo:

- resolve warnings
- run integration tests, fix failures, fix warnings
  • Loading branch information
baskaryan authored Sep 11, 2024
2 parents d32b8fb + 290a90b commit c98e754
Show file tree
Hide file tree
Showing 20 changed files with 349 additions and 191 deletions.
4 changes: 3 additions & 1 deletion libs/genai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ test tests integration_test integration_tests:
check_imports: $(shell find langchain_google_genai -name '*.py')
poetry run python ./scripts/check_imports.py $^

test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)

# Run unit tests and generate a coverage report.
coverage:
poetry run pytest --cov \
Expand All @@ -36,7 +39,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
24 changes: 20 additions & 4 deletions libs/genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
import google.ai.generativelanguage_v1beta.types as gapic
import proto # type: ignore[import]
from google.generativeai.types.content_types import ToolDict # type: ignore[import]
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as callable_as_lc_tool
from langchain_core.utils.function_calling import (
FunctionDescription,
convert_to_openai_tool,
)
from langchain_core.utils.json_schema import dereference_refs
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -205,7 +206,7 @@ def _format_to_gapic_function_declaration(
function = cast(dict, tool)
function["parameters"] = {}
else:
if "parameters" in tool and tool["parameters"].get("properties"):
if "parameters" in tool and tool["parameters"].get("properties"): # type: ignore[index]
function = convert_to_openai_tool(cast(dict, tool))["function"]
else:
function = cast(dict, tool)
Expand All @@ -232,7 +233,14 @@ def _format_base_tool_to_function_declaration(
),
)

schema = tool.args_schema.schema()
if issubclass(tool.args_schema, BaseModel):
schema = tool.args_schema.model_json_schema()
elif issubclass(tool.args_schema, BaseModelV1):
schema = tool.args_schema.schema()
else:
raise NotImplementedError(
f"args_schema must be a Pydantic BaseModel, got {tool.args_schema}."
)
parameters = _dict_to_gapic_schema(schema)

return gapic.FunctionDeclaration(
Expand All @@ -247,7 +255,15 @@ def _convert_pydantic_to_genai_function(
tool_name: Optional[str] = None,
tool_description: Optional[str] = None,
) -> gapic.FunctionDeclaration:
schema = dereference_refs(pydantic_model.schema())
if issubclass(pydantic_model, BaseModel):
schema = pydantic_model.model_json_schema()
elif issubclass(pydantic_model, BaseModelV1):
schema = pydantic_model.schema()
else:
raise NotImplementedError(
f"pydantic_model must be a Pydantic BaseModel, got {pydantic_model}"
)
schema = dereference_refs(schema)
schema.pop("definitions", None)
function_declaration = gapic.FunctionDeclaration(
name=tool_name if tool_name else schema.get("title"),
Expand Down
66 changes: 36 additions & 30 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,24 @@
parse_tool_calls,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.utils import secret_from_env
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from typing_extensions import Self

from langchain_google_genai._common import (
GoogleGenerativeAIError,
Expand Down Expand Up @@ -696,7 +703,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
Expand Down Expand Up @@ -741,7 +748,7 @@ class GetPopulation(BaseModel):
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class Joke(BaseModel):
Expand Down Expand Up @@ -832,8 +839,9 @@ class Joke(BaseModel):
Gemini does not support system messages; any unsupported messages will
raise an error."""

class Config:
allow_population_by_field_name = True
model_config = ConfigDict(
populate_by_name=True,
)

@property
def lc_secrets(self) -> Dict[str, str]:
Expand All @@ -847,38 +855,36 @@ def _llm_type(self) -> str:
def is_lc_serializable(self) -> bool:
return True

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates params and passes them to google-generativeai package."""
if (
values.get("temperature") is not None
and not 0 <= values["temperature"] <= 1
):
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

if values.get("top_p") is not None and not 0 <= values["top_p"] <= 1:
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")

if values.get("top_k") is not None and values["top_k"] <= 0:
if self.top_k is not None and self.top_k <= 0:
raise ValueError("top_k must be positive")

if not values["model"].startswith("models/"):
values["model"] = f"models/{values['model']}"
if not self.model.startswith("models/"):
self.model = f"models/{self.model}"

additional_headers = values.get("additional_headers") or {}
values["default_metadata"] = tuple(additional_headers.items())
additional_headers = self.additional_headers or {}
self.default_metadata = tuple(additional_headers.items())
client_info = get_client_info("ChatGoogleGenerativeAI")
google_api_key = None
if not values.get("credentials"):
google_api_key = values.get("google_api_key")
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
transport: Optional[str] = values.get("transport")
values["client"] = genaix.build_generative_service(
credentials=values.get("credentials"),
if not self.credentials:
if isinstance(self.google_api_key, SecretStr):
google_api_key = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
transport: Optional[str] = self.transport
self.client = genaix.build_generative_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=client_info,
client_options=values.get("client_options"),
client_options=self.client_options,
transport=transport,
)

Expand All @@ -888,17 +894,17 @@ def validate_environment(cls, values: Dict) -> Dict:
# this check ensures that async client is only initialized
# within an asyncio event loop to avoid the error
if _is_event_loop_running():
values["async_client"] = genaix.build_generative_async_service(
credentials=values.get("credentials"),
self.async_client = genaix.build_generative_async_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=client_info,
client_options=values.get("client_options"),
client_options=self.client_options,
transport=transport,
)
else:
values["async_client"] = None
self.async_client = None

return values
return self

@property
def _identifying_params(self) -> Dict[str, Any]:
Expand Down
22 changes: 12 additions & 10 deletions libs/genai/langchain_google_genai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
EmbedContentRequest,
)
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_google_genai._common import (
GoogleGenerativeAIError,
Expand Down Expand Up @@ -82,21 +83,22 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
"Example: `{'timeout': 10}`",
)

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates params and passes them to google-generativeai package."""
google_api_key = values.get("google_api_key")
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
if isinstance(self.google_api_key, SecretStr):
google_api_key: Optional[str] = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
client_info = get_client_info("GoogleGenerativeAIEmbeddings")

values["client"] = build_generative_service(
credentials=values.get("credentials"),
self.client = build_generative_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=client_info,
client_options=values.get("client_options"),
client_options=self.client_options,
)
return values
return self

@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion libs/genai/langchain_google_genai/genai_aqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from typing import Any, List, Optional

import google.ai.generativelanguage as genai
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from langchain_core.runnables import RunnableSerializable
from langchain_core.runnables.config import RunnableConfig
from pydantic import BaseModel, PrivateAttr

from . import _genai_extension as genaix

Expand Down
4 changes: 2 additions & 2 deletions libs/genai/langchain_google_genai/google_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import google.ai.generativelanguage as genai
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from pydantic import BaseModel, PrivateAttr

from . import _genai_extension as genaix
from .genai_aqa import (
Expand Down Expand Up @@ -467,7 +467,7 @@ def as_aqa(self, **kwargs: Any) -> Runnable[str, AqaOutput]:
return (
RunnablePassthrough[str]()
| {
"prompt": RunnablePassthrough(),
"prompt": RunnablePassthrough(), # type: ignore[dict-item]
"passages": self.as_retriever(),
}
| RunnableLambda(_toAqaInput)
Expand Down
46 changes: 24 additions & 22 deletions libs/genai/langchain_google_genai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from langchain_core.language_models import LangSmithParams, LanguageModelInput
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_google_genai._enums import (
HarmBlockThreshold,
Expand Down Expand Up @@ -216,57 +217,58 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):

client: Any = None #: :meta private:

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates params and passes them to google-generativeai package."""
if values.get("credentials"):
if self.credentials:
genai.configure(
credentials=values.get("credentials"),
transport=values.get("transport"),
client_options=values.get("client_options"),
credentials=self.credentials,
transport=self.transport,
client_options=self.client_options,
)
else:
google_api_key = values.get("google_api_key")
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
if isinstance(self.google_api_key, SecretStr):
google_api_key: Optional[str] = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
transport=self.transport,
client_options=self.client_options,
)

model_name = values["model"]
model_name = self.model

safety_settings = values["safety_settings"]
safety_settings = self.safety_settings

if safety_settings and (
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
):
raise ValueError("Safety settings are only supported for Gemini models")

if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
values["client"] = genai.GenerativeModel(
self.client = genai.GenerativeModel(
model_name=model_name, safety_settings=safety_settings
)
else:
values["client"] = genai
self.client = genai

if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")

if values["top_k"] is not None and values["top_k"] <= 0:
if self.top_k is not None and self.top_k <= 0:
raise ValueError("top_k must be positive")

if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
if self.max_output_tokens is not None and self.max_output_tokens <= 0:
raise ValueError("max_output_tokens must be greater than zero")

if values["timeout"] is not None and values["timeout"] <= 0:
if self.timeout is not None and self.timeout <= 0:
raise ValueError("timeout must be greater than zero")

return values
return self

def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
Expand Down
Loading

0 comments on commit c98e754

Please sign in to comment.