Skip to content

Commit

Permalink
extracting model response
Browse files Browse the repository at this point in the history
  • Loading branch information
Heiko Hotz committed Mar 18, 2024
1 parent 90bcf52 commit 7e7065e
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@

from langchain_google_vertexai._base import _BaseVertexAIModelGarden

import re


def extract_model_response(text, prompt):

Check failure on line 18 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (I001)

langchain_google_vertexai/model_garden.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 18 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (I001)

langchain_google_vertexai/model_garden.py:1:1: I001 Import block is un-sorted or un-formatted
# Remove the "Prompt:\n<prompt>\n" section from the start
prompt_section_pattern = re.compile(r'^Prompt:\n' + re.escape(prompt) + r'\n', re.DOTALL)

Check failure on line 20 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:20:89: E501 Line too long (93 > 88)

Check failure on line 20 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:20:89: E501 Line too long (93 > 88)
text_without_prompt_section = prompt_section_pattern.sub('', text, count=1)

# Define the output section starting pattern to look for
output_start_pattern = re.compile(r'^Output:\n', re.DOTALL)

# Check if the section immediately following "Output:\n" is the prompt
if re.match(output_start_pattern.pattern + re.escape(prompt), text_without_prompt_section):

Check failure on line 27 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:27:89: E501 Line too long (95 > 88)

Check failure on line 27 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:27:89: E501 Line too long (95 > 88)
# If the prompt is indeed repeated, remove "Output:\n<prompt>\n"
output_without_repeated_prompt = re.sub(output_start_pattern.pattern + re.escape(prompt), '', text_without_prompt_section, count=1)

Check failure on line 29 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:29:89: E501 Line too long (139 > 88)

Check failure on line 29 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:29:89: E501 Line too long (139 > 88)
else:
# If the prompt is not repeated, simply remove "Output:\n" to start extracting the model response

Check failure on line 31 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:31:89: E501 Line too long (105 > 88)

Check failure on line 31 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:31:89: E501 Line too long (105 > 88)
output_without_repeated_prompt = re.sub(output_start_pattern.pattern, '', text_without_prompt_section, count=1)

Check failure on line 32 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:32:89: E501 Line too long (119 > 88)

Check failure on line 32 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:32:89: E501 Line too long (119 > 88)

# Return the cleaned output section, which is the model response
return output_without_repeated_prompt.strip()


class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""
Expand All @@ -38,6 +60,10 @@ def _generate(
)

response = self.client.predict(endpoint=self.endpoint_path, instances=instances)

if not kwargs.get("keep_original_response", False):
response.predictions[0] = extract_model_response(response.predictions[0], prompts[0])

Check failure on line 65 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:65:89: E501 Line too long (97 > 88)

Check failure on line 65 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:65:89: E501 Line too long (97 > 88)

return self._parse_response(response)

async def _agenerate(
Expand Down Expand Up @@ -69,4 +95,8 @@ async def _agenerate(
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)

if not kwargs.get("keep_original_response", False):
response.predictions[0] = extract_model_response(response.predictions[0], prompts[0])

Check failure on line 100 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:100:89: E501 Line too long (97 > 88)

Check failure on line 100 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:100:89: E501 Line too long (97 > 88)

return self._parse_response(response)

0 comments on commit 7e7065e

Please sign in to comment.