diff --git a/app/pipeline/chat/__init__.py b/app/pipeline/chat/__init__.py index 3cdcc370..629dfd69 100644 --- a/app/pipeline/chat/__init__.py +++ b/app/pipeline/chat/__init__.py @@ -1 +1 @@ -from simple_chat_pipeline import SimpleChatPipeline +from pipeline.chat.simple_chat_pipeline import SimpleChatPipeline diff --git a/app/pipeline/chat/chat_pipeline.py b/app/pipeline/chat/chat_pipeline.py index f0d8396d..60b41741 100644 --- a/app/pipeline/chat/chat_pipeline.py +++ b/app/pipeline/chat/chat_pipeline.py @@ -1,20 +1,17 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod, ABCMeta from domain import IrisMessage from pipeline import Pipeline -class ProgrammingExerciseTutorChatPipeline(Pipeline, ABC): +class ChatPipeline(Pipeline, metaclass=ABCMeta): """ Abstract class for the programming exercise tutor chat pipeline implementations. This class defines the signature of all implementations of this Iris feature. """ - def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage: - return self._run(query) - @abstractmethod - def _run(self, query: IrisMessage) -> IrisMessage: + def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage: """ Runs the pipeline and returns the response message. """ diff --git a/app/pipeline/chat/simple_chat_pipeline.py b/app/pipeline/chat/simple_chat_pipeline.py index 3ef7b180..efad821c 100644 --- a/app/pipeline/chat/simple_chat_pipeline.py +++ b/app/pipeline/chat/simple_chat_pipeline.py @@ -5,10 +5,10 @@ from domain import IrisMessage, IrisMessageRole from llm.langchain import IrisLangchainChatModel -from pipeline.chat.chat_pipeline import ProgrammingExerciseTutorChatPipeline +from pipeline.chat.chat_pipeline import ChatPipeline -class SimpleChatPipeline(ProgrammingExerciseTutorChatPipeline): +class SimpleChatPipeline(ChatPipeline): """A simple chat pipeline that uses our custom langchain chat model for our own request handler""" llm: IrisLangchainChatModel @@ -19,7 +19,7 @@ def __init__(self, llm: IrisLangchainChatModel): self.pipeline = {"query": itemgetter("query")} | llm | StrOutputParser() super().__init__(implementation_id="simple_chat_pipeline") - def _run(self, query: IrisMessage) -> IrisMessage: + def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage: """ Gets a response from the langchain chat model """ diff --git a/app/pipeline/chat/tutor_chat_pipeline.py b/app/pipeline/chat/tutor_chat_pipeline.py index 216a1e67..4e81ae58 100644 --- a/app/pipeline/chat/tutor_chat_pipeline.py +++ b/app/pipeline/chat/tutor_chat_pipeline.py @@ -7,12 +7,12 @@ from domain import IrisMessage, IrisMessageRole from llm.langchain import IrisLangchainChatModel -from pipeline.chat.chat_pipeline import ProgrammingExerciseTutorChatPipeline +from pipeline.chat.chat_pipeline import ChatPipeline logger = logging.getLogger(__name__) -class TutorChatPipelineReferenceImpl(ProgrammingExerciseTutorChatPipeline): +class TutorChatPipelineReferenceImpl(ChatPipeline): """Tutor chat pipeline that answers exercises related questions from students.""" llm: IrisLangchainChatModel @@ -38,7 +38,7 @@ def __init__(self, llm: IrisLangchainChatModel): # Create the pipeline self.pipeline = prompt | llm | StrOutputParser() - def _run(self, query: IrisMessage) -> IrisMessage: + def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage: """ Runs the pipeline :param query: The query diff --git a/app/pipeline/shared/summary_pipeline.py b/app/pipeline/shared/summary_pipeline.py index 2d54efee..aa7a8089 100644 --- a/app/pipeline/shared/summary_pipeline.py +++ b/app/pipeline/shared/summary_pipeline.py @@ -6,12 +6,12 @@ from langchain_core.runnables import Runnable from llm.langchain import IrisLangchainChatModel -from pipeline import AbstractPipeline +from pipeline import Pipeline logger = logging.getLogger(__name__) -class SummaryPipeline(AbstractPipeline): +class SummaryPipeline(Pipeline): """A generic summary pipeline that can be used to summarize any text""" llm: IrisLangchainChatModel @@ -19,8 +19,8 @@ class SummaryPipeline(AbstractPipeline): prompt_str: str prompt: ChatPromptTemplate - def __init__(self, llm: IrisLangchainChatModel, name=None): - super().__init__(name=name) + def __init__(self, llm: IrisLangchainChatModel): + super().__init__(implementation_id="summary_pipeline") # Set the langchain chat model self.llm = llm # Load the prompt from a file