From 072559db01e12ec12d826275c2d18e0b3d45318c Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 20 Feb 2024 10:41:08 +0100 Subject: [PATCH] added cost callback --- .../langchain_google_vertexai/_utils.py | 27 +++++- .../langchain_google_vertexai/callbacks.py | 66 +++++++++++++++ .../langchain_google_vertexai/chat_models.py | 49 ++++++++--- .../langchain_google_vertexai/llms.py | 45 +++++++--- .../tests/integration_tests/test_callbacks.py | 84 +++++++++++++++++++ .../integration_tests/test_chat_models.py | 14 ++-- .../tests/integration_tests/test_llms.py | 9 ++ 7 files changed, 260 insertions(+), 34 deletions(-) create mode 100644 libs/vertexai/langchain_google_vertexai/callbacks.py create mode 100644 libs/vertexai/tests/integration_tests/test_callbacks.py diff --git a/libs/vertexai/langchain_google_vertexai/_utils.py b/libs/vertexai/langchain_google_vertexai/_utils.py index 1af44f0c..5b8c59a3 100644 --- a/libs/vertexai/langchain_google_vertexai/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_utils.py @@ -13,13 +13,13 @@ CallbackManagerForLLMRun, ) from langchain_core.language_models.llms import create_base_retry_decorator -from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped] +from vertexai.generative_models import ( # type: ignore[import-untyped] Candidate, + Image, ) from vertexai.language_models import ( # type: ignore[import-untyped] TextGenerationResponse, ) -from vertexai.preview.generative_models import Image # type: ignore[import-untyped] def create_retry_decorator( @@ -102,6 +102,7 @@ def get_generation_info( is_gemini: bool, *, stream: bool = False, + usage_metadata: Optional[Dict] = None, ) -> Dict[str, Any]: if is_gemini: # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body @@ -121,11 +122,33 @@ def get_generation_info( else None ), } + if usage_metadata: + info["usage_metadata"] = usage_metadata # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body else: info = dataclasses.asdict(candidate) info.pop("text") info = {k: v for k, v in info.items() if not k.startswith("_")} + if usage_metadata: + info_usage_metadata = {} + output_usage = usage_metadata.get("tokenMetadata", {}).get( + "outputTokenCount", {} + ) + info_usage_metadata["candidates_billable_characters"] = output_usage.get( + "totalBillableCharacters" + ) + info_usage_metadata["candidates_token_count"] = output_usage.get( + "totalTokens" + ) + input_usage = usage_metadata.get("tokenMetadata", {}).get( + "inputTokenCount", {} + ) + info_usage_metadata["prompt_billable_characters"] = input_usage.get( + "totalBillableCharacters" + ) + info_usage_metadata["prompt_token_count"] = input_usage.get("totalTokens") + info["usage_metadata"] = {k: v for k, v in info_usage_metadata.items() if v} + if stream: # Remove non-streamable types, like bools. info.pop("is_blocked") diff --git a/libs/vertexai/langchain_google_vertexai/callbacks.py b/libs/vertexai/langchain_google_vertexai/callbacks.py new file mode 100644 index 00000000..37315eda --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/callbacks.py @@ -0,0 +1,66 @@ +import threading +from typing import Any, Dict, List + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + + +class VertexAICallbackHandler(BaseCallbackHandler): + """Callback Handler that tracks VertexAI info.""" + + prompt_tokens: int = 0 + prompt_characters: int = 0 + completion_tokens: int = 0 + completion_characters: int = 0 + successful_requests: int = 0 + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + + def __repr__(self) -> str: + return ( + f"\tPrompt tokens: {self.prompt_tokens}\n" + f"\tPrompt characters: {self.prompt_characters}\n" + f"\tCompletion tokens: {self.completion_tokens}\n" + f"\tCompletion characters: {self.completion_characters}\n" + f"Successful requests: {self.successful_requests}\n" + ) + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Runs when LLM starts running.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Runs on new LLM token. Only available when streaming is enabled.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collects token usage.""" + completion_tokens, prompt_tokens = 0, 0 + completion_characters, prompt_characters = 0, 0 + for generations in response.generations: + if len(generations) > 0 and generations[0].generation_info: + usage_metadata = generations[0].generation_info.get( + "usage_metadata", {} + ) + completion_tokens += usage_metadata.get("candidates_token_count", 0) + prompt_tokens += usage_metadata.get("prompt_token_count", 0) + completion_characters += usage_metadata.get( + "candidates_billable_characters", 0 + ) + prompt_characters += usage_metadata.get("prompt_billable_characters", 0) + + with self._lock: + self.prompt_characters += prompt_characters + self.prompt_tokens += prompt_tokens + self.completion_characters += completion_characters + self.completion_tokens += completion_tokens + self.successful_requests += 1 diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 62613078..a33d0764 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -398,7 +398,11 @@ def _generate( generations = [ ChatGeneration( message=_parse_response_candidate(c), - generation_info=get_generation_info(c, self._is_gemini_model), + generation_info=get_generation_info( + c, + self._is_gemini_model, + usage_metadata=response.to_dict().get("usage_metadata"), + ), ) for c in response.candidates ] @@ -413,7 +417,11 @@ def _generate( generations = [ ChatGeneration( message=AIMessage(content=r.text), - generation_info=get_generation_info(r, self._is_gemini_model), + generation_info=get_generation_info( + r, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), ) for r in response.candidates ] @@ -470,7 +478,11 @@ async def _agenerate( generations = [ ChatGeneration( message=_parse_response_candidate(c), - generation_info=get_generation_info(c, self._is_gemini_model), + generation_info=get_generation_info( + c, + self._is_gemini_model, + usage_metadata=response.to_dict().get("usage_metadata"), + ), ) for c in response.candidates ] @@ -485,7 +497,11 @@ async def _agenerate( generations = [ ChatGeneration( message=AIMessage(content=r.text), - generation_info=get_generation_info(r, self._is_gemini_model), + generation_info=get_generation_info( + r, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), ) for r in response.candidates ] @@ -526,7 +542,12 @@ def _stream( message=AIMessageChunk( content=message.content, additional_kwargs=message.additional_kwargs, - ) + ), + generation_info=get_generation_info( + response.candidates[0], + self._is_gemini_model, + usage_metadata=response.to_dict().get("usage_metadata"), + ), ) else: question = _get_question(messages) @@ -536,13 +557,17 @@ def _stream( params["examples"] = _parse_examples(examples) chat = self._start_chat(history, **params) responses = chat.send_message_streaming(question.content, **params) - for response in responses: - if run_manager: - run_manager.on_llm_new_token(response.text) - yield ChatGenerationChunk( - message=AIMessageChunk(content=response.text), - generation_info=get_generation_info(response, self._is_gemini_model), - ) + for response in responses: + if run_manager: + run_manager.on_llm_new_token(response.text) + yield ChatGenerationChunk( + message=AIMessageChunk(content=response.text), + generation_info=get_generation_info( + response, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), + ) def _start_chat( self, history: _ChatHistory, **kwargs: Any diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index 3844c649..8b6f1062 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -20,6 +20,7 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from vertexai.generative_models import ( # type: ignore[import-untyped] + Candidate, GenerativeModel, Image, ) @@ -327,12 +328,19 @@ def validate_environment(cls, values: Dict) -> Dict: raise ValueError("Only one candidate can be generated with streaming!") return values - def _response_to_generation( - self, response: TextGenerationResponse, *, stream: bool = False + def _candidate_to_generation( + self, + response: Union[Candidate, TextGenerationResponse], + *, + stream: bool = False, + usage_metadata: Optional[Dict] = None, ) -> GenerationChunk: """Converts a stream response to a generation chunk.""" generation_info = get_generation_info( - response, self._is_gemini_model, stream=stream + response, + self._is_gemini_model, + stream=stream, + usage_metadata=usage_metadata, ) try: text = response.text @@ -373,8 +381,15 @@ def _generate( run_manager=run_manager, **params, ) + if self._is_gemini_model: + usage_metadata = res.to_dict().get("usage_metadata") + else: + usage_metadata = res.raw_prediction_response.metadata generations.append( - [self._response_to_generation(r) for r in res.candidates] + [ + self._candidate_to_generation(r, usage_metadata=usage_metadata) + for r in res.candidates + ] ) return LLMResult(generations=generations) @@ -395,8 +410,15 @@ async def _agenerate( run_manager=run_manager, **params, ) + if self._is_gemini_model: + usage_metadata = res.to_dict().get("usage_metadata") + else: + usage_metadata = res.raw_prediction_response.metadata generations.append( - [self._response_to_generation(r) for r in res.candidates] + [ + self._candidate_to_generation(r, usage_metadata=usage_metadata) + for r in res.candidates + ] ) return LLMResult(generations=generations) @@ -416,14 +438,13 @@ def _stream( run_manager=run_manager, **params, ): - # Gemini models return GenerationResponse even when streaming, which has a - # candidates field. - stream_resp = ( - stream_resp - if isinstance(stream_resp, TextGenerationResponse) - else stream_resp.candidates[0] + usage_metadata = None + if self._is_gemini_model: + usage_metadata = stream_resp.to_dict().get("usage_metadata") + stream_resp = stream_resp.candidates[0] + chunk = self._candidate_to_generation( + stream_resp, stream=True, usage_metadata=usage_metadata ) - chunk = self._response_to_generation(stream_resp, stream=True) yield chunk if run_manager: run_manager.on_llm_new_token( diff --git a/libs/vertexai/tests/integration_tests/test_callbacks.py b/libs/vertexai/tests/integration_tests/test_callbacks.py new file mode 100644 index 00000000..f0b9e84c --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_callbacks.py @@ -0,0 +1,84 @@ +import pytest +from langchain_core.messages import HumanMessage + +from langchain_google_vertexai.callbacks import VertexAICallbackHandler +from langchain_google_vertexai.chat_models import ChatVertexAI +from langchain_google_vertexai.llms import VertexAI + + +@pytest.mark.parametrize( + "model_name", + ["gemini-pro", "text-bison@001", "code-bison@001"], +) +def test_llm_invoke(model_name: str) -> None: + vb = VertexAICallbackHandler() + llm = VertexAI(model_name=model_name, temperature=0.0, callbacks=[vb]) + _ = llm.invoke("2+2") + assert vb.successful_requests == 1 + assert vb.prompt_tokens > 0 + assert vb.completion_tokens > 0 + prompt_tokens = vb.prompt_tokens + completion_tokens = vb.completion_tokens + _ = llm.invoke("2+2") + assert vb.successful_requests == 2 + assert vb.prompt_tokens > prompt_tokens + assert vb.completion_tokens > completion_tokens + + +@pytest.mark.parametrize( + "model_name", + ["gemini-pro", "chat-bison@001", "codechat-bison@001"], +) +def test_chat_call(model_name: str) -> None: + vb = VertexAICallbackHandler() + llm = ChatVertexAI(model_name=model_name, temperature=0.0, callbacks=[vb]) + message = HumanMessage(content="Hello") + _ = llm([message]) + assert vb.successful_requests == 1 + assert vb.prompt_tokens > 0 + assert vb.completion_tokens > 0 + prompt_tokens = vb.prompt_tokens + completion_tokens = vb.completion_tokens + _ = llm([message]) + assert vb.successful_requests == 2 + assert vb.prompt_tokens > prompt_tokens + assert vb.completion_tokens > completion_tokens + + +@pytest.mark.parametrize( + "model_name", + ["gemini-pro", "text-bison@001", "code-bison@001"], +) +def test_invoke_config(model_name: str) -> None: + vb = VertexAICallbackHandler() + llm = VertexAI(model_name=model_name, temperature=0.0) + llm.invoke("2+2", config={"callbacks": [vb]}) + assert vb.successful_requests == 1 + assert vb.prompt_tokens > 0 + assert vb.completion_tokens > 0 + prompt_tokens = vb.prompt_tokens + completion_tokens = vb.completion_tokens + llm.invoke("2+2", config={"callbacks": [vb]}) + assert vb.successful_requests == 2 + assert vb.prompt_tokens > prompt_tokens + assert vb.completion_tokens > completion_tokens + + +def test_llm_stream() -> None: + vb = VertexAICallbackHandler() + llm = VertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb]) + for _ in llm.stream("2+2"): + pass + assert vb.successful_requests == 1 + assert vb.prompt_tokens > 0 + assert vb.completion_tokens > 0 + + +def test_chat_stream() -> None: + vb = VertexAICallbackHandler() + llm = ChatVertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb]) + for _ in llm.stream("2+2"): + pass + assert vb.successful_requests == 1 + assert vb.completion_tokens > 0 + assert vb.completion_tokens > 0 diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index 53c391fe..c3721af3 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -68,14 +68,12 @@ async def test_vertexai_agenerate(model_name: str) -> None: sync_generation = cast(ChatGeneration, sync_response.generations[0][0]) async_generation = cast(ChatGeneration, response.generations[0][0]) - # assert some properties to make debugging easier - - # xfail: this is not equivalent with temp=0 right now - # assert sync_generation.message.content == async_generation.message.content - assert sync_generation.generation_info == async_generation.generation_info - - # xfail: content is not same right now - # assert sync_generation == async_generation + usage_metadata = sync_generation.generation_info["usage_metadata"] # type: ignore + assert int(usage_metadata["prompt_token_count"]) > 0 + assert int(usage_metadata["candidates_token_count"]) > 0 + usage_metadata = async_generation.generation_info["usage_metadata"] # type: ignore + assert int(usage_metadata["prompt_token_count"]) > 0 + assert int(usage_metadata["candidates_token_count"]) > 0 @pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"]) diff --git a/libs/vertexai/tests/integration_tests/test_llms.py b/libs/vertexai/tests/integration_tests/test_llms.py index ae10d937..29f29db2 100644 --- a/libs/vertexai/tests/integration_tests/test_llms.py +++ b/libs/vertexai/tests/integration_tests/test_llms.py @@ -55,6 +55,9 @@ def test_vertex_generate(model_name: str) -> None: output = llm.generate(["Say foo:"]) assert isinstance(output, LLMResult) assert len(output.generations) == 1 + usage_metadata = output.generations[0][0].generation_info["usage_metadata"] # type: ignore + assert int(usage_metadata["prompt_token_count"]) == 3 + assert int(usage_metadata["candidates_token_count"]) > 0 @pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates") @@ -73,12 +76,18 @@ def test_vertex_generate_code() -> None: assert isinstance(output, LLMResult) assert len(output.generations) == 1 assert len(output.generations[0]) == 2 + usage_metadata = output.generations[0][0].generation_info["usage_metadata"] # type: ignore + assert int(usage_metadata["prompt_token_count"]) == 3 + assert int(usage_metadata["candidates_token_count"]) > 1 async def test_vertex_agenerate() -> None: llm = VertexAI(temperature=0) output = await llm.agenerate(["Please say foo:"]) assert isinstance(output, LLMResult) + usage_metadata = output.generations[0][0].generation_info["usage_metadata"] # type: ignore + assert int(usage_metadata["prompt_token_count"]) == 4 + assert int(usage_metadata["candidates_token_count"]) > 0 @pytest.mark.parametrize(