diff --git a/libs/databricks/langchain_databricks/chat_models.py b/libs/databricks/langchain_databricks/chat_models.py index d0c54a0..dcbdbed 100644 --- a/libs/databricks/langchain_databricks/chat_models.py +++ b/libs/databricks/langchain_databricks/chat_models.py @@ -35,6 +35,7 @@ ToolMessage, ToolMessageChunk, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike @@ -157,6 +158,30 @@ class ChatDatabricks(BaseChatModel): id='run-4cef851f-6223-424f-ad26-4a54e5852aa5' ) + To get token usage returned when streaming, pass the ``stream_usage`` kwarg: + + .. code-block:: python + + stream = llm.stream(messages, stream_usage=True) + next(stream).usage_metadata + + .. code-block:: python + + {"input_tokens": 28, "output_tokens": 5, "total_tokens": 33} + + Alternatively, setting ``stream_usage`` when instantiating the model can be + useful when incorporating ``ChatDatabricks`` into LCEL chains-- or when using + methods like ``.with_structured_output``, which generate chains under the + hood. + + .. code-block:: python + + llm = ChatDatabricks( + endpoint="databricks-meta-llama-3-1-405b-instruct", + stream_usage=True + ) + structured_llm = llm.with_structured_output(...) + Async: .. code-block:: python @@ -229,6 +254,10 @@ class GetPopulation(BaseModel): max_tokens: Optional[int] = None """The maximum number of tokens to generate.""" extra_params: Optional[Dict[str, Any]] = None + """Whether to include usage metadata in streaming output. If True, additional + message chunks will be generated during the stream including usage metadata. + """ + stream_usage: bool = False """Any extra parameters to pass to the endpoint.""" client: Optional[BaseDeploymentClient] = Field( default=None, exclude=True @@ -301,8 +330,12 @@ def _stream( messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + if stream_usage is None: + stream_usage = self.stream_usage data = self._prepare_inputs(messages, stop, **kwargs) first_chunk_role = None for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): # type: ignore @@ -313,8 +346,19 @@ def _stream( if first_chunk_role is None: first_chunk_role = chunk_delta.get("role") + if stream_usage and (usage := chunk.get("usage")): + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + usage = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + } + else: + usage = None + chunk_message = _convert_dict_to_message_chunk( - chunk_delta, first_chunk_role + chunk_delta, first_chunk_role, usage=usage ) generation_info = {} @@ -759,7 +803,9 @@ def _convert_dict_to_message(_dict: Dict) -> BaseMessage: def _convert_dict_to_message_chunk( - _dict: Mapping[str, Any], default_role: str + _dict: Mapping[str, Any], + default_role: str, + usage: Optional[Dict[str, Any]] = None, ) -> BaseMessageChunk: role = _dict.get("role", default_role) content = _dict.get("content") @@ -790,11 +836,13 @@ def _convert_dict_to_message_chunk( ] except KeyError: pass + usage_metadata = UsageMetadata(**usage) if usage else None # type: ignore return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, id=_dict.get("id"), tool_call_chunks=tool_call_chunks, + usage_metadata=usage_metadata, ) else: return ChatMessageChunk(content=content, role=role) diff --git a/libs/databricks/tests/integration_tests/test_chat_models.py b/libs/databricks/tests/integration_tests/test_chat_models.py index 0a7f491..d2fe498 100644 --- a/libs/databricks/tests/integration_tests/test_chat_models.py +++ b/libs/databricks/tests/integration_tests/test_chat_models.py @@ -116,6 +116,38 @@ def on_llm_new_token(self, *args, **kwargs): assert last_chunk.response_metadata["finish_reason"] == "stop" +def test_chat_databricks_stream_with_usage(): + class FakeCallbackHandler(BaseCallbackHandler): + def __init__(self): + self.chunk_counts = 0 + + def on_llm_new_token(self, *args, **kwargs): + self.chunk_counts += 1 + + callback = FakeCallbackHandler() + + chat = ChatDatabricks( + endpoint=_TEST_ENDPOINT, + temperature=0, + stop=["Python"], + max_tokens=100, + stream_usage=True, + ) + + chunks = list(chat.stream("How to learn Python?", config={"callbacks": [callback]})) + assert len(chunks) > 0 + assert all(isinstance(chunk, AIMessageChunk) for chunk in chunks) + assert all("Python" not in chunk.content for chunk in chunks) + assert callback.chunk_counts == len(chunks) + + last_chunk = chunks[-1] + assert last_chunk.response_metadata["finish_reason"] == "stop" + assert last_chunk.usage_metadata is not None + assert last_chunk.usage_metadata["input_tokens"] > 0 + assert last_chunk.usage_metadata["output_tokens"] > 0 + assert last_chunk.usage_metadata["total_tokens"] > 0 + + @pytest.mark.asyncio async def test_chat_databricks_ainvoke(): chat = ChatDatabricks( diff --git a/libs/databricks/tests/unit_tests/test_chat_models.py b/libs/databricks/tests/unit_tests/test_chat_models.py index 37579da..0f684b3 100644 --- a/libs/databricks/tests/unit_tests/test_chat_models.py +++ b/libs/databricks/tests/unit_tests/test_chat_models.py @@ -188,6 +188,43 @@ def test_chat_model_stream(llm: ChatDatabricks) -> None: assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index] +def test_chat_model_stream_with_usage(llm: ChatDatabricks) -> None: + def _assert_usage(chunk, expected): + usage = chunk.usage_metadata + assert usage is not None + assert usage["input_tokens"] == expected["usage"]["prompt_tokens"] + assert usage["output_tokens"] == expected["usage"]["completion_tokens"] + assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] + + # Method 1: Pass stream_usage=True to the constructor + res = llm.stream( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "36939 * 8922.4"}, + ], + stream_usage=True, + ) + for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE): + assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index] + _assert_usage(chunk, expected) + + # Method 2: Pass stream_usage=True to the constructor + llm_with_usage = ChatDatabricks( + endpoint="databricks-meta-llama-3-70b-instruct", + target_uri="databricks", + stream_usage=True, + ) + res = llm_with_usage.stream( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "36939 * 8922.4"}, + ], + ) + for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE): + assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index] + _assert_usage(chunk, expected) + + class GetWeather(BaseModel): """Get the current weather in a given location"""