Skip to content

Commit

Permalink
removed unused class since VertexAIModelGarden has been moved
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Feb 28, 2024
1 parent 6af9e3b commit 3b9e070
Showing 1 changed file with 0 additions and 131 deletions.
131 changes: 0 additions & 131 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional, Union

import vertexai # type: ignore[import-untyped]
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
from google.cloud.aiplatform.models import Prediction
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -53,7 +45,6 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai._utils import (
create_retry_decorator,
get_client_info,
get_generation_info,
is_codey_model,
is_gemini_model,
Expand Down Expand Up @@ -487,125 +478,3 @@ async def _astream(
await run_manager.on_llm_new_token(
chunk.text, chunk=chunk, verbose=self.verbose
)


class VertexAIModelGarden(_VertexAIBase, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""

client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
endpoint_id: str
"A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None
"Allowed optional args to be passed to the model."
prompt_arg: str = "prompt"
result_arg: Optional[str] = "generated_text"
"Set result_arg to None if output of the model is expected to be a string."
"Otherwise, if it's a dict, provided an argument that contains the result."

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""

if not values["project"]:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)

client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
)
client_info = get_client_info(module="vertex-ai-model-garden")
values["client"] = PredictionServiceClient(
client_options=client_options, client_info=client_info
)
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
return values

@property
def endpoint_path(self) -> str:
return self.client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)

@property
def _llm_type(self) -> str:
return "vertexai_model_garden"

def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)

predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
return predict_instances

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)

def _parse_response(self, predictions: "Prediction") -> LLMResult:
generations: List[List[Generation]] = []
for result in predictions.predictions:
generations.append(
[
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
)
return LLMResult(generations=generations)

def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction

if self.result_arg:
try:
return prediction[self.result_arg]
except KeyError:
if isinstance(prediction, str):
error_desc = (
"Provided non-None `result_arg` (result_arg="
f"{self.result_arg}). But got prediction of type "
f"{type(prediction)} instead of dict. Most probably, you"
"need to set `result_arg=None` during VertexAIModelGarden "
"initialization."
)
raise ValueError(error_desc)
else:
raise ValueError(f"{self.result_arg} key not found in prediction!")

return prediction

async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)

0 comments on commit 3b9e070

Please sign in to comment.