diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index c24cb0b12..ee04922c4 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -53,6 +53,21 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str: return "".join(messages) +def _parse_gemma_chat_response(response: str) -> str: + """Removes chat history from the response.""" + pattern = "model\n" + pos = response.rfind(pattern) + if pos == -1: + return response + else: + text = response[(pos + len(pattern)) :] + pos = text.find("user\n") + if pos > 0: + return text[:pos] + else: + return text + + class _GemmaBase(BaseModel): max_tokens: Optional[int] = None """The maximum number of tokens to generate.""" @@ -98,6 +113,9 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha "top_k", "max_tokens", ] + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" @property def _llm_type(self) -> str: @@ -120,6 +138,8 @@ def _generate( request["prompt"] = gemma_messages_to_prompt(messages) output = self.client.predict(endpoint=self.endpoint_path, instances=[request]) text = output.predictions[0] + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -142,7 +162,9 @@ async def _agenerate( output = await self.async_client.predict( endpoint=self.endpoint_path, instances=[request] ) - text = output.predictions[0] + text = _parse_gemma_chat_response(text) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -183,6 +205,11 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): """Local gemma chat model loaded from Kaggle.""" @@ -195,7 +222,7 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} + params = self._get_params(**kwargs) results = self.client.generate(prompts, **params) results = [results] if isinstance(results, str) else results if stop: @@ -209,6 +236,10 @@ def _llm_type(self) -> str: class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel): + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" + def _generate( self, messages: List[BaseMessage], @@ -216,9 +247,11 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = {"max_length": self.max_tokens} if self.max_tokens else {} + params = self._get_params(**kwargs) prompt = gemma_messages_to_prompt(messages) text = self.client.generate(prompt, **params) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text)) @@ -268,9 +301,15 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + def _run(self, prompt: str, **kwargs: Any) -> str: inputs = self.tokenizer(prompt, return_tensors="pt") - generate_ids = self.client.generate(inputs.input_ids, **kwargs) + params = self._get_params(**kwargs) + generate_ids = self.client.generate(inputs.input_ids, **params) return self.tokenizer.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] @@ -287,8 +326,7 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} - results = [self._run(prompt, **params) for prompt in prompts] + results = [self._run(prompt, **kwargs) for prompt in prompts] if stop: results = [enforce_stop_tokens(text, stop) for text in results] return LLMResult(generations=[[Generation(text=text)] for text in results]) @@ -300,7 +338,9 @@ def _llm_type(self) -> str: class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel): - """Local gemma chat model loaded from HuggingFace.""" + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" def _generate( self, @@ -309,9 +349,10 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = {"max_length": self.max_tokens} if self.max_tokens else {} prompt = gemma_messages_to_prompt(messages) - text = self._run(prompt, **params) + text = self._run(prompt, **kwargs) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text))