From 7e7065e7bbe5a934db82085980f436fca4ebc80d Mon Sep 17 00:00:00 2001 From: Heiko Hotz Date: Mon, 18 Mar 2024 19:28:46 +0000 Subject: [PATCH] extracting model response --- .../langchain_google_vertexai/model_garden.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 6469fb42..9114ef5c 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -12,6 +12,28 @@ from langchain_google_vertexai._base import _BaseVertexAIModelGarden +import re + + +def extract_model_response(text, prompt): + # Remove the "Prompt:\n\n" section from the start + prompt_section_pattern = re.compile(r'^Prompt:\n' + re.escape(prompt) + r'\n', re.DOTALL) + text_without_prompt_section = prompt_section_pattern.sub('', text, count=1) + + # Define the output section starting pattern to look for + output_start_pattern = re.compile(r'^Output:\n', re.DOTALL) + + # Check if the section immediately following "Output:\n" is the prompt + if re.match(output_start_pattern.pattern + re.escape(prompt), text_without_prompt_section): + # If the prompt is indeed repeated, remove "Output:\n\n" + output_without_repeated_prompt = re.sub(output_start_pattern.pattern + re.escape(prompt), '', text_without_prompt_section, count=1) + else: + # If the prompt is not repeated, simply remove "Output:\n" to start extracting the model response + output_without_repeated_prompt = re.sub(output_start_pattern.pattern, '', text_without_prompt_section, count=1) + + # Return the cleaned output section, which is the model response + return output_without_repeated_prompt.strip() + class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM): """Large language models served from Vertex AI Model Garden.""" @@ -38,6 +60,10 @@ def _generate( ) response = self.client.predict(endpoint=self.endpoint_path, instances=instances) + + if not kwargs.get("keep_original_response", False): + response.predictions[0] = extract_model_response(response.predictions[0], prompts[0]) + return self._parse_response(response) async def _agenerate( @@ -69,4 +95,8 @@ async def _agenerate( response = await self.async_client.predict( endpoint=self.endpoint_path, instances=instances ) + + if not kwargs.get("keep_original_response", False): + response.predictions[0] = extract_model_response(response.predictions[0], prompts[0]) + return self._parse_response(response)