diff --git a/app/pipeline/chat/chat_pipeline.py b/app/pipeline/chat/chat_pipeline.py index 8e0f7fdd..f0d8396d 100644 --- a/app/pipeline/chat/chat_pipeline.py +++ b/app/pipeline/chat/chat_pipeline.py @@ -5,7 +5,10 @@ class ProgrammingExerciseTutorChatPipeline(Pipeline, ABC): - """Abstract class for the programming exercise tutor chat pipeline implementations""" + """ + Abstract class for the programming exercise tutor chat pipeline implementations. + This class defines the signature of all implementations of this Iris feature. + """ def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage: return self._run(query) diff --git a/app/pipeline/chat/tutor_chat_pipeline.py b/app/pipeline/chat/tutor_chat_pipeline.py index a422c05b..93461812 100644 --- a/app/pipeline/chat/tutor_chat_pipeline.py +++ b/app/pipeline/chat/tutor_chat_pipeline.py @@ -5,42 +5,38 @@ from domain import IrisMessage, IrisMessageRole from llm.langchain import IrisLangchainChatModel -from pipeline import AbstractPipeline - +from pipeline.chat.chat_pipeline import ProgrammingExerciseTutorChatPipeline logger = logging.getLogger(__name__) -class TutorChatPipeline(AbstractPipeline): +class TutorChatPipelineReferenceImpl(ProgrammingExerciseTutorChatPipeline): """Tutor chat pipeline that answers exercises related questions from students.""" llm: IrisLangchainChatModel pipeline: Runnable - prompt_str: str - prompt: ChatPromptTemplate - def __init__(self, llm: IrisLangchainChatModel, name=None): - super().__init__(name=name) + def __init__(self, llm: IrisLangchainChatModel): + super().__init__(implementation_id="tutor_chat_pipeline_reference_impl") # Set the langchain chat model self.llm = llm # Load the prompt from a file with open("../prompts/iris_tutor_chat_prompt.txt", "r") as file: logger.debug("Loading tutor chat prompt...") - self.prompt_str = file.read() + prompt_str = file.read() # Create the prompt - self.prompt = ChatPromptTemplate.from_messages( + prompt = ChatPromptTemplate.from_messages( [ - SystemMessagePromptTemplate.from_template(self.prompt_str), + SystemMessagePromptTemplate.from_template(prompt_str), ] ) # Create the pipeline - self.pipeline = self.prompt | llm | StrOutputParser() + self.pipeline = prompt | llm | StrOutputParser() - def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage: + def _run(self, query: IrisMessage) -> IrisMessage: """ Runs the pipeline :param query: The query - :param kwargs: keyword arguments :return: IrisMessage """ if query is None: