Skip to content

Commit

Permalink
community[major]: upgrade pydantic (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Sep 11, 2024
2 parents 2abd5f7 + cf72c5d commit c7a212b
Show file tree
Hide file tree
Showing 23 changed files with 330 additions and 306 deletions.
1 change: 0 additions & 1 deletion libs/community/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from langchain_community.vectorstores.utils import maximal_marginal_relevance
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, ConfigDict, root_validator
from langchain_core.vectorstores import VectorStore
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self

from langchain_google_community._utils import get_client_info
from langchain_google_community.bq_storage_vectorstores.utils import (
Expand Down Expand Up @@ -75,8 +76,9 @@ class BaseBigQueryVectorStore(VectorStore, BaseModel, ABC):
_logger: Any = None
_full_table_id: Optional[str] = None

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

@abstractmethod
def sync_data(self) -> None:
Expand Down Expand Up @@ -113,8 +115,8 @@ def _similarity_search_by_vectors_with_scores_and_embeddings(
) -> list[list[list[Any]]]:
...

@root_validator(pre=False, skip_on_failure=True)
def validate_vals(cls, values: dict) -> dict:
@model_validator(mode="after")
def validate_vals(self) -> Self:
try:
import pandas # noqa: F401
from google.cloud import bigquery # type: ignore[attr-defined]
Expand All @@ -127,41 +129,37 @@ def validate_vals(cls, values: dict) -> dict:
"Please, install feature store dependency group: "
"`pip install langchain-google-community[featurestore]`"
)
values["_logger"] = base.Logger(__name__)
values["_bq_client"] = bigquery.Client(
project=values["project_id"],
location=values["location"],
credentials=values["credentials"],
self._logger = base.Logger(__name__)
self._bq_client = bigquery.Client(
project=self.project_id,
location=self.location,
credentials=self.credentials,
client_info=get_client_info(module="bigquery-vector-search"),
)
if values["embedding_dimension"] is None:
values["embedding_dimension"] = len(values["embedding"].embed_query("test"))
full_table_id = (
f"{values['project_id']}.{values['dataset_name']}.{values['table_name']}"
)
values["_full_table_id"] = full_table_id
temp_dataset_id = f"{values['dataset_name']}_temp"
if self.embedding_dimension is None:
self.embedding_dimension = len(self.embedding.embed_query("test"))
full_table_id = f"{self.project_id}.{self.dataset_name}.{self.table_name}"
self._full_table_id = full_table_id
temp_dataset_id = f"{self.dataset_name}_temp"
if not check_bq_dataset_exists(
client=values["_bq_client"], dataset_id=values["dataset_name"]
client=self._bq_client, dataset_id=self.dataset_name
):
values["_bq_client"].create_dataset(
dataset=values["dataset_name"], exists_ok=True
)
self._bq_client.create_dataset(dataset=self.dataset_name, exists_ok=True)
if not check_bq_dataset_exists(
client=values["_bq_client"], dataset_id=temp_dataset_id
client=self._bq_client, dataset_id=temp_dataset_id
):
values["_bq_client"].create_dataset(dataset=temp_dataset_id, exists_ok=True)
self._bq_client.create_dataset(dataset=temp_dataset_id, exists_ok=True)
table_ref = bigquery.TableReference.from_string(full_table_id)
values["_bq_client"].create_table(table_ref, exists_ok=True)
values["_logger"].info(
self._bq_client.create_table(table_ref, exists_ok=True)
self._logger.info(
f"BigQuery table {full_table_id} "
f"initialized/validated as persistent storage. "
f"Access via BigQuery console:\n "
f"https://console.cloud.google.com/bigquery?project={values['project_id']}"
f"&ws=!1m5!1m4!4m3!1s{values['project_id']}!2s{values['dataset_name']}!3s"
f"{values['table_name']}"
f"https://console.cloud.google.com/bigquery?project={self.project_id}"
f"&ws=!1m5!1m4!4m3!1s{self.project_id}!2s{self.dataset_name}!3s"
f"{self.table_name}"
)
return values
return self

@property
def embeddings(self) -> Optional[Embeddings]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from google.api_core.exceptions import ClientError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import root_validator
from pydantic import model_validator

if TYPE_CHECKING:
from google.cloud.bigquery.table import Table

from typing_extensions import Self

from langchain_google_community.bq_storage_vectorstores._base import (
BaseBigQueryVectorStore,
)
Expand Down Expand Up @@ -114,67 +116,67 @@ def get_documents(
docs.append(doc)
return docs

@root_validator(pre=False, skip_on_failure=True)
def initialize_bq_vector_index(cls, values: dict) -> dict:
@model_validator(mode="after")
def initialize_bq_vector_index(self) -> Self:
"""
A vector index in BigQuery table enables efficient
approximate vector search.
"""
from google.cloud import bigquery # type: ignore[attr-defined]

values["_creating_index"] = values.get("_creating_index", False)
values["_have_index"] = values.get("_have_index", False)
values["_last_index_check"] = values.get("_last_index_check", datetime.min)
self._creating_index = self._creating_index
self._have_index = self._have_index
self._last_index_check = self._last_index_check

if values.get("_have_index") or values.get("_creating_index"):
return values
if self._have_index or self._creating_index:
return self

table = values["_bq_client"].get_table(values["_full_table_id"]) # type: ignore[union-attr]
table = self._bq_client.get_table(self._full_table_id) # type: ignore[union-attr]

# Update existing table schema
schema = table.schema.copy()
if schema: ## Check if table has a schema
values["table_schema"] = {field.name: field.field_type for field in schema}
self.table_schema = {field.name: field.field_type for field in schema}

if (table.num_rows or 0) < MIN_INDEX_ROWS:
values["_logger"].debug("Not enough rows to create a vector index.")
return values
self._logger.debug("Not enough rows to create a vector index.")
return self

if datetime.utcnow() - values["_last_index_check"] < INDEX_CHECK_INTERVAL:
return values
if datetime.utcnow() - self._last_index_check < INDEX_CHECK_INTERVAL:
return self

with _vector_table_lock:
values["_last_index_check"] = datetime.utcnow()
self._last_index_check = datetime.utcnow()
# Check if index exists, create if necessary
check_query = (
f"SELECT 1 FROM `{values['project_id']}."
f"{values['dataset_name']}"
f"SELECT 1 FROM `{self.project_id}."
f"{self.dataset_name}"
".INFORMATION_SCHEMA.VECTOR_INDEXES` WHERE"
f" table_name = '{values['table_name']}'"
f" table_name = '{self.table_name}'"
)
job = values["_bq_client"].query( # type: ignore[union-attr]
job = self._bq_client.query( # type: ignore[union-attr]
check_query, api_method=bigquery.enums.QueryApiMethod.QUERY
)
if job.result().total_rows == 0:
# Need to create an index. Make it in a separate thread.
values["_logger"].debug("Trying to create a vector index.")
self._logger.debug("Trying to create a vector index.")
Thread(
target=_create_bq_index,
kwargs={
"bq_client": values["_bq_client"],
"table_name": values["table_name"],
"full_table_id": values["_full_table_id"],
"embedding_field": values["embedding_field"],
"distance_type": values["distance_type"],
"logger": values["_logger"],
"bq_client": self._bq_client,
"table_name": self.table_name,
"full_table_id": self._full_table_id,
"embedding_field": self.embedding_field,
"distance_type": self.distance_type,
"logger": self._logger,
},
daemon=True,
).start()

else:
values["_logger"].debug("Vector index already exists.")
values["_have_index"] = True
return values
self._logger.debug("Vector index already exists.")
self._have_index = True
return self

def _similarity_search_by_vectors_with_scores_and_embeddings(
self,
Expand Down Expand Up @@ -565,7 +567,9 @@ def to_vertex_fs_vector_store(self, **kwargs: Any) -> Any:
VertexFSVectorStore,
)

base_params = self.dict(include=BaseBigQueryVectorStore.__fields__.keys())
base_params = self.model_dump(
include=set(BaseBigQueryVectorStore.model_fields.keys())
)
base_params["embedding"] = self.embedding
all_params = {**base_params, **kwargs}
fs_obj = VertexFSVectorStore(**all_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import root_validator
from pydantic import model_validator
from typing_extensions import Self

from langchain_google_community._utils import get_client_info, get_user_agent
from langchain_google_community.bq_storage_vectorstores._base import (
Expand Down Expand Up @@ -80,8 +81,8 @@ class VertexFSVectorStore(BaseBigQueryVectorStore):
feature_view: Any = None
_admin_client: Any = None

@root_validator(pre=False, skip_on_failure=True)
def _initialize_bq_vector_index(cls, values: dict) -> dict:
@model_validator(mode="after")
def _initialize_bq_vector_index(self) -> Self:
import vertexai
from google.cloud.aiplatform_v1beta1 import (
FeatureOnlineStoreAdminServiceClient,
Expand All @@ -91,49 +92,43 @@ def _initialize_bq_vector_index(cls, values: dict) -> dict:
utils, # type: ignore[import-untyped]
)

vertexai.init(project=values["project_id"], location=values["location"])
values["_user_agent"] = get_user_agent(
f"{USER_AGENT_PREFIX}-VertexFSVectorStore"
)[1]

if values["algorithm_config"] is None:
values["algorithm_config"] = utils.TreeAhConfig()
if values["distance_measure_type"] is None:
values[
"distance_measure_type"
] = utils.DistanceMeasureType.DOT_PRODUCT_DISTANCE
if values.get("online_store_name") is None:
values["online_store_name"] = values["dataset_name"]
if values.get("view_name") is None:
values["view_name"] = values["table_name"]

api_endpoint = f"{values['location']}-aiplatform.googleapis.com"
values["_admin_client"] = FeatureOnlineStoreAdminServiceClient(
vertexai.init(project=self.project_id, location=self.location)
self._user_agent = get_user_agent(f"{USER_AGENT_PREFIX}-VertexFSVectorStore")[1]

if self.algorithm_config is None:
self.algorithm_config = utils.TreeAhConfig()
if self.distance_measure_type is None:
self.distance_measure_type = utils.DistanceMeasureType.DOT_PRODUCT_DISTANCE
if self.online_store_name is None:
self.online_store_name = self.dataset_name
if self.view_name is None:
self.view_name = self.table_name

api_endpoint = f"{self.location}-aiplatform.googleapis.com"
self._admin_client = FeatureOnlineStoreAdminServiceClient(
client_options={"api_endpoint": api_endpoint},
client_info=get_client_info(module=values["_user_agent"]),
client_info=get_client_info(module=self._user_agent),
)
values["online_store"] = _create_online_store(
project_id=values["project_id"],
location=values["location"],
online_store_name=values["online_store_name"],
_admin_client=values["_admin_client"],
_logger=values["_logger"],
self.online_store = _create_online_store(
project_id=self.project_id,
location=self.location,
online_store_name=self.online_store_name,
_admin_client=self._admin_client,
_logger=self._logger,
)
gca_resource = values["online_store"].gca_resource
gca_resource = self.online_store.gca_resource
endpoint = gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name
values["_search_client"] = FeatureOnlineStoreServiceClient(
self._search_client = FeatureOnlineStoreServiceClient(
client_options={"api_endpoint": endpoint},
client_info=get_client_info(module=values["_user_agent"]),
)
values["feature_view"] = _get_feature_view(
values["online_store"], values["view_name"]
client_info=get_client_info(module=self._user_agent),
)
self.feature_view = _get_feature_view(self.online_store, self.view_name)

values["_logger"].info(
self._logger.info(
"VertexFSVectorStore initialized with Feature Store Vector Search. \n"
"Optional batch serving available via .to_bq_vector_store() method."
)
return values
return self

def _init_store(self) -> None:
from google.cloud.aiplatform_v1beta1 import FeatureOnlineStoreServiceClient
Expand Down Expand Up @@ -518,7 +513,9 @@ def to_bq_vector_store(self, **kwargs: Any) -> Any:
BigQueryVectorStore,
)

base_params = self.dict(include=BaseBigQueryVectorStore.__fields__.keys())
base_params = self.model_dump(
include=set(BaseBigQueryVectorStore.model_fields.keys())
)
base_params["embedding"] = self.embedding
all_params = {**base_params, **kwargs}
bq_obj = BigQueryVectorStore(**all_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import get_from_dict_or_env
from pydantic import model_validator

from langchain_google_community._utils import get_client_info

Expand Down Expand Up @@ -42,8 +42,9 @@ class DocumentAIWarehouseRetriever(BaseRetriever):
"""The limit on the number of documents returned."""
client: "DocumentServiceClient" = None #: :meta private:

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validates the environment."""
try: # noqa: F401
from google.cloud.contentwarehouse_v1 import DocumentServiceClient
Expand Down
9 changes: 5 additions & 4 deletions libs/community/langchain_google_community/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator, validator
from pydantic import BaseModel, field_validator, model_validator

SCOPES = ["https://www.googleapis.com/auth/drive.file"]

Expand Down Expand Up @@ -194,8 +194,9 @@ def _get_identity_metadata_from_id(self, id: str) -> List[str]:

return authorized_identities

@root_validator
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def validate_inputs(cls, values: Dict[str, Any]) -> Any:
"""Validate that either folder_id or document_ids is set, but not both."""
if values.get("folder_id") and (
values.get("document_ids") or values.get("file_ids")
Expand Down Expand Up @@ -241,7 +242,7 @@ def full_form(x: str) -> str:
values["file_types"] = [full_form(file_type) for file_type in file_types]
return values

@validator("credentials_path")
@field_validator("credentials_path")
def validate_credentials_path(cls, v: Any, **kwargs: Any) -> Any:
"""Validate that credentials_path exists."""
if not v.exists():
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_google_community/gmail/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from typing import TYPE_CHECKING

from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool
from pydantic import Field

from langchain_google_community.gmail.utils import build_resource_service

Expand Down
Loading

0 comments on commit c7a212b

Please sign in to comment.