From 3bf8510b437b89e18f8927f6252ab4ade81947c4 Mon Sep 17 00:00:00 2001 From: Alexander Joham <73483450+alexjoham@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:57:57 +0200 Subject: [PATCH] Track token usage of iris requests (#165) --- app/common/PipelineEnum.py | 16 ++++++ app/common/message_converters.py | 2 +- app/{domain => common}/pyris_message.py | 3 ++ app/common/token_usage_dto.py | 18 +++++++ app/domain/__init__.py | 1 - .../chat_pipeline_execution_base_data_dto.py | 2 +- .../chat/chat_pipeline_execution_dto.py | 2 +- app/domain/chat/interaction_suggestion_dto.py | 2 +- app/domain/status/status_update_dto.py | 2 + ...xt_exercise_chat_pipeline_execution_dto.py | 3 +- app/llm/external/model.py | 2 +- app/llm/external/ollama.py | 22 ++++++-- app/llm/external/openai_chat.py | 29 ++++++++--- .../langchain/iris_langchain_chat_model.py | 16 +++++- .../request_handler/basic_request_handler.py | 2 +- .../capability_request_handler.py | 7 ++- .../request_handler_interface.py | 2 +- app/pipeline/chat/code_feedback_pipeline.py | 8 ++- app/pipeline/chat/course_chat_pipeline.py | 13 +++-- app/pipeline/chat/exercise_chat_pipeline.py | 51 ++++++++++++++++--- .../chat/interaction_suggestion_pipeline.py | 7 ++- app/pipeline/chat/lecture_chat_pipeline.py | 6 ++- .../competency_extraction_pipeline.py | 10 ++-- app/pipeline/lecture_ingestion_pipeline.py | 22 ++++++-- app/pipeline/pipeline.py | 9 ++++ app/pipeline/shared/citation_pipeline.py | 3 ++ app/pipeline/shared/reranker_pipeline.py | 5 +- app/pipeline/shared/summary_pipeline.py | 1 + app/pipeline/text_exercise_chat_pipeline.py | 2 +- app/retrieval/lecture_retrieval.py | 19 ++++++- app/web/status/status_update.py | 8 ++- 31 files changed, 246 insertions(+), 49 deletions(-) create mode 100644 app/common/PipelineEnum.py rename app/{domain => common}/pyris_message.py (82%) create mode 100644 app/common/token_usage_dto.py diff --git a/app/common/PipelineEnum.py b/app/common/PipelineEnum.py new file mode 100644 index 00000000..3d8e101e --- /dev/null +++ b/app/common/PipelineEnum.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class PipelineEnum(str, Enum): + IRIS_CODE_FEEDBACK = "IRIS_CODE_FEEDBACK" + IRIS_CHAT_COURSE_MESSAGE = "IRIS_CHAT_COURSE_MESSAGE" + IRIS_CHAT_EXERCISE_MESSAGE = "IRIS_CHAT_EXERCISE_MESSAGE" + IRIS_INTERACTION_SUGGESTION = "IRIS_INTERACTION_SUGGESTION" + IRIS_CHAT_LECTURE_MESSAGE = "IRIS_CHAT_LECTURE_MESSAGE" + IRIS_COMPETENCY_GENERATION = "IRIS_COMPETENCY_GENERATION" + IRIS_CITATION_PIPELINE = "IRIS_CITATION_PIPELINE" + IRIS_RERANKER_PIPELINE = "IRIS_RERANKER_PIPELINE" + IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE" + IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE" + IRIS_LECTURE_INGESTION = "IRIS_LECTURE_INGESTION" + NOT_SET = "NOT_SET" diff --git a/app/common/message_converters.py b/app/common/message_converters.py index 671dd565..d96886e5 100644 --- a/app/common/message_converters.py +++ b/app/common/message_converters.py @@ -4,7 +4,7 @@ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage from app.domain.data.text_message_content_dto import TextMessageContentDTO -from app.domain.pyris_message import PyrisMessage, IrisMessageRole +from app.common.pyris_message import PyrisMessage, IrisMessageRole def convert_iris_message_to_langchain_message( diff --git a/app/domain/pyris_message.py b/app/common/pyris_message.py similarity index 82% rename from app/domain/pyris_message.py rename to app/common/pyris_message.py index 056f77ef..f18e636a 100644 --- a/app/domain/pyris_message.py +++ b/app/common/pyris_message.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field from app.domain.data.message_content_dto import MessageContentDTO +from app.common.token_usage_dto import TokenUsageDTO class IrisMessageRole(str, Enum): @@ -16,6 +17,8 @@ class IrisMessageRole(str, Enum): class PyrisMessage(BaseModel): model_config = ConfigDict(populate_by_name=True) + token_usage: TokenUsageDTO = Field(default_factory=TokenUsageDTO) + sent_at: datetime | None = Field(alias="sentAt", default=None) sender: IrisMessageRole contents: List[MessageContentDTO] = [] diff --git a/app/common/token_usage_dto.py b/app/common/token_usage_dto.py new file mode 100644 index 00000000..a0ee3eda --- /dev/null +++ b/app/common/token_usage_dto.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel, Field + +from app.common.PipelineEnum import PipelineEnum + + +class TokenUsageDTO(BaseModel): + model_info: str = Field(alias="model", default="") + num_input_tokens: int = Field(alias="numInputTokens", default=0) + cost_per_input_token: float = Field(alias="costPerMillionInputToken", default=0) + num_output_tokens: int = Field(alias="numOutputTokens", default=0) + cost_per_output_token: float = Field(alias="costPerMillionOutputToken", default=0) + pipeline: PipelineEnum = Field(alias="pipelineId", default=PipelineEnum.NOT_SET) + + def __str__(self): + return ( + f"{self.model_info}: {self.num_input_tokens} input cost: {self.cost_per_input_token}," + f" {self.num_output_tokens} output cost: {self.cost_per_output_token}, pipeline: {self.pipeline} " + ) diff --git a/app/domain/__init__.py b/app/domain/__init__.py index 27fd881d..e7b03301 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -12,6 +12,5 @@ from app.domain.chat.course_chat.course_chat_pipeline_execution_dto import ( CourseChatPipelineExecutionDTO, ) -from .pyris_message import PyrisMessage, IrisMessageRole from app.domain.data import image_message_content_dto from app.domain.feature_dto import FeatureDTO diff --git a/app/domain/chat/chat_pipeline_execution_base_data_dto.py b/app/domain/chat/chat_pipeline_execution_base_data_dto.py index e0677c76..a9bfd8d2 100644 --- a/app/domain/chat/chat_pipeline_execution_base_data_dto.py +++ b/app/domain/chat/chat_pipeline_execution_base_data_dto.py @@ -3,7 +3,7 @@ from pydantic import Field, BaseModel from app.domain import PipelineExecutionSettingsDTO -from app.domain.pyris_message import PyrisMessage +from app.common.pyris_message import PyrisMessage from app.domain.data.user_dto import UserDTO from app.domain.status.stage_dto import StageDTO diff --git a/app/domain/chat/chat_pipeline_execution_dto.py b/app/domain/chat/chat_pipeline_execution_dto.py index e3e63284..a92c8332 100644 --- a/app/domain/chat/chat_pipeline_execution_dto.py +++ b/app/domain/chat/chat_pipeline_execution_dto.py @@ -3,7 +3,7 @@ from pydantic import Field from app.domain import PipelineExecutionDTO -from app.domain.pyris_message import PyrisMessage +from app.common.pyris_message import PyrisMessage from app.domain.data.user_dto import UserDTO diff --git a/app/domain/chat/interaction_suggestion_dto.py b/app/domain/chat/interaction_suggestion_dto.py index 3835ce81..e905f83f 100644 --- a/app/domain/chat/interaction_suggestion_dto.py +++ b/app/domain/chat/interaction_suggestion_dto.py @@ -2,7 +2,7 @@ from pydantic import Field, BaseModel -from app.domain import PyrisMessage +from app.common.pyris_message import PyrisMessage class InteractionSuggestionPipelineExecutionDTO(BaseModel): diff --git a/app/domain/status/status_update_dto.py b/app/domain/status/status_update_dto.py index bb6dc3a6..80848a21 100644 --- a/app/domain/status/status_update_dto.py +++ b/app/domain/status/status_update_dto.py @@ -2,8 +2,10 @@ from pydantic import BaseModel +from app.common.token_usage_dto import TokenUsageDTO from ...domain.status.stage_dto import StageDTO class StatusUpdateDTO(BaseModel): stages: List[StageDTO] + tokens: List[TokenUsageDTO] = [] diff --git a/app/domain/text_exercise_chat_pipeline_execution_dto.py b/app/domain/text_exercise_chat_pipeline_execution_dto.py index 65e8871c..ed77892c 100644 --- a/app/domain/text_exercise_chat_pipeline_execution_dto.py +++ b/app/domain/text_exercise_chat_pipeline_execution_dto.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field -from app.domain import PipelineExecutionDTO, PyrisMessage +from app.common.pyris_message import PyrisMessage +from app.domain import PipelineExecutionDTO from app.domain.data.text_exercise_dto import TextExerciseDTO diff --git a/app/llm/external/model.py b/app/llm/external/model.py index 47b90962..3fba9e6f 100644 --- a/app/llm/external/model.py +++ b/app/llm/external/model.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from pydantic import BaseModel -from ...domain import PyrisMessage +from ...common.pyris_message import PyrisMessage from ...llm import CompletionArguments from ...llm.capability import CapabilityList diff --git a/app/llm/external/ollama.py b/app/llm/external/ollama.py index 832df17c..1b89f3c4 100644 --- a/app/llm/external/ollama.py +++ b/app/llm/external/ollama.py @@ -6,10 +6,11 @@ from ollama import Client, Message from ...common.message_converters import map_role_to_str, map_str_to_role +from ...common.pyris_message import PyrisMessage +from ...common.token_usage_dto import TokenUsageDTO from ...domain.data.json_message_content_dto import JsonMessageContentDTO from ...domain.data.text_message_content_dto import TextMessageContentDTO from ...domain.data.image_message_content_dto import ImageMessageContentDTO -from ...domain import PyrisMessage from ...llm import CompletionArguments from ...llm.external.model import ChatModel, CompletionModel, EmbeddingModel @@ -57,15 +58,23 @@ def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]: return messages_to_return -def convert_to_iris_message(message: Message) -> PyrisMessage: +def convert_to_iris_message( + message: Message, num_input_tokens: int, num_output_tokens: int, model: str +) -> PyrisMessage: """ Convert a Message to a PyrisMessage """ contents = [TextMessageContentDTO(text_content=message["content"])] + tokens = TokenUsageDTO( + numInputTokens=num_input_tokens, + numOutputTokens=num_output_tokens, + model=model, + ) return PyrisMessage( sender=map_str_to_role(message["role"]), contents=contents, - send_at=datetime.now(), + sentAt=datetime.now(), + token_usage=tokens, ) @@ -108,7 +117,12 @@ def chat( format="json" if arguments.response_format == "JSON" else "", options=self.options, ) - return convert_to_iris_message(response["message"]) + return convert_to_iris_message( + response.get("message"), + response.get("prompt_eval_count", 0), + response.get("eval_count", 0), + response.get("model", self.model), + ) def embed(self, text: str) -> list[float]: response = self._client.embeddings( diff --git a/app/llm/external/openai_chat.py b/app/llm/external/openai_chat.py index c3575919..75b6f3b2 100644 --- a/app/llm/external/openai_chat.py +++ b/app/llm/external/openai_chat.py @@ -1,7 +1,7 @@ import logging import time from datetime import datetime -from typing import Literal, Any +from typing import Literal, Any, Optional from openai import ( OpenAI, @@ -12,12 +12,14 @@ ContentFilterFinishReasonError, ) from openai.lib.azure import AzureOpenAI +from openai.types import CompletionUsage from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam from openai.types.shared_params import ResponseFormatJSONObject from ...common.message_converters import map_str_to_role, map_role_to_str from app.domain.data.text_message_content_dto import TextMessageContentDTO -from ...domain import PyrisMessage +from ...common.pyris_message import PyrisMessage +from ...common.token_usage_dto import TokenUsageDTO from ...domain.data.image_message_content_dto import ImageMessageContentDTO from ...domain.data.json_message_content_dto import JsonMessageContentDTO from ...llm import CompletionArguments @@ -67,15 +69,28 @@ def convert_to_open_ai_messages( return openai_messages -def convert_to_iris_message(message: ChatCompletionMessage) -> PyrisMessage: +def convert_to_iris_message( + message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str +) -> PyrisMessage: """ Convert a ChatCompletionMessage to a PyrisMessage """ - return PyrisMessage( + num_input_tokens = getattr(usage, "prompt_tokens", 0) + num_output_tokens = getattr(usage, "completion_tokens", 0) + + tokens = TokenUsageDTO( + model=model, + numInputTokens=num_input_tokens, + numOutputTokens=num_output_tokens, + ) + + message = PyrisMessage( sender=map_str_to_role(message.role), contents=[TextMessageContentDTO(textContent=message.content)], - send_at=datetime.now(), + sentAt=datetime.now(), + token_usage=tokens, ) + return message class OpenAIChatModel(ChatModel): @@ -113,13 +128,15 @@ def chat( max_tokens=arguments.max_tokens, ) choice = response.choices[0] + usage = response.usage + model = response.model if choice.finish_reason == "content_filter": # I figured that an openai error would be automatically raised if the content filter activated, # but it seems that that is not the case. # We don't want to retry because the same message will likely be rejected again. # Raise an exception to trigger the global error handler and report a fatal error to the client. raise ContentFilterFinishReasonError() - return convert_to_iris_message(choice.message) + return convert_to_iris_message(choice.message, usage, model) except ( APIError, APITimeoutError, diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py index 9dc85d38..c8b1c6da 100644 --- a/app/llm/langchain/iris_langchain_chat_model.py +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional, Any from langchain_core.callbacks import CallbackManagerForLLMRun @@ -5,13 +6,14 @@ BaseChatModel, ) from langchain_core.messages import BaseMessage -from langchain_core.outputs import ChatResult -from langchain_core.outputs.chat_generation import ChatGeneration +from langchain_core.outputs import ChatResult, ChatGeneration +from app.common.PipelineEnum import PipelineEnum from ...common import ( convert_iris_message_to_langchain_message, convert_langchain_message_to_iris_message, ) +from app.common.token_usage_dto import TokenUsageDTO from ...llm import RequestHandler, CompletionArguments @@ -20,6 +22,8 @@ class IrisLangchainChatModel(BaseChatModel): request_handler: RequestHandler completion_args: CompletionArguments + tokens: TokenUsageDTO = None + logger = logging.getLogger(__name__) def __init__( self, @@ -43,6 +47,14 @@ def _generate( iris_message = self.request_handler.chat(iris_messages, self.completion_args) base_message = convert_iris_message_to_langchain_message(iris_message) chat_generation = ChatGeneration(message=base_message) + self.tokens = TokenUsageDTO( + model=iris_message.token_usage.model_info, + numInputTokens=iris_message.token_usage.num_input_tokens, + costPerMillionInputToken=iris_message.token_usage.cost_per_input_token, + numOutputTokens=iris_message.token_usage.num_output_tokens, + costPerMillionOutputToken=iris_message.token_usage.cost_per_output_token, + pipeline=PipelineEnum.NOT_SET, + ) return ChatResult(generations=[chat_generation]) @property diff --git a/app/llm/request_handler/basic_request_handler.py b/app/llm/request_handler/basic_request_handler.py index 5756346f..1342a71c 100644 --- a/app/llm/request_handler/basic_request_handler.py +++ b/app/llm/request_handler/basic_request_handler.py @@ -1,6 +1,6 @@ from typing import Optional -from app.domain import PyrisMessage +from app.common.pyris_message import PyrisMessage from app.domain.data.image_message_content_dto import ImageMessageContentDTO from app.llm.request_handler import RequestHandler from app.llm.completion_arguments import CompletionArguments diff --git a/app/llm/request_handler/capability_request_handler.py b/app/llm/request_handler/capability_request_handler.py index 1ed05b3d..97d6a36f 100644 --- a/app/llm/request_handler/capability_request_handler.py +++ b/app/llm/request_handler/capability_request_handler.py @@ -1,6 +1,6 @@ from enum import Enum -from app.domain import PyrisMessage +from app.common.pyris_message import PyrisMessage from app.llm.capability import RequirementList from app.llm.external.model import ( ChatModel, @@ -44,7 +44,10 @@ def chat( self, messages: list[PyrisMessage], arguments: CompletionArguments ) -> PyrisMessage: llm = self._select_model(ChatModel) - return llm.chat(messages, arguments) + message = llm.chat(messages, arguments) + message.token_usage.cost_per_input_token = llm.capabilities.input_cost.value + message.token_usage.cost_per_output_token = llm.capabilities.output_cost.value + return message def embed(self, text: str) -> list[float]: llm = self._select_model(EmbeddingModel) diff --git a/app/llm/request_handler/request_handler_interface.py b/app/llm/request_handler/request_handler_interface.py index 390a4cbc..89dccedb 100644 --- a/app/llm/request_handler/request_handler_interface.py +++ b/app/llm/request_handler/request_handler_interface.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from typing import Optional -from ...domain import PyrisMessage +from ...common.pyris_message import PyrisMessage from ...domain.data.image_message_content_dto import ImageMessageContentDTO from ...llm import CompletionArguments diff --git a/app/pipeline/chat/code_feedback_pipeline.py b/app/pipeline/chat/code_feedback_pipeline.py index 8ed5d9ba..90c27ecc 100644 --- a/app/pipeline/chat/code_feedback_pipeline.py +++ b/app/pipeline/chat/code_feedback_pipeline.py @@ -8,11 +8,13 @@ from langsmith import traceable from pydantic import BaseModel -from ...domain import PyrisMessage +from ...common.pyris_message import PyrisMessage from ...domain.data.build_log_entry import BuildLogEntryDTO from ...domain.data.feedback_dto import FeedbackDTO +from app.common.token_usage_dto import TokenUsageDTO from ...llm import CapabilityRequestHandler, RequirementList from ...llm import CompletionArguments +from app.common.PipelineEnum import PipelineEnum from ...llm.langchain import IrisLangchainChatModel from ...pipeline import Pipeline from ...web.status.status_update import StatusCallback @@ -40,6 +42,7 @@ class CodeFeedbackPipeline(Pipeline): callback: StatusCallback default_prompt: PromptTemplate output_parser: StrOutputParser + tokens: TokenUsageDTO def __init__(self, callback: Optional[StatusCallback] = None): super().__init__(implementation_id="code_feedback_pipeline_reference_impl") @@ -141,4 +144,7 @@ def __call__( } ) ) + token_usage = self.llm.tokens + token_usage.pipeline = PipelineEnum.IRIS_CODE_FEEDBACK + self.tokens = token_usage return response.replace("{", "{{").replace("}", "}}") diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index d2928df7..b21616d0 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -22,7 +22,7 @@ from .lecture_chat_pipeline import LectureChatPipeline from ..shared.citation_pipeline import CitationPipeline from ...common import convert_iris_message_to_langchain_message -from ...domain import PyrisMessage +from ...common.pyris_message import PyrisMessage from ...llm import CapabilityRequestHandler, RequirementList from ..prompts.iris_course_chat_prompts import ( tell_iris_initial_system_prompt, @@ -41,6 +41,7 @@ elicit_begin_agent_jol_prompt, ) from ...domain import CourseChatPipelineExecutionDTO +from app.common.PipelineEnum import PipelineEnum from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase from ...vector_database.lecture_schema import LectureSchema @@ -107,6 +108,7 @@ def __init__(self, callback: CourseChatStatusCallback, variant: str = "default") # Create the pipeline self.pipeline = self.llm | StrOutputParser() + self.tokens = [] def __repr__(self): return f"{self.__class__.__name__}(llm={self.llm})" @@ -406,14 +408,18 @@ def lecture_content_retrieval() -> str: self.callback.in_progress() for step in agent_executor.iter(params): print("STEP:", step) + self._append_tokens( + self.llm.tokens, PipelineEnum.IRIS_CHAT_COURSE_MESSAGE + ) if step.get("output", None): out = step["output"] if self.retrieved_paragraphs: self.callback.in_progress("Augmenting response ...") out = self.citation_pipeline(self.retrieved_paragraphs, out) + self.tokens.extend(self.citation_pipeline.tokens) - self.callback.done("Response created", final_result=out) + self.callback.done("Response created", final_result=out, tokens=self.tokens) # try: # # if out: @@ -440,7 +446,8 @@ def lecture_content_retrieval() -> str: ) traceback.print_exc() self.callback.error( - "An error occurred while running the course chat pipeline." + "An error occurred while running the course chat pipeline.", + tokens=self.tokens, ) def should_allow_lecture_tool(self, course_id: int) -> bool: diff --git a/app/pipeline/chat/exercise_chat_pipeline.py b/app/pipeline/chat/exercise_chat_pipeline.py index 8043e9ad..386cd27a 100644 --- a/app/pipeline/chat/exercise_chat_pipeline.py +++ b/app/pipeline/chat/exercise_chat_pipeline.py @@ -24,8 +24,8 @@ from ..shared.citation_pipeline import CitationPipeline from ..shared.reranker_pipeline import RerankerPipeline from ...common import convert_iris_message_to_langchain_message +from ...common.pyris_message import PyrisMessage from ...domain import ExerciseChatPipelineExecutionDTO -from ...domain import PyrisMessage from ...domain.chat.interaction_suggestion_dto import ( InteractionSuggestionPipelineExecutionDTO, ) @@ -34,6 +34,7 @@ from ...domain.data.programming_submission_dto import ProgrammingSubmissionDTO from ...llm import CapabilityRequestHandler, RequirementList from ...llm import CompletionArguments +from app.common.PipelineEnum import PipelineEnum from ...llm.langchain import IrisLangchainChatModel from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase @@ -78,6 +79,7 @@ def __init__(self, callback: ExerciseChatStatusCallback): self.code_feedback_pipeline = CodeFeedbackPipeline() self.pipeline = self.llm | StrOutputParser() self.citation_pipeline = CitationPipeline() + self.tokens = [] def __repr__(self): return f"{self.__class__.__name__}(llm={self.llm})" @@ -98,7 +100,9 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO): ) self._run_exercise_chat_pipeline(dto, should_execute_lecture_pipeline), self.callback.done( - "Generated response", final_result=self.exercise_chat_response + "Generated response", + final_result=self.exercise_chat_response, + tokens=self.tokens, ) try: @@ -112,7 +116,15 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO): suggestion_dto.last_message = self.exercise_chat_response suggestion_dto.problem_statement = dto.exercise.problem_statement suggestions = self.suggestion_pipeline(suggestion_dto) - self.callback.done(final_result=None, suggestions=suggestions) + if self.suggestion_pipeline.tokens is not None: + tokens = [self.suggestion_pipeline.tokens] + else: + tokens = [] + self.callback.done( + final_result=None, + suggestions=suggestions, + tokens=tokens, + ) else: # This should never happen but whatever self.callback.skip( @@ -125,11 +137,15 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO): ) traceback.print_exc() self.callback.error( - "Generating interaction suggestions failed.", exception=e + "Generating interaction suggestions failed.", + exception=e, + tokens=self.tokens, ) except Exception as e: traceback.print_exc() - self.callback.error(f"Failed to generate response: {e}", exception=e) + self.callback.error( + f"Failed to generate response: {e}", exception=e, tokens=self.tokens + ) def _run_exercise_chat_pipeline( self, @@ -200,6 +216,8 @@ def _run_exercise_chat_pipeline( if submission: try: feedback = future_feedback.result() + if self.code_feedback_pipeline.tokens is not None: + self.tokens.append(self.code_feedback_pipeline.tokens) self.prompt += SystemMessagePromptTemplate.from_template( "Another AI has checked the code of the student and has found the following issues. " "Use this information to help the student. " @@ -212,7 +230,9 @@ def _run_exercise_chat_pipeline( ) except Exception as e: self.callback.error( - f"Failed to look up files in the repository: {e}", exception=e + f"Failed to look up files in the repository: {e}", + exception=e, + tokens=self.tokens, ) return @@ -220,13 +240,20 @@ def _run_exercise_chat_pipeline( if should_execute_lecture_pipeline: try: self.retrieved_lecture_chunks = future_lecture.result() + if ( + self.retriever.tokens is not None + and len(self.retriever.tokens) > 0 + ): + self.tokens.extend(self.retriever.tokens) if len(self.retrieved_lecture_chunks) > 0: self._add_relevant_chunks_to_prompt( self.retrieved_lecture_chunks ) except Exception as e: self.callback.error( - f"Failed to retrieve lecture chunks: {e}", exception=e + f"Failed to retrieve lecture chunks: {e}", + exception=e, + tokens=self.tokens, ) return @@ -252,6 +279,9 @@ def _run_exercise_chat_pipeline( .with_config({"run_name": "Response Drafting"}) .invoke({}) ) + self._append_tokens( + self.llm.tokens, PipelineEnum.IRIS_CHAT_EXERCISE_MESSAGE + ) self.callback.done() self.prompt = ChatPromptTemplate.from_messages( [ @@ -266,6 +296,9 @@ def _run_exercise_chat_pipeline( .with_config({"run_name": "Response Refining"}) .invoke({}) ) + self._append_tokens( + self.llm.tokens, PipelineEnum.IRIS_CHAT_EXERCISE_MESSAGE + ) if "!ok!" in guide_response: print("Response is ok and not rewritten!!!") @@ -274,7 +307,9 @@ def _run_exercise_chat_pipeline( print("Response is rewritten.") self.exercise_chat_response = guide_response except Exception as e: - self.callback.error(f"Failed to create response: {e}", exception=e) + self.callback.error( + f"Failed to create response: {e}", exception=e, tokens=self.tokens + ) # print stack trace traceback.print_exc() return "Failed to generate response" diff --git a/app/pipeline/chat/interaction_suggestion_pipeline.py b/app/pipeline/chat/interaction_suggestion_pipeline.py index 86635166..620728de 100644 --- a/app/pipeline/chat/interaction_suggestion_pipeline.py +++ b/app/pipeline/chat/interaction_suggestion_pipeline.py @@ -13,10 +13,11 @@ from pydantic.v1 import Field, BaseModel from ...common import convert_iris_message_to_langchain_message -from ...domain import PyrisMessage from app.domain.chat.interaction_suggestion_dto import ( InteractionSuggestionPipelineExecutionDTO, ) +from app.common.token_usage_dto import TokenUsageDTO +from ...common.pyris_message import PyrisMessage from ...llm import CapabilityRequestHandler, RequirementList from ..prompts.iris_interaction_suggestion_prompts import ( course_chat_begin_prompt, @@ -34,6 +35,7 @@ ) from ...llm import CompletionArguments +from app.common.PipelineEnum import PipelineEnum from ...llm.langchain import IrisLangchainChatModel from ..pipeline import Pipeline @@ -52,6 +54,7 @@ class InteractionSuggestionPipeline(Pipeline): pipeline: Runnable prompt: ChatPromptTemplate variant: str + tokens: TokenUsageDTO def __init__(self, variant: str = "default"): super().__init__(implementation_id="interaction_suggestion_pipeline") @@ -164,6 +167,8 @@ def __call__( self.prompt = ChatPromptTemplate.from_messages(prompt_val) response: dict = (self.prompt | self.pipeline).invoke({}) + self.tokens = self.llm.tokens + self.tokens.pipeline = PipelineEnum.IRIS_INTERACTION_SUGGESTION return response["questions"] except Exception as e: logger.error( diff --git a/app/pipeline/chat/lecture_chat_pipeline.py b/app/pipeline/chat/lecture_chat_pipeline.py index 51693009..22eb8c7a 100644 --- a/app/pipeline/chat/lecture_chat_pipeline.py +++ b/app/pipeline/chat/lecture_chat_pipeline.py @@ -11,11 +11,12 @@ from ..shared.citation_pipeline import CitationPipeline from ...common import convert_iris_message_to_langchain_message -from ...domain import PyrisMessage +from ...common.pyris_message import PyrisMessage from ...domain.chat.lecture_chat.lecture_chat_pipeline_execution_dto import ( LectureChatPipelineExecutionDTO, ) from ...llm import CapabilityRequestHandler, RequirementList +from app.common.PipelineEnum import PipelineEnum from ...retrieval.lecture_retrieval import LectureRetrieval from ...vector_database.database import VectorDatabase from ...vector_database.lecture_schema import LectureSchema @@ -74,6 +75,7 @@ def __init__(self): self.retriever = LectureRetrieval(self.db.client) self.pipeline = self.llm | StrOutputParser() self.citation_pipeline = CitationPipeline() + self.tokens = [] def __repr__(self): return f"{self.__class__.__name__}(llm={self.llm})" @@ -114,9 +116,11 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO): self.prompt = ChatPromptTemplate.from_messages(prompt_val) try: response = (self.prompt | self.pipeline).invoke({}) + self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_CHAT_LECTURE_MESSAGE) response_with_citation = self.citation_pipeline( retrieved_lecture_chunks, response ) + self.tokens.extend(self.citation_pipeline.tokens) logger.info(f"Response from lecture chat pipeline: {response}") return response_with_citation except Exception as e: diff --git a/app/pipeline/competency_extraction_pipeline.py b/app/pipeline/competency_extraction_pipeline.py index a2288ab5..12efb65f 100644 --- a/app/pipeline/competency_extraction_pipeline.py +++ b/app/pipeline/competency_extraction_pipeline.py @@ -6,10 +6,10 @@ ChatPromptTemplate, ) +from app.common.PipelineEnum import PipelineEnum +from app.common.pyris_message import PyrisMessage, IrisMessageRole from app.domain import ( CompetencyExtractionPipelineExecutionDTO, - PyrisMessage, - IrisMessageRole, ) from app.domain.data.text_message_content_dto import TextMessageContentDTO from app.domain.data.competency_dto import Competency @@ -38,6 +38,7 @@ def __init__(self, callback: Optional[CompetencyExtractionCallback] = None): ) ) self.output_parser = PydanticOutputParser(pydantic_object=Competency) + self.tokens = [] def __call__( self, @@ -76,6 +77,9 @@ def __call__( response = self.request_handler.chat( [prompt], CompletionArguments(temperature=0.4) ) + self._append_tokens( + response.token_usage, PipelineEnum.IRIS_COMPETENCY_GENERATION + ) response = response.contents[0].text_content generated_competencies: list[Competency] = [] @@ -98,4 +102,4 @@ def __call__( continue logger.debug(f"Generated competency: {competency}") generated_competencies.append(competency) - self.callback.done(final_result=generated_competencies) + self.callback.done(final_result=generated_competencies, tokens=self.tokens) diff --git a/app/pipeline/lecture_ingestion_pipeline.py b/app/pipeline/lecture_ingestion_pipeline.py index 0b468a41..73d6371b 100644 --- a/app/pipeline/lecture_ingestion_pipeline.py +++ b/app/pipeline/lecture_ingestion_pipeline.py @@ -10,7 +10,7 @@ from weaviate import WeaviateClient from weaviate.classes.query import Filter from . import Pipeline -from ..domain import IrisMessageRole, PyrisMessage +from ..common.pyris_message import PyrisMessage, IrisMessageRole from ..domain.data.image_message_content_dto import ImageMessageContentDTO from ..domain.data.lecture_unit_dto import LectureUnitDTO @@ -18,6 +18,7 @@ IngestionPipelineExecutionDto, ) from ..domain.data.text_message_content_dto import TextMessageContentDTO +from app.common.PipelineEnum import PipelineEnum from ..llm.langchain import IrisLangchainChatModel from ..vector_database.lecture_schema import init_lecture_schema, LectureSchema from ..ingestion.abstract_ingestion import AbstractIngestion @@ -112,6 +113,7 @@ def __init__( request_handler=request_handler, completion_args=completion_args ) self.pipeline = self.llm | StrOutputParser() + self.tokens = [] def __call__(self) -> bool: try: @@ -139,7 +141,7 @@ def __call__(self) -> bool: self.callback.done("Lecture Chunking and interpretation Finished") self.callback.in_progress("Ingesting lecture chunks into database...") self.batch_update(chunks) - self.callback.done("Lecture Ingestion Finished") + self.callback.done("Lecture Ingestion Finished", tokens=self.tokens) logger.info( f"Lecture ingestion pipeline finished Successfully for course " f"{self.dto.lecture_units[0].course_name}" @@ -148,7 +150,9 @@ def __call__(self) -> bool: except Exception as e: logger.error(f"Error updating lecture unit: {e}") self.callback.error( - f"Failed to ingest lectures into the database: {e}", exception=e + f"Failed to ingest lectures into the database: {e}", + exception=e, + tokens=self.tokens, ) return False @@ -170,7 +174,9 @@ def batch_update(self, chunks): except Exception as e: logger.error(f"Error updating lecture unit: {e}") self.callback.error( - f"Failed to ingest lectures into the database: {e}", exception=e + f"Failed to ingest lectures into the database: {e}", + exception=e, + tokens=self.tokens, ) def chunk_data( @@ -245,6 +251,9 @@ def interpret_image( response = self.llm_vision.chat( [iris_message], CompletionArguments(temperature=0, max_tokens=512) ) + self._append_tokens( + response.token_usage, PipelineEnum.IRIS_LECTURE_INGESTION + ) except Exception as e: logger.error(f"Error interpreting image: {e}") return None @@ -273,9 +282,11 @@ def merge_page_content_and_image_interpretation( image_interpretation=image_interpretation, ) prompt = ChatPromptTemplate.from_messages(prompt_val) - return clean( + clean_output = clean( (prompt | self.pipeline).invoke({}), bullets=True, extra_whitespace=True ) + self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_LECTURE_INGESTION) + return clean_output def get_course_language(self, page_content: str) -> str: """ @@ -292,6 +303,7 @@ def get_course_language(self, page_content: str) -> str: response = self.llm_chat.chat( [iris_message], CompletionArguments(temperature=0, max_tokens=20) ) + self._append_tokens(response.token_usage, PipelineEnum.IRIS_LECTURE_INGESTION) return response.contents[0].text_content def delete_old_lectures(self): diff --git a/app/pipeline/pipeline.py b/app/pipeline/pipeline.py index 8f2249b7..428dcf62 100644 --- a/app/pipeline/pipeline.py +++ b/app/pipeline/pipeline.py @@ -1,10 +1,15 @@ from abc import ABCMeta +from typing import List + +from app.common.token_usage_dto import TokenUsageDTO +from app.common.PipelineEnum import PipelineEnum class Pipeline(metaclass=ABCMeta): """Abstract class for all pipelines""" implementation_id: str + tokens: List[TokenUsageDTO] def __init__(self, implementation_id=None, **kwargs): self.implementation_id = implementation_id @@ -27,3 +32,7 @@ def __init_subclass__(cls, **kwargs): raise NotImplementedError( "Subclasses of Pipeline interface must implement the __call__ method." ) + + def _append_tokens(self, tokens: TokenUsageDTO, pipeline: PipelineEnum) -> None: + tokens.pipeline = pipeline + self.tokens.append(tokens) diff --git a/app/pipeline/shared/citation_pipeline.py b/app/pipeline/shared/citation_pipeline.py index 6a4aab38..b630bd4d 100644 --- a/app/pipeline/shared/citation_pipeline.py +++ b/app/pipeline/shared/citation_pipeline.py @@ -7,6 +7,7 @@ from langchain_core.runnables import Runnable from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments +from app.common.PipelineEnum import PipelineEnum from app.llm.langchain import IrisLangchainChatModel from app.pipeline import Pipeline @@ -38,6 +39,7 @@ def __init__(self): with open(prompt_file_path, "r") as file: self.prompt_str = file.read() self.pipeline = self.llm | StrOutputParser() + self.tokens = [] def __repr__(self): return f"{self.__class__.__name__}(llm={self.llm})" @@ -83,6 +85,7 @@ def __call__( response = (self.default_prompt | self.pipeline).invoke( {"Answer": answer, "Paragraphs": paras} ) + self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_CITATION_PIPELINE) if response == "!NONE!": return answer print(response) diff --git a/app/pipeline/shared/reranker_pipeline.py b/app/pipeline/shared/reranker_pipeline.py index 178bb4e6..f915be32 100644 --- a/app/pipeline/shared/reranker_pipeline.py +++ b/app/pipeline/shared/reranker_pipeline.py @@ -7,8 +7,9 @@ from langchain_core.runnables import Runnable from langsmith import traceable -from app.domain import PyrisMessage +from app.common.pyris_message import PyrisMessage from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments +from app.common.PipelineEnum import PipelineEnum from app.llm.langchain import IrisLangchainChatModel from app.pipeline import Pipeline from app.pipeline.chat.output_models.output_models.selected_paragraphs import ( @@ -56,6 +57,7 @@ def __init__(self): ) logger.debug(self.output_parser.get_format_instructions()) self.pipeline = self.llm | self.output_parser + self.tokens = [] def __repr__(self): return f"{self.__class__.__name__}(llm={self.llm})" @@ -108,4 +110,5 @@ def __call__( prompt = self.default_prompt response = (prompt | self.pipeline).invoke(data) + self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_RERANKER_PIPELINE) return response.selected_paragraphs diff --git a/app/pipeline/shared/summary_pipeline.py b/app/pipeline/shared/summary_pipeline.py index 382881a2..6a7f49a0 100644 --- a/app/pipeline/shared/summary_pipeline.py +++ b/app/pipeline/shared/summary_pipeline.py @@ -45,6 +45,7 @@ def __init__(self): ) # Create the pipeline self.pipeline = self.prompt | self.llm | StrOutputParser() + self.tokens = [] def __repr__(self): return f"{self.__class__.__name__}(llm={self.llm})" diff --git a/app/pipeline/text_exercise_chat_pipeline.py b/app/pipeline/text_exercise_chat_pipeline.py index 5d27fc71..9bcf2431 100644 --- a/app/pipeline/text_exercise_chat_pipeline.py +++ b/app/pipeline/text_exercise_chat_pipeline.py @@ -2,9 +2,9 @@ from datetime import datetime from typing import Optional, List, Tuple +from app.common.pyris_message import PyrisMessage, IrisMessageRole from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments from app.pipeline import Pipeline -from app.domain import PyrisMessage, IrisMessageRole from app.domain.text_exercise_chat_pipeline_execution_dto import ( TextExerciseChatPipelineExecutionDTO, ) diff --git a/app/retrieval/lecture_retrieval.py b/app/retrieval/lecture_retrieval.py index df832ebc..7bcd8ce0 100644 --- a/app/retrieval/lecture_retrieval.py +++ b/app/retrieval/lecture_retrieval.py @@ -6,10 +6,12 @@ from weaviate.classes.query import Filter from ..common import convert_iris_message_to_langchain_message +from app.common.token_usage_dto import TokenUsageDTO +from app.common.PipelineEnum import PipelineEnum +from ..common.pyris_message import PyrisMessage from ..llm.langchain import IrisLangchainChatModel from ..pipeline import Pipeline -from app.domain import PyrisMessage from app.llm import ( BasicRequestHandler, CompletionArguments, @@ -81,6 +83,8 @@ class LectureRetrieval(Pipeline): Class for retrieving lecture data from the database. """ + tokens: List[TokenUsageDTO] + def __init__(self, client: WeaviateClient, **kwargs): super().__init__(implementation_id="lecture_retrieval_pipeline") request_handler = CapabilityRequestHandler( @@ -98,6 +102,7 @@ def __init__(self, client: WeaviateClient, **kwargs): self.pipeline = self.llm | StrOutputParser() self.collection = init_lecture_schema(client) self.reranker_pipeline = RerankerPipeline() + self.tokens = [] @traceable(name="Full Lecture Retrieval") def __call__( @@ -236,6 +241,9 @@ def rewrite_student_query( prompt = ChatPromptTemplate.from_messages(prompt_val) try: response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE + self.tokens.append(self.llm.tokens) logger.info(f"Response from exercise chat pipeline: {response}") return response except Exception as e: @@ -273,6 +281,9 @@ def rewrite_student_query_with_exercise_context( prompt = ChatPromptTemplate.from_messages(prompt_val) try: response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE + self.tokens.append(self.llm.tokens) logger.info(f"Response from exercise chat pipeline: {response}") return response except Exception as e: @@ -308,6 +319,9 @@ def rewrite_elaborated_query( prompt = ChatPromptTemplate.from_messages(prompt_val) try: response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE + self.tokens.append(self.llm.tokens) logger.info(f"Response from retirval pipeline: {response}") return response except Exception as e: @@ -347,6 +361,9 @@ def rewrite_elaborated_query_with_exercise_context( ) try: response = (prompt | self.pipeline).invoke({}) + token_usage = self.llm.tokens + token_usage.pipeline = PipelineEnum.IRIS_LECTURE_RETRIEVAL_PIPELINE + self.tokens.append(self.llm.tokens) logger.info(f"Response from exercise chat pipeline: {response}") return response except Exception as e: diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index 30a73f13..a122b6e7 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -5,6 +5,7 @@ import requests from abc import ABC +from app.common.token_usage_dto import TokenUsageDTO from app.domain.status.competency_extraction_status_update_dto import ( CompetencyExtractionStatusUpdateDTO, ) @@ -99,6 +100,7 @@ def done( message: Optional[str] = None, final_result: Optional[str] = None, suggestions: Optional[List[str]] = None, + tokens: Optional[List[TokenUsageDTO]] = None, next_stage_message: Optional[str] = None, start_next_stage: bool = True, ): @@ -110,6 +112,7 @@ def done( self.stage.state = StageStateEnum.DONE self.stage.message = message self.status.result = final_result + self.status.tokens = tokens or self.status.tokens if hasattr(self.status, "suggestions"): self.status.suggestions = suggestions next_stage = self.get_next_stage() @@ -121,7 +124,9 @@ def done( self.stage.state = StageStateEnum.IN_PROGRESS self.on_status_update() - def error(self, message: str, exception=None): + def error( + self, message: str, exception=None, tokens: Optional[List[TokenUsageDTO]] = None + ): """ Transition the current stage to ERROR and update the status. Set all later stages to SKIPPED if an error occurs. @@ -129,6 +134,7 @@ def error(self, message: str, exception=None): self.stage.state = StageStateEnum.ERROR self.stage.message = message self.status.result = None + self.status.tokens = tokens or self.status.tokens # Set all subsequent stages to SKIPPED if an error occurs rest_of_index = ( self.current_stage_index + 1