diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index 9257269c..59fb147b 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -4,14 +4,6 @@ from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional, Union import vertexai # type: ignore[import-untyped] -from google.api_core.client_options import ClientOptions -from google.cloud.aiplatform.gapic import ( - PredictionServiceAsyncClient, - PredictionServiceClient, -) -from google.cloud.aiplatform.models import Prediction -from google.protobuf import json_format -from google.protobuf.struct_pb2 import Value from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -53,7 +45,6 @@ from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory from langchain_google_vertexai._utils import ( create_retry_decorator, - get_client_info, get_generation_info, is_codey_model, is_gemini_model, @@ -487,125 +478,3 @@ async def _astream( await run_manager.on_llm_new_token( chunk.text, chunk=chunk, verbose=self.verbose ) - - -class VertexAIModelGarden(_VertexAIBase, BaseLLM): - """Large language models served from Vertex AI Model Garden.""" - - client: Any = None #: :meta private: - async_client: Any = None #: :meta private: - endpoint_id: str - "A name of an endpoint where the model has been deployed." - allowed_model_args: Optional[List[str]] = None - "Allowed optional args to be passed to the model." - prompt_arg: str = "prompt" - result_arg: Optional[str] = "generated_text" - "Set result_arg to None if output of the model is expected to be a string." - "Otherwise, if it's a dict, provided an argument that contains the result." - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that the python package exists in environment.""" - - if not values["project"]: - raise ValueError( - "A GCP project should be provided to run inference on Model Garden!" - ) - - client_options = ClientOptions( - api_endpoint=f"{values['location']}-aiplatform.googleapis.com" - ) - client_info = get_client_info(module="vertex-ai-model-garden") - values["client"] = PredictionServiceClient( - client_options=client_options, client_info=client_info - ) - values["async_client"] = PredictionServiceAsyncClient( - client_options=client_options, client_info=client_info - ) - return values - - @property - def endpoint_path(self) -> str: - return self.client.endpoint_path( - project=self.project, location=self.location, endpoint=self.endpoint_id - ) - - @property - def _llm_type(self) -> str: - return "vertexai_model_garden" - - def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]: - instances = [] - for prompt in prompts: - if self.allowed_model_args: - instance = { - k: v for k, v in kwargs.items() if k in self.allowed_model_args - } - else: - instance = {} - instance[self.prompt_arg] = prompt - instances.append(instance) - - predict_instances = [ - json_format.ParseDict(instance_dict, Value()) for instance_dict in instances - ] - return predict_instances - - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - instances = self._prepare_request(prompts, **kwargs) - response = self.client.predict(endpoint=self.endpoint_path, instances=instances) - return self._parse_response(response) - - def _parse_response(self, predictions: "Prediction") -> LLMResult: - generations: List[List[Generation]] = [] - for result in predictions.predictions: - generations.append( - [ - Generation(text=self._parse_prediction(prediction)) - for prediction in result - ] - ) - return LLMResult(generations=generations) - - def _parse_prediction(self, prediction: Any) -> str: - if isinstance(prediction, str): - return prediction - - if self.result_arg: - try: - return prediction[self.result_arg] - except KeyError: - if isinstance(prediction, str): - error_desc = ( - "Provided non-None `result_arg` (result_arg=" - f"{self.result_arg}). But got prediction of type " - f"{type(prediction)} instead of dict. Most probably, you" - "need to set `result_arg=None` during VertexAIModelGarden " - "initialization." - ) - raise ValueError(error_desc) - else: - raise ValueError(f"{self.result_arg} key not found in prediction!") - - return prediction - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - instances = self._prepare_request(prompts, **kwargs) - response = await self.async_client.predict( - endpoint=self.endpoint_path, instances=instances - ) - return self._parse_response(response)