From e241d457b86e97ad4df94fa91e11db80ca28d552 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Fri, 11 Oct 2024 19:40:11 +0200 Subject: [PATCH 1/2] Fix lint errors --- app/domain/data/token_usage_dto.py | 2 +- app/llm/external/LLMTokenCount.py | 8 +++++++- app/llm/external/ollama.py | 11 +++++++++-- app/llm/external/openai_chat.py | 15 ++++++++++----- app/llm/langchain/iris_langchain_chat_model.py | 10 ++++++---- app/pipeline/chat/exercise_chat_pipeline.py | 12 ++++++++---- app/pipeline/competency_extraction_pipeline.py | 11 ++++++----- 7 files changed, 47 insertions(+), 22 deletions(-) diff --git a/app/domain/data/token_usage_dto.py b/app/domain/data/token_usage_dto.py index cc98c8af..c7e1868c 100644 --- a/app/domain/data/token_usage_dto.py +++ b/app/domain/data/token_usage_dto.py @@ -7,4 +7,4 @@ class TokenUsageDTO(BaseModel): model_info: str num_input_tokens: int num_output_tokens: int - pipeline: PipelineEnum \ No newline at end of file + pipeline: PipelineEnum diff --git a/app/llm/external/LLMTokenCount.py b/app/llm/external/LLMTokenCount.py index e82b02af..7570eddb 100644 --- a/app/llm/external/LLMTokenCount.py +++ b/app/llm/external/LLMTokenCount.py @@ -8,7 +8,13 @@ class LLMTokenCount: num_output_tokens: int pipeline: PipelineEnum - def __init__(self, model_info: str, num_input_tokens: int, num_output_tokens: int, pipeline: PipelineEnum): + def __init__( + self, + model_info: str, + num_input_tokens: int, + num_output_tokens: int, + pipeline: PipelineEnum, + ): self.model_info = model_info self.num_input_tokens = num_input_tokens self.num_output_tokens = num_output_tokens diff --git a/app/llm/external/ollama.py b/app/llm/external/ollama.py index 89d126a6..8474b8a1 100644 --- a/app/llm/external/ollama.py +++ b/app/llm/external/ollama.py @@ -57,7 +57,9 @@ def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]: return messages_to_return -def convert_to_iris_message(message: Message, num_input_tokens: int, num_output_tokens: int, model: str) -> PyrisMessage: +def convert_to_iris_message( + message: Message, num_input_tokens: int, num_output_tokens: int, model: str +) -> PyrisMessage: """ Convert a Message to a PyrisMessage """ @@ -111,7 +113,12 @@ def chat( format="json" if arguments.response_format == "JSON" else "", options=self.options, ) - return convert_to_iris_message(response["message"], response["prompt_eval_count"], response["eval_count"], response["model"]) + return convert_to_iris_message( + response["message"], + response["prompt_eval_count"], + response["eval_count"], + response["model"], + ) def embed(self, text: str) -> list[float]: response = self._client.embeddings( diff --git a/app/llm/external/openai_chat.py b/app/llm/external/openai_chat.py index da5a8c2e..99d6c4b6 100644 --- a/app/llm/external/openai_chat.py +++ b/app/llm/external/openai_chat.py @@ -62,12 +62,14 @@ def convert_to_open_ai_messages( return openai_messages -def convert_to_iris_message(message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str) -> PyrisMessage: +def convert_to_iris_message( + message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str +) -> PyrisMessage: """ Convert a ChatCompletionMessage to a PyrisMessage """ - num_input_tokens = getattr(usage, 'prompt_tokens', -1) - num_output_tokens = getattr(usage, 'completion_tokens', -1) + num_input_tokens = getattr(usage, "prompt_tokens", -1) + num_output_tokens = getattr(usage, "completion_tokens", -1) message = PyrisMessage( sender=map_str_to_role(message.role), @@ -75,10 +77,11 @@ def convert_to_iris_message(message: ChatCompletionMessage, usage: Optional[Comp send_at=datetime.now(), num_input_tokens=num_input_tokens, num_output_tokens=num_output_tokens, - model_info=model + model_info=model, ) return message + class OpenAIChatModel(ChatModel): model: str api_key: str @@ -110,7 +113,9 @@ def chat( temperature=arguments.temperature, max_tokens=arguments.max_tokens, ) - return convert_to_iris_message(response.choices[0].message, response.usage, response.model) + return convert_to_iris_message( + response.choices[0].message, response.usage, response.model + ) except Exception as e: wait_time = initial_delay * (backoff_factor**attempt) logging.warning(f"Exception on attempt {attempt + 1}: {e}") diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py index c5cd9273..50c1cb0a 100644 --- a/app/llm/langchain/iris_langchain_chat_model.py +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -45,10 +45,12 @@ def _generate( iris_message = self.request_handler.chat(iris_messages, self.completion_args) base_message = convert_iris_message_to_langchain_message(iris_message) chat_generation = ChatGeneration(message=base_message) - self.tokens = LLMTokenCount(model_info=iris_message.model_info, - num_input_tokens=iris_message.num_input_tokens, - num_output_tokens=iris_message.num_output_tokens, - pipeline=PipelineEnum.NOT_SET) + self.tokens = LLMTokenCount( + model_info=iris_message.model_info, + num_input_tokens=iris_message.num_input_tokens, + num_output_tokens=iris_message.num_output_tokens, + pipeline=PipelineEnum.NOT_SET, + ) return ChatResult(generations=[chat_generation]) @property diff --git a/app/pipeline/chat/exercise_chat_pipeline.py b/app/pipeline/chat/exercise_chat_pipeline.py index 2aee5beb..d2c0ca04 100644 --- a/app/pipeline/chat/exercise_chat_pipeline.py +++ b/app/pipeline/chat/exercise_chat_pipeline.py @@ -10,7 +10,6 @@ ) from langchain_core.runnables import Runnable from langsmith import traceable, get_current_run_tree -from sipbuild.generator.parser.tokens import tokens from weaviate.collections.classes.filters import Filter from .code_feedback_pipeline import CodeFeedbackPipeline @@ -35,7 +34,6 @@ from ...domain.data.programming_submission_dto import ProgrammingSubmissionDTO from ...llm import CapabilityRequestHandler, RequirementList from ...llm import CompletionArguments -from ...llm.external.LLMTokenCount import LLMTokenCount from ...llm.external.PipelineEnum import PipelineEnum from ...llm.langchain import IrisLangchainChatModel from ...retrieval.lecture_retrieval import LectureRetrieval @@ -102,7 +100,9 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO): ) self._run_exercise_chat_pipeline(dto, should_execute_lecture_pipeline), self.callback.done( - "Generated response", final_result=self.exercise_chat_response, tokens=self.tokens + "Generated response", + final_result=self.exercise_chat_response, + tokens=self.tokens, ) try: @@ -116,7 +116,11 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO): suggestion_dto.last_message = self.exercise_chat_response suggestion_dto.problem_statement = dto.exercise.problem_statement suggestions = self.suggestion_pipeline(suggestion_dto) - self.callback.done(final_result=None, suggestions=suggestions, tokens=[self.suggestion_pipeline.tokens]) + self.callback.done( + final_result=None, + suggestions=suggestions, + tokens=[self.suggestion_pipeline.tokens], + ) else: # This should never happen but whatever self.callback.skip( diff --git a/app/pipeline/competency_extraction_pipeline.py b/app/pipeline/competency_extraction_pipeline.py index df4d2e30..68265a48 100644 --- a/app/pipeline/competency_extraction_pipeline.py +++ b/app/pipeline/competency_extraction_pipeline.py @@ -5,7 +5,6 @@ from langchain_core.prompts import ( ChatPromptTemplate, ) -from sipbuild.generator.parser.tokens import tokens from app.domain import ( CompetencyExtractionPipelineExecutionDTO, @@ -80,10 +79,12 @@ def __call__( response = self.request_handler.chat( [prompt], CompletionArguments(temperature=0.4) ) - num_tokens = LLMTokenCount(model_info=response.model_info, - num_input_tokens=response.num_input_tokens, - num_output_tokens=response.num_output_tokens, - pipeline=PipelineEnum.IRIS_COMPETENCY_GENERATION) + num_tokens = LLMTokenCount( + model_info=response.model_info, + num_input_tokens=response.num_input_tokens, + num_output_tokens=response.num_output_tokens, + pipeline=PipelineEnum.IRIS_COMPETENCY_GENERATION, + ) self.tokens.append(num_tokens) response = response.contents[0].text_content From 4502e30e52f6117e20b14abc9dd6af9d96a526d5 Mon Sep 17 00:00:00 2001 From: Alexander Joham Date: Fri, 11 Oct 2024 19:42:52 +0200 Subject: [PATCH 2/2] Fix last lint error --- app/retrieval/lecture_retrieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/retrieval/lecture_retrieval.py b/app/retrieval/lecture_retrieval.py index c80f94bb..f75a139e 100644 --- a/app/retrieval/lecture_retrieval.py +++ b/app/retrieval/lecture_retrieval.py @@ -2,7 +2,6 @@ from typing import List from langsmith import traceable -from sipbuild.generator.parser.tokens import tokens from weaviate import WeaviateClient from weaviate.classes.query import Filter @@ -82,6 +81,7 @@ class LectureRetrieval(Pipeline): """ Class for retrieving lecture data from the database. """ + tokens: LLMTokenCount def __init__(self, client: WeaviateClient, **kwargs):