Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extracting model response #73

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
from __future__ import annotations

import asyncio
from typing import Any, List, Optional

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult

from langchain_google_vertexai._base import _BaseVertexAIModelGarden

import re


def extract_model_response(text, prompt):

Check failure on line 18 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (I001)

langchain_google_vertexai/model_garden.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 18 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (I001)

langchain_google_vertexai/model_garden.py:1:1: I001 Import block is un-sorted or un-formatted
# Remove the "Prompt:\n<prompt>\n" section from the start
prompt_section_pattern = re.compile(r'^Prompt:\n' + re.escape(prompt) + r'\n', re.DOTALL)

Check failure on line 20 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:20:89: E501 Line too long (93 > 88)

Check failure on line 20 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:20:89: E501 Line too long (93 > 88)
text_without_prompt_section = prompt_section_pattern.sub('', text, count=1)

# Define the output section starting pattern to look for
output_start_pattern = re.compile(r'^Output:\n', re.DOTALL)

# Check if the section immediately following "Output:\n" is the prompt
if re.match(output_start_pattern.pattern + re.escape(prompt), text_without_prompt_section):

Check failure on line 27 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:27:89: E501 Line too long (95 > 88)

Check failure on line 27 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:27:89: E501 Line too long (95 > 88)
# If the prompt is indeed repeated, remove "Output:\n<prompt>\n"
output_without_repeated_prompt = re.sub(output_start_pattern.pattern + re.escape(prompt), '', text_without_prompt_section, count=1)

Check failure on line 29 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:29:89: E501 Line too long (139 > 88)

Check failure on line 29 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:29:89: E501 Line too long (139 > 88)
else:
# If the prompt is not repeated, simply remove "Output:\n" to start extracting the model response

Check failure on line 31 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:31:89: E501 Line too long (105 > 88)

Check failure on line 31 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:31:89: E501 Line too long (105 > 88)
output_without_repeated_prompt = re.sub(output_start_pattern.pattern, '', text_without_prompt_section, count=1)

Check failure on line 32 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:32:89: E501 Line too long (119 > 88)

Check failure on line 32 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:32:89: E501 Line too long (119 > 88)

# Return the cleaned output section, which is the model response
return output_without_repeated_prompt.strip()


class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""
Expand All @@ -38,6 +60,10 @@
)

response = self.client.predict(endpoint=self.endpoint_path, instances=instances)

if not kwargs.get("keep_original_response", False):
response.predictions[0] = extract_model_response(response.predictions[0], prompts[0])

Check failure on line 65 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:65:89: E501 Line too long (97 > 88)

Check failure on line 65 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:65:89: E501 Line too long (97 > 88)

return self._parse_response(response)

async def _agenerate(
Expand Down Expand Up @@ -69,4 +95,8 @@
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)

if not kwargs.get("keep_original_response", False):
response.predictions[0] = extract_model_response(response.predictions[0], prompts[0])

Check failure on line 100 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (E501)

langchain_google_vertexai/model_garden.py:100:89: E501 Line too long (97 > 88)

Check failure on line 100 in libs/vertexai/langchain_google_vertexai/model_garden.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (E501)

langchain_google_vertexai/model_garden.py:100:89: E501 Line too long (97 > 88)

return self._parse_response(response)
Loading