Skip to content

Commit

Permalink
Naming changes and import bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kaancayli committed Feb 19, 2024
1 parent 47c49e5 commit aba2d25
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 17 deletions.
2 changes: 1 addition & 1 deletion app/pipeline/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from simple_chat_pipeline import SimpleChatPipeline
from pipeline.chat.simple_chat_pipeline import SimpleChatPipeline
9 changes: 3 additions & 6 deletions app/pipeline/chat/chat_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from abc import ABC, abstractmethod
from abc import abstractmethod, ABCMeta

from domain import IrisMessage
from pipeline import Pipeline


class ProgrammingExerciseTutorChatPipeline(Pipeline, ABC):
class ChatPipeline(Pipeline, metaclass=ABCMeta):
"""
Abstract class for the programming exercise tutor chat pipeline implementations.
This class defines the signature of all implementations of this Iris feature.
"""

def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage:
return self._run(query)

@abstractmethod
def _run(self, query: IrisMessage) -> IrisMessage:
def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage:
"""
Runs the pipeline and returns the response message.
"""
Expand Down
6 changes: 3 additions & 3 deletions app/pipeline/chat/simple_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from domain import IrisMessage, IrisMessageRole
from llm.langchain import IrisLangchainChatModel
from pipeline.chat.chat_pipeline import ProgrammingExerciseTutorChatPipeline
from pipeline.chat.chat_pipeline import ChatPipeline


class SimpleChatPipeline(ProgrammingExerciseTutorChatPipeline):
class SimpleChatPipeline(ChatPipeline):
"""A simple chat pipeline that uses our custom langchain chat model for our own request handler"""

llm: IrisLangchainChatModel
Expand All @@ -19,7 +19,7 @@ def __init__(self, llm: IrisLangchainChatModel):
self.pipeline = {"query": itemgetter("query")} | llm | StrOutputParser()
super().__init__(implementation_id="simple_chat_pipeline")

def _run(self, query: IrisMessage) -> IrisMessage:
def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage:
"""
Gets a response from the langchain chat model
"""
Expand Down
6 changes: 3 additions & 3 deletions app/pipeline/chat/tutor_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from domain import IrisMessage, IrisMessageRole
from llm.langchain import IrisLangchainChatModel
from pipeline.chat.chat_pipeline import ProgrammingExerciseTutorChatPipeline
from pipeline.chat.chat_pipeline import ChatPipeline

logger = logging.getLogger(__name__)


class TutorChatPipelineReferenceImpl(ProgrammingExerciseTutorChatPipeline):
class TutorChatPipelineReferenceImpl(ChatPipeline):
"""Tutor chat pipeline that answers exercises related questions from students."""

llm: IrisLangchainChatModel
Expand All @@ -38,7 +38,7 @@ def __init__(self, llm: IrisLangchainChatModel):
# Create the pipeline
self.pipeline = prompt | llm | StrOutputParser()

def _run(self, query: IrisMessage) -> IrisMessage:
def __call__(self, query: IrisMessage, **kwargs) -> IrisMessage:
"""
Runs the pipeline
:param query: The query
Expand Down
8 changes: 4 additions & 4 deletions app/pipeline/shared/summary_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
from langchain_core.runnables import Runnable

from llm.langchain import IrisLangchainChatModel
from pipeline import AbstractPipeline
from pipeline import Pipeline

logger = logging.getLogger(__name__)


class SummaryPipeline(AbstractPipeline):
class SummaryPipeline(Pipeline):
"""A generic summary pipeline that can be used to summarize any text"""

llm: IrisLangchainChatModel
pipeline: Runnable
prompt_str: str
prompt: ChatPromptTemplate

def __init__(self, llm: IrisLangchainChatModel, name=None):
super().__init__(name=name)
def __init__(self, llm: IrisLangchainChatModel):
super().__init__(implementation_id="summary_pipeline")
# Set the langchain chat model
self.llm = llm
# Load the prompt from a file
Expand Down

0 comments on commit aba2d25

Please sign in to comment.