Skip to content

Commit

Permalink
vertex[major]: upgrade pydantic (#475)
Browse files Browse the repository at this point in the history
In preparation for langchain 0.3 release

Todo:

- [x] fix gapic format test + chain integration tests (important)
- [x] fix ser/des standard tests
- [x] fix warnings
  • Loading branch information
baskaryan authored Sep 11, 2024
2 parents c98e754 + 97b32d8 commit 2abd5f7
Show file tree
Hide file tree
Showing 25 changed files with 659 additions and 358 deletions.
4 changes: 3 additions & 1 deletion libs/vertexai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ integration_test integration_tests: TEST_FILE = tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest --release $(TEST_FILE)

test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)

# Run unit tests and generate a coverage report.
coverage:
poetry run pytest --cov \
Expand All @@ -33,7 +36,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
7 changes: 4 additions & 3 deletions libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, ConfigDict


class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
pydantic_schemas: Optional[List[Type[BaseModel]]] = None

class Config:
extra = "forbid"
model_config = ConfigDict(
extra="forbid",
)

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel

if TYPE_CHECKING:
from anthropic.types import RawMessageStreamEvent # type: ignore
Expand Down
34 changes: 18 additions & 16 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
from vertexai.generative_models._generative_models import ( # type: ignore
SafetySettingsType,
)
Expand Down Expand Up @@ -91,14 +92,15 @@ class _VertexAIBase(BaseModel):
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."

class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
protected_namespaces=(),
)

allow_population_by_field_name = True
arbitrary_types_allowed = True

@root_validator(pre=True)
def validate_params_base(cls, values: dict) -> dict:
@model_validator(mode="before")
@classmethod
def validate_params_base(cls, values: dict) -> Any:
if "model" in values and "model_name" not in values:
values["model_name"] = values.pop("model")
if values.get("project") is None:
Expand All @@ -108,7 +110,7 @@ def validate_params_base(cls, values: dict) -> dict:
if values.get("api_endpoint"):
api_endpoint = values["api_endpoint"]
else:
location = values.get("location", cls.__fields__["location"].default)
location = values.get("location", cls.model_fields["location"].default)
api_endpoint = f"{location}-{constants.PREDICTION_API_BASE_PATH}"
client_options = ClientOptions(api_endpoint=api_endpoint)
if values.get("client_cert_source"):
Expand Down Expand Up @@ -311,26 +313,26 @@ class _BaseVertexAIModelGarden(_VertexAIBase):
single_example_per_request: bool = True
"LLM endpoint currently serves only the first example in the request"

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that the python package exists in environment."""

if not values["project"]:
if not self.project:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)

client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
api_endpoint=f"{self.location}-aiplatform.googleapis.com"
)
client_info = get_client_info(module="vertex-ai-model-garden")
values["client"] = PredictionServiceClient(
self.client = PredictionServiceClient(
client_options=client_options, client_info=client_info
)
values["async_client"] = PredictionServiceAsyncClient(
self.async_client = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
return values
return self

@property
def endpoint_path(self) -> str:
Expand Down
11 changes: 8 additions & 3 deletions libs/vertexai/langchain_google_vertexai/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
StrOutputParser,
)
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from pydantic import BaseModel

from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser

Expand Down Expand Up @@ -51,7 +51,12 @@ def _create_structured_runnable_extra_step(
*,
prompt: Optional[BasePromptTemplate] = None,
) -> Runnable:
names = [schema.schema()["title"] for schema in functions]
names = [
schema.model_json_schema()["title"]
if hasattr(schema, "model_json_schema")
else schema.schema()["title"]
for schema in functions
]
if hasattr(llm, "is_gemini_advanced") and llm._is_gemini_advanced: # type: ignore
llm_with_functions = llm.bind(
functions=functions,
Expand Down Expand Up @@ -111,7 +116,7 @@ def create_structured_runnable(
from langchain_google_vertexai import ChatVertexAI, create_structured_runnable
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class RecordPerson(BaseModel):
Expand Down
86 changes: 48 additions & 38 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, root_validator, Field
from pydantic import BaseModel, Field, model_validator
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableGenerator
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
Expand Down Expand Up @@ -124,6 +124,9 @@
_format_to_gapic_tool,
_ToolType,
)
from pydantic import ConfigDict
from typing_extensions import Self


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -762,7 +765,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
Expand Down Expand Up @@ -800,7 +803,7 @@ class GetPopulation(BaseModel):
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
Expand Down Expand Up @@ -1024,11 +1027,10 @@ def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
arbitrary_types_allowed = True
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)

@classmethod
def is_lc_serializable(self) -> bool:
Expand All @@ -1039,57 +1041,65 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that the python package exists in environment."""
safety_settings = values.get("safety_settings")
tuned_model_name = values.get("tuned_model_name")
values["model_family"] = GoogleModelFamily(values["model_name"])
safety_settings = self.safety_settings
tuned_model_name = self.tuned_model_name
self.model_family = GoogleModelFamily(self.model_name)

if values["model_name"] == "chat-bison-default":
if self.model_name == "chat-bison-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Sep-01-2024. Currently the default is set to "
"chat-bison"
)
values["model_name"] = "chat-bison"
self.model_name = "chat-bison"

if values.get("full_model_name") is not None:
if self.full_model_name is not None:
pass
elif values.get("tuned_model_name") is not None:
values["full_model_name"] = _format_model_name(
values["tuned_model_name"],
location=values["location"],
project=values["project"],
elif self.tuned_model_name is not None:
self.full_model_name = _format_model_name(
self.tuned_model_name,
location=self.location,
project=cast(str, self.project),
)
else:
values["full_model_name"] = _format_model_name(
values["model_name"],
location=values["location"],
project=values["project"],
self.full_model_name = _format_model_name(
self.model_name,
location=self.location,
project=cast(str, self.project),
)

if safety_settings and not is_gemini_model(values["model_family"]):
if safety_settings and not is_gemini_model(self.model_family):
raise ValueError("Safety settings are only supported for Gemini models")

if tuned_model_name:
generative_model_name = values["tuned_model_name"]
generative_model_name = self.tuned_model_name
else:
generative_model_name = values["model_name"]

if not is_gemini_model(values["model_family"]):
cls._init_vertexai(values)
if values["model_family"] == GoogleModelFamily.CODEY:
generative_model_name = self.model_name

if not is_gemini_model(self.model_family):
values = {
"project": self.project,
"location": self.location,
"credentials": self.credentials,
"api_transport": self.api_transport,
"api_endpoint": self.api_endpoint,
"default_metadata": self.default_metadata,
}
self._init_vertexai(values)
if self.model_family == GoogleModelFamily.CODEY:
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(generative_model_name)
values["client_preview"] = model_cls_preview.from_pretrained(
self.client = model_cls.from_pretrained(generative_model_name)
self.client_preview = model_cls_preview.from_pretrained(
generative_model_name
)
return values
return self

@property
def _is_gemini_advanced(self) -> bool:
Expand Down Expand Up @@ -1647,7 +1657,7 @@ def with_structured_output(
Example: Pydantic schema, exclude raw:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_google_vertexai import ChatVertexAI
class AnswerWithJustification(BaseModel):
Expand All @@ -1666,7 +1676,7 @@ class AnswerWithJustification(BaseModel):
Example: Pydantic schema, include raw:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_google_vertexai import ChatVertexAI
class AnswerWithJustification(BaseModel):
Expand All @@ -1687,7 +1697,7 @@ class AnswerWithJustification(BaseModel):
Example: Dict schema, exclude raw:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_google_vertexai import ChatVertexAI
Expand Down
36 changes: 23 additions & 13 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from google.cloud.aiplatform import telemetry
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
from vertexai.language_models import ( # type: ignore
TextEmbeddingInput,
TextEmbeddingModel,
Expand Down Expand Up @@ -100,24 +101,33 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
# Instance context
instance: Dict[str, Any] = {} #: :meta private:

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
model_config = ConfigDict(
extra="forbid",
protected_namespaces=(),
)

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates that the python package exists in environment."""
cls._init_vertexai(values)
_, user_agent = get_user_agent(f"{cls.__name__}_{values['model_name']}") # type: ignore
values = {
"project": self.project,
"location": self.location,
"credentials": self.credentials,
"api_transport": self.api_transport,
"api_endpoint": self.api_endpoint,
"default_metadata": self.default_metadata,
}
self._init_vertexai(values)
_, user_agent = get_user_agent(f"{self.__class__.__name__}_{self.model_name}")
with telemetry.tool_context_manager(user_agent):
if (
GoogleEmbeddingModelType(values["model_name"])
GoogleEmbeddingModelType(self.model_name)
== GoogleEmbeddingModelType.MULTIMODAL
):
values["client"] = MultiModalEmbeddingModel.from_pretrained(
values["model_name"]
)
self.client = MultiModalEmbeddingModel.from_pretrained(self.model_name)
else:
values["client"] = TextEmbeddingModel.from_pretrained(
values["model_name"]
)
return values
self.client = TextEmbeddingModel.from_pretrained(self.model_name)
return self

def __init__(
self,
Expand Down
Loading

0 comments on commit 2abd5f7

Please sign in to comment.