Skip to content

Commit

Permalink
Merge pull request #18 from langchain-ai/cost_cb
Browse files Browse the repository at this point in the history
added cost callback
  • Loading branch information
lkuligin authored Feb 21, 2024
2 parents baab1e3 + 3e5da91 commit 302f264
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 38 deletions.
27 changes: 25 additions & 2 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
Image,
)
from vertexai.language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]


def create_retry_decorator(
Expand Down Expand Up @@ -102,6 +102,7 @@ def get_generation_info(
is_gemini: bool,
*,
stream: bool = False,
usage_metadata: Optional[Dict] = None,
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
Expand All @@ -121,11 +122,33 @@ def get_generation_info(
else None
),
}
if usage_metadata:
info["usage_metadata"] = usage_metadata
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
else:
info = dataclasses.asdict(candidate)
info.pop("text")
info = {k: v for k, v in info.items() if not k.startswith("_")}
if usage_metadata:
info_usage_metadata = {}
output_usage = usage_metadata.get("tokenMetadata", {}).get(
"outputTokenCount", {}
)
info_usage_metadata["candidates_billable_characters"] = output_usage.get(
"totalBillableCharacters"
)
info_usage_metadata["candidates_token_count"] = output_usage.get(
"totalTokens"
)
input_usage = usage_metadata.get("tokenMetadata", {}).get(
"inputTokenCount", {}
)
info_usage_metadata["prompt_billable_characters"] = input_usage.get(
"totalBillableCharacters"
)
info_usage_metadata["prompt_token_count"] = input_usage.get("totalTokens")
info["usage_metadata"] = {k: v for k, v in info_usage_metadata.items() if v}

if stream:
# Remove non-streamable types, like bools.
info.pop("is_blocked")
Expand Down
66 changes: 66 additions & 0 deletions libs/vertexai/langchain_google_vertexai/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import threading
from typing import Any, Dict, List

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult


class VertexAICallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks VertexAI info."""

prompt_tokens: int = 0
prompt_characters: int = 0
completion_tokens: int = 0
completion_characters: int = 0
successful_requests: int = 0

def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()

def __repr__(self) -> str:
return (
f"\tPrompt tokens: {self.prompt_tokens}\n"
f"\tPrompt characters: {self.prompt_characters}\n"
f"\tCompletion tokens: {self.completion_tokens}\n"
f"\tCompletion characters: {self.completion_characters}\n"
f"Successful requests: {self.successful_requests}\n"
)

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Runs when LLM starts running."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Runs on new LLM token. Only available when streaming is enabled."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collects token usage."""
completion_tokens, prompt_tokens = 0, 0
completion_characters, prompt_characters = 0, 0
for generations in response.generations:
if len(generations) > 0 and generations[0].generation_info:
usage_metadata = generations[0].generation_info.get(
"usage_metadata", {}
)
completion_tokens += usage_metadata.get("candidates_token_count", 0)
prompt_tokens += usage_metadata.get("prompt_token_count", 0)
completion_characters += usage_metadata.get(
"candidates_billable_characters", 0
)
prompt_characters += usage_metadata.get("prompt_billable_characters", 0)

with self._lock:
self.prompt_characters += prompt_characters
self.prompt_tokens += prompt_tokens
self.completion_characters += completion_characters
self.completion_tokens += completion_tokens
self.successful_requests += 1
57 changes: 41 additions & 16 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,14 @@ def _generate(
)
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
message=_parse_response_candidate(candidate),
generation_info=get_generation_info(
candidate,
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
for c in response.candidates
for candidate in response.candidates
]
else:
question = _get_question(messages)
Expand All @@ -412,10 +416,14 @@ def _generate(
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
message=AIMessage(content=candidate.text),
generation_info=get_generation_info(
candidate,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)
for r in response.candidates
for candidate in response.candidates
]
return ChatResult(generations=generations)

Expand Down Expand Up @@ -470,7 +478,11 @@ async def _agenerate(
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
generation_info=get_generation_info(
c,
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
for c in response.candidates
]
Expand All @@ -485,7 +497,11 @@ async def _agenerate(
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
generation_info=get_generation_info(
r,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)
for r in response.candidates
]
Expand Down Expand Up @@ -526,7 +542,12 @@ def _stream(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
)
),
generation_info=get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
else:
question = _get_question(messages)
Expand All @@ -536,13 +557,17 @@ def _stream(
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(response, self._is_gemini_model),
)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(
response,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)

def _start_chat(
self, history: _ChatHistory, **kwargs: Any
Expand Down
45 changes: 33 additions & 12 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
GenerativeModel,
Image,
)
Expand Down Expand Up @@ -327,12 +328,19 @@ def validate_environment(cls, values: Dict) -> Dict:
raise ValueError("Only one candidate can be generated with streaming!")
return values

def _response_to_generation(
self, response: TextGenerationResponse, *, stream: bool = False
def _candidate_to_generation(
self,
response: Union[Candidate, TextGenerationResponse],
*,
stream: bool = False,
usage_metadata: Optional[Dict] = None,
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
generation_info = get_generation_info(
response, self._is_gemini_model, stream=stream
response,
self._is_gemini_model,
stream=stream,
usage_metadata=usage_metadata,
)
try:
text = response.text
Expand Down Expand Up @@ -373,8 +381,15 @@ def _generate(
run_manager=run_manager,
**params,
)
if self._is_gemini_model:
usage_metadata = res.to_dict().get("usage_metadata")
else:
usage_metadata = res.raw_prediction_response.metadata
generations.append(
[self._response_to_generation(r) for r in res.candidates]
[
self._candidate_to_generation(r, usage_metadata=usage_metadata)
for r in res.candidates
]
)
return LLMResult(generations=generations)

Expand All @@ -395,8 +410,15 @@ async def _agenerate(
run_manager=run_manager,
**params,
)
if self._is_gemini_model:
usage_metadata = res.to_dict().get("usage_metadata")
else:
usage_metadata = res.raw_prediction_response.metadata
generations.append(
[self._response_to_generation(r) for r in res.candidates]
[
self._candidate_to_generation(r, usage_metadata=usage_metadata)
for r in res.candidates
]
)
return LLMResult(generations=generations)

Expand All @@ -416,14 +438,13 @@ def _stream(
run_manager=run_manager,
**params,
):
# Gemini models return GenerationResponse even when streaming, which has a
# candidates field.
stream_resp = (
stream_resp
if isinstance(stream_resp, TextGenerationResponse)
else stream_resp.candidates[0]
usage_metadata = None
if self._is_gemini_model:
usage_metadata = stream_resp.to_dict().get("usage_metadata")
stream_resp = stream_resp.candidates[0]
chunk = self._candidate_to_generation(
stream_resp, stream=True, usage_metadata=usage_metadata
)
chunk = self._response_to_generation(stream_resp, stream=True)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
Expand Down
84 changes: 84 additions & 0 deletions libs/vertexai/tests/integration_tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest
from langchain_core.messages import HumanMessage

from langchain_google_vertexai.callbacks import VertexAICallbackHandler
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.llms import VertexAI


@pytest.mark.parametrize(
"model_name",
["gemini-pro", "text-bison@001", "code-bison@001"],
)
def test_llm_invoke(model_name: str) -> None:
vb = VertexAICallbackHandler()
llm = VertexAI(model_name=model_name, temperature=0.0, callbacks=[vb])
_ = llm.invoke("2+2")
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0
prompt_tokens = vb.prompt_tokens
completion_tokens = vb.completion_tokens
_ = llm.invoke("2+2")
assert vb.successful_requests == 2
assert vb.prompt_tokens > prompt_tokens
assert vb.completion_tokens > completion_tokens


@pytest.mark.parametrize(
"model_name",
["gemini-pro", "chat-bison@001", "codechat-bison@001"],
)
def test_chat_call(model_name: str) -> None:
vb = VertexAICallbackHandler()
llm = ChatVertexAI(model_name=model_name, temperature=0.0, callbacks=[vb])
message = HumanMessage(content="Hello")
_ = llm([message])
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0
prompt_tokens = vb.prompt_tokens
completion_tokens = vb.completion_tokens
_ = llm([message])
assert vb.successful_requests == 2
assert vb.prompt_tokens > prompt_tokens
assert vb.completion_tokens > completion_tokens


@pytest.mark.parametrize(
"model_name",
["gemini-pro", "text-bison@001", "code-bison@001"],
)
def test_invoke_config(model_name: str) -> None:
vb = VertexAICallbackHandler()
llm = VertexAI(model_name=model_name, temperature=0.0)
llm.invoke("2+2", config={"callbacks": [vb]})
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0
prompt_tokens = vb.prompt_tokens
completion_tokens = vb.completion_tokens
llm.invoke("2+2", config={"callbacks": [vb]})
assert vb.successful_requests == 2
assert vb.prompt_tokens > prompt_tokens
assert vb.completion_tokens > completion_tokens


def test_llm_stream() -> None:
vb = VertexAICallbackHandler()
llm = VertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb])
for _ in llm.stream("2+2"):
pass
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0


def test_chat_stream() -> None:
vb = VertexAICallbackHandler()
llm = ChatVertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb])
for _ in llm.stream("2+2"):
pass
assert vb.successful_requests == 1
assert vb.completion_tokens > 0
assert vb.completion_tokens > 0
Loading

0 comments on commit 302f264

Please sign in to comment.