Skip to content

Commit

Permalink
Fix lint errors
Browse files Browse the repository at this point in the history
Fix last lint error

Fix last lint errors
  • Loading branch information
alexjoham committed Oct 11, 2024
1 parent 9905460 commit 3b81a30
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 29 deletions.
2 changes: 1 addition & 1 deletion app/domain/data/token_usage_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ class TokenUsageDTO(BaseModel):
model_info: str
num_input_tokens: int
num_output_tokens: int
pipeline: PipelineEnum
pipeline: PipelineEnum
8 changes: 7 additions & 1 deletion app/llm/external/LLMTokenCount.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ class LLMTokenCount:
num_output_tokens: int
pipeline: PipelineEnum

def __init__(self, model_info: str, num_input_tokens: int, num_output_tokens: int, pipeline: PipelineEnum):
def __init__(
self,
model_info: str,
num_input_tokens: int,
num_output_tokens: int,
pipeline: PipelineEnum,
):
self.model_info = model_info
self.num_input_tokens = num_input_tokens
self.num_output_tokens = num_output_tokens
Expand Down
11 changes: 9 additions & 2 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]:
return messages_to_return


def convert_to_iris_message(message: Message, num_input_tokens: int, num_output_tokens: int, model: str) -> PyrisMessage:
def convert_to_iris_message(
message: Message, num_input_tokens: int, num_output_tokens: int, model: str
) -> PyrisMessage:
"""
Convert a Message to a PyrisMessage
"""
Expand Down Expand Up @@ -111,7 +113,12 @@ def chat(
format="json" if arguments.response_format == "JSON" else "",
options=self.options,
)
return convert_to_iris_message(response["message"], response["prompt_eval_count"], response["eval_count"], response["model"])
return convert_to_iris_message(
response["message"],
response["prompt_eval_count"],
response["eval_count"],
response["model"],
)

def embed(self, text: str) -> list[float]:
response = self._client.embeddings(
Expand Down
15 changes: 10 additions & 5 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,26 @@ def convert_to_open_ai_messages(
return openai_messages


def convert_to_iris_message(message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str) -> PyrisMessage:
def convert_to_iris_message(
message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str
) -> PyrisMessage:
"""
Convert a ChatCompletionMessage to a PyrisMessage
"""
num_input_tokens = getattr(usage, 'prompt_tokens', -1)
num_output_tokens = getattr(usage, 'completion_tokens', -1)
num_input_tokens = getattr(usage, "prompt_tokens", -1)
num_output_tokens = getattr(usage, "completion_tokens", -1)

message = PyrisMessage(
sender=map_str_to_role(message.role),
contents=[TextMessageContentDTO(textContent=message.content)],
send_at=datetime.now(),
num_input_tokens=num_input_tokens,
num_output_tokens=num_output_tokens,
model_info=model
model_info=model,
)
return message


class OpenAIChatModel(ChatModel):
model: str
api_key: str
Expand Down Expand Up @@ -110,7 +113,9 @@ def chat(
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
)
return convert_to_iris_message(response.choices[0].message, response.usage, response.model)
return convert_to_iris_message(
response.choices[0].message, response.usage, response.model
)
except Exception as e:
wait_time = initial_delay * (backoff_factor**attempt)
logging.warning(f"Exception on attempt {attempt + 1}: {e}")
Expand Down
10 changes: 6 additions & 4 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def _generate(
iris_message = self.request_handler.chat(iris_messages, self.completion_args)
base_message = convert_iris_message_to_langchain_message(iris_message)
chat_generation = ChatGeneration(message=base_message)
self.tokens = LLMTokenCount(model_info=iris_message.model_info,
num_input_tokens=iris_message.num_input_tokens,
num_output_tokens=iris_message.num_output_tokens,
pipeline=PipelineEnum.NOT_SET)
self.tokens = LLMTokenCount(
model_info=iris_message.model_info,
num_input_tokens=iris_message.num_input_tokens,
num_output_tokens=iris_message.num_output_tokens,
pipeline=PipelineEnum.NOT_SET,
)
return ChatResult(generations=[chat_generation])

@property
Expand Down
1 change: 0 additions & 1 deletion app/pipeline/chat/code_feedback_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from langchain_core.runnables import Runnable
from langsmith import traceable
from pydantic import BaseModel
from sipbuild.generator.parser.tokens import tokens

from ...domain import PyrisMessage
from ...domain.data.build_log_entry import BuildLogEntryDTO
Expand Down
2 changes: 0 additions & 2 deletions app/pipeline/chat/course_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from langchain_core.runnables import Runnable
from langchain_core.tools import tool
from langsmith import traceable
from sipbuild.generator.parser.tokens import tokens
from weaviate.collections.classes.filters import Filter

from .interaction_suggestion_pipeline import (
Expand Down Expand Up @@ -42,7 +41,6 @@
elicit_begin_agent_jol_prompt,
)
from ...domain import CourseChatPipelineExecutionDTO
from ...llm.external.LLMTokenCount import LLMTokenCount
from ...llm.external.PipelineEnum import PipelineEnum
from ...retrieval.lecture_retrieval import LectureRetrieval
from ...vector_database.database import VectorDatabase
Expand Down
12 changes: 8 additions & 4 deletions app/pipeline/chat/exercise_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from langchain_core.runnables import Runnable
from langsmith import traceable, get_current_run_tree
from sipbuild.generator.parser.tokens import tokens
from weaviate.collections.classes.filters import Filter

from .code_feedback_pipeline import CodeFeedbackPipeline
Expand All @@ -35,7 +34,6 @@
from ...domain.data.programming_submission_dto import ProgrammingSubmissionDTO
from ...llm import CapabilityRequestHandler, RequirementList
from ...llm import CompletionArguments
from ...llm.external.LLMTokenCount import LLMTokenCount
from ...llm.external.PipelineEnum import PipelineEnum
from ...llm.langchain import IrisLangchainChatModel
from ...retrieval.lecture_retrieval import LectureRetrieval
Expand Down Expand Up @@ -102,7 +100,9 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO):
)
self._run_exercise_chat_pipeline(dto, should_execute_lecture_pipeline),
self.callback.done(
"Generated response", final_result=self.exercise_chat_response, tokens=self.tokens
"Generated response",
final_result=self.exercise_chat_response,
tokens=self.tokens,
)

try:
Expand All @@ -116,7 +116,11 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO):
suggestion_dto.last_message = self.exercise_chat_response
suggestion_dto.problem_statement = dto.exercise.problem_statement
suggestions = self.suggestion_pipeline(suggestion_dto)
self.callback.done(final_result=None, suggestions=suggestions, tokens=[self.suggestion_pipeline.tokens])
self.callback.done(
final_result=None,
suggestions=suggestions,
tokens=[self.suggestion_pipeline.tokens],
)
else:
# This should never happen but whatever
self.callback.skip(
Expand Down
1 change: 0 additions & 1 deletion app/pipeline/chat/lecture_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
LectureChatPipelineExecutionDTO,
)
from ...llm import CapabilityRequestHandler, RequirementList
from ...llm.external import LLMTokenCount
from ...llm.external.PipelineEnum import PipelineEnum
from ...retrieval.lecture_retrieval import LectureRetrieval
from ...vector_database.database import VectorDatabase
Expand Down
11 changes: 6 additions & 5 deletions app/pipeline/competency_extraction_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from langchain_core.prompts import (
ChatPromptTemplate,
)
from sipbuild.generator.parser.tokens import tokens

from app.domain import (
CompetencyExtractionPipelineExecutionDTO,
Expand Down Expand Up @@ -80,10 +79,12 @@ def __call__(
response = self.request_handler.chat(
[prompt], CompletionArguments(temperature=0.4)
)
num_tokens = LLMTokenCount(model_info=response.model_info,
num_input_tokens=response.num_input_tokens,
num_output_tokens=response.num_output_tokens,
pipeline=PipelineEnum.IRIS_COMPETENCY_GENERATION)
num_tokens = LLMTokenCount(
model_info=response.model_info,
num_input_tokens=response.num_input_tokens,
num_output_tokens=response.num_output_tokens,
pipeline=PipelineEnum.IRIS_COMPETENCY_GENERATION,
)
self.tokens.append(num_tokens)
response = response.contents[0].text_content

Expand Down
1 change: 0 additions & 1 deletion app/pipeline/shared/citation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from langchain_core.runnables import Runnable

from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments
from app.llm.external import LLMTokenCount
from app.llm.external.PipelineEnum import PipelineEnum
from app.llm.langchain import IrisLangchainChatModel
from app.pipeline import Pipeline
Expand Down
1 change: 0 additions & 1 deletion app/pipeline/shared/reranker_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import Runnable
from langsmith import traceable
from sipbuild.generator.parser.tokens import tokens

from app.domain import PyrisMessage
from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments
Expand Down
2 changes: 1 addition & 1 deletion app/retrieval/lecture_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List

from langsmith import traceable
from sipbuild.generator.parser.tokens import tokens
from weaviate import WeaviateClient
from weaviate.classes.query import Filter

Expand Down Expand Up @@ -82,6 +81,7 @@ class LectureRetrieval(Pipeline):
"""
Class for retrieving lecture data from the database.
"""

tokens: LLMTokenCount

def __init__(self, client: WeaviateClient, **kwargs):
Expand Down

0 comments on commit 3b81a30

Please sign in to comment.