Skip to content

Commit

Permalink
Token tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Dec 2, 2024
1 parent 17a34e6 commit 8cc3481
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
44 changes: 43 additions & 1 deletion py/core/agent/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import logging

from abc import ABCMeta
from litellm import token_counter
from typing import AsyncGenerator, Generator, Optional

from core.base.abstractions import (
Expand Down Expand Up @@ -61,13 +63,36 @@ async def arun(
for message in messages:
await self.conversation.add_message(message)

print(f"Messages: {messages}")

token_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}

while not self._completed:
messages_list = await self.conversation.get_messages()
generation_config = self.get_generation_config(messages_list[-1])

prompt_tokens = token_counter(
model=generation_config.model, messages=messages_list
)
token_usage["prompt_tokens"] += prompt_tokens

response = await self.llm_provider.aget_completion(
messages_list,
generation_config,
)

# assistant_message = response.choices[0].message
# completion_tokens = token_counter(model=generation_config.model, messages=[assistant_message])
# token_usage['completion_tokens'] += completion_tokens

# token_usage['total_tokens'] = token_usage['prompt_tokens'] + token_usage['completion_tokens']

token_usage["total_tokens"] = token_usage["prompt_tokens"]

await self.process_llm_response(response, *args, **kwargs)

# Get the output messages
Expand All @@ -76,6 +101,7 @@ async def arun(

output_messages = []
for message_2 in all_messages:
print(f"Type of message_2: {type(message_2)}")
if (
message_2.get("content")
and message_2.get("content") != messages[-1].content
Expand All @@ -85,7 +111,7 @@ async def arun(
break
output_messages.reverse()

return output_messages
return {"messages": output_messages, "token_usage": token_usage}

async def process_llm_response(
self, response: LLMChatCompletion, *args, **kwargs
Expand Down Expand Up @@ -135,15 +161,31 @@ async def arun( # type: ignore
generation_config = self.get_generation_config(
messages_list[-1], stream=True
)
print(f"Messages: {messages_list}")
tokens = token_counter(
model=generation_config.model,
messages=messages_list,
)
print(f"Tokens: {tokens}")
stream = self.llm_provider.aget_completion_stream(
messages_list,
generation_config,
)

chunk_content = ""
async for proc_chunk in self.process_llm_response(
stream, *args, **kwargs
):
chunk_content += proc_chunk
yield proc_chunk

print(f"Chunk content: {chunk_content}")
tokens = token_counter(
model=generation_config.model,
messages=[{"message": chunk_content}],
)
print(f"Tokens: {tokens}")

def run(
self, system_instruction, messages, *args, **kwargs
) -> Generator[str, None, None]:
Expand Down
23 changes: 23 additions & 0 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import UUID

from fastapi import HTTPException
from litellm import token_counter

from core import R2RStreamingRAGAgent
from core.base import (
Expand Down Expand Up @@ -331,11 +332,25 @@ async def agent(

current_message = messages[-1] # type: ignore

print(f"The current message is: {current_message}")

tokens = token_counter(
model=rag_generation_config.model,
messages=[
{
"role": current_message.role,
"content": current_message.content,
}
],
)
print(f"Tokens: {tokens}")

# Save the new message to the conversation
message_id = await self.logging_connection.add_message(
conversation_id, # type: ignore
current_message, # type: ignore
parent_id=str(ids[-2]) if (ids and len(ids) > 1) else None, # type: ignore
metadata={"usage": tokens},
)

if rag_generation_config.stream:
Expand All @@ -348,6 +363,10 @@ async def agent(
value=latency,
)

print(
f"Messages: {messages} \n\nSystem Instruction: {task_prompt_override} \n\nVector Search Settings: {vector_search_settings} \n\nKG Search Settings: {kg_search_settings} \n\nRAG Generation Config: {rag_generation_config} \n\nInclude Title If Available: {include_title_if_available}"
)

async def stream_response():
async with manage_run(self.run_manager, "rag_agent"):
agent = R2RStreamingRAGAgent(
Expand All @@ -370,6 +389,10 @@ async def stream_response():

return stream_response()

print(
f"Messages: {messages} \n\nSystem Instruction: {task_prompt_override} \n\nVector Search Settings: {vector_search_settings} \n\nKG Search Settings: {kg_search_settings} \n\nRAG Generation Config: {rag_generation_config} \n\nInclude Title If Available: {include_title_if_available}"
)

results = await self.agents.rag_agent.arun(
messages=messages,
system_instruction=task_prompt_override,
Expand Down

0 comments on commit 8cc3481

Please sign in to comment.