From 8be34db6a313edaae0710e24fcd57efdee276461 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 27 Feb 2024 19:01:31 +0100 Subject: [PATCH] added post-processing for local gemma (#40) * added post-processing for local gemma * fixes after review --- .../langchain_google_vertexai/gemma.py | 55 ++++++++++++++++--- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index c24cb0b1..4a453d9a 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -53,6 +53,19 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str: return "".join(messages) +def _parse_gemma_chat_response(response: str) -> str: + """Removes chat history from the response.""" + pattern = "model\n" + pos = response.rfind(pattern) + if pos == -1: + return response + text = response[(pos + len(pattern)) :] + pos = text.find("user\n") + if pos > 0: + return text[:pos] + return text + + class _GemmaBase(BaseModel): max_tokens: Optional[int] = None """The maximum number of tokens to generate.""" @@ -98,6 +111,9 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha "top_k", "max_tokens", ] + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" @property def _llm_type(self) -> str: @@ -120,6 +136,8 @@ def _generate( request["prompt"] = gemma_messages_to_prompt(messages) output = self.client.predict(endpoint=self.endpoint_path, instances=[request]) text = output.predictions[0] + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -143,6 +161,8 @@ async def _agenerate( endpoint=self.endpoint_path, instances=[request] ) text = output.predictions[0] + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -183,6 +203,11 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): """Local gemma chat model loaded from Kaggle.""" @@ -195,7 +220,7 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} + params = self._get_params(**kwargs) results = self.client.generate(prompts, **params) results = [results] if isinstance(results, str) else results if stop: @@ -209,6 +234,10 @@ def _llm_type(self) -> str: class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel): + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" + def _generate( self, messages: List[BaseMessage], @@ -216,9 +245,11 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = {"max_length": self.max_tokens} if self.max_tokens else {} + params = self._get_params(**kwargs) prompt = gemma_messages_to_prompt(messages) text = self.client.generate(prompt, **params) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text)) @@ -268,9 +299,15 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + def _run(self, prompt: str, **kwargs: Any) -> str: inputs = self.tokenizer(prompt, return_tensors="pt") - generate_ids = self.client.generate(inputs.input_ids, **kwargs) + params = self._get_params(**kwargs) + generate_ids = self.client.generate(inputs.input_ids, **params) return self.tokenizer.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] @@ -287,8 +324,7 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} - results = [self._run(prompt, **params) for prompt in prompts] + results = [self._run(prompt, **kwargs) for prompt in prompts] if stop: results = [enforce_stop_tokens(text, stop) for text in results] return LLMResult(generations=[[Generation(text=text)] for text in results]) @@ -300,7 +336,9 @@ def _llm_type(self) -> str: class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel): - """Local gemma chat model loaded from HuggingFace.""" + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" def _generate( self, @@ -309,9 +347,10 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = {"max_length": self.max_tokens} if self.max_tokens else {} prompt = gemma_messages_to_prompt(messages) - text = self._run(prompt, **params) + text = self._run(prompt, **kwargs) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text))