Skip to content

Commit

Permalink
Use chat-gpt-variant instead of separate endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
wasnertobias committed Jan 15, 2025
1 parent 837494c commit 148a065
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 90 deletions.
10 changes: 0 additions & 10 deletions app/domain/chat_gpt_wrapper_pipeline_execution_dto.py

This file was deleted.

11 changes: 0 additions & 11 deletions app/domain/status/chat_gpt_wrapper_status_update_dto.py

This file was deleted.

62 changes: 45 additions & 17 deletions app/pipeline/chat_gpt_wrapper_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,51 @@
import logging
from typing import Optional
from typing import List, Optional

from langchain_core.prompts import (
ChatPromptTemplate,
)
from app.common.pyris_message import IrisMessageRole, PyrisMessage
from app.domain.chat.exercise_chat.exercise_chat_pipeline_execution_dto import ExerciseChatPipelineExecutionDTO
from app.domain.data.text_message_content_dto import TextMessageContentDTO
from app.pipeline.prompts.chat_gpt_wrapper_prompts import chat_gpt_initial_system_prompt
from langchain_core.messages import SystemMessage, HumanMessage

from app.domain.chat_gpt_wrapper_pipeline_execution_dto import (
ChatGPTWrapperPipelineExecutionDTO,
)
from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments
from app.pipeline import Pipeline
from app.web.status.status_update import ChatGPTWrapperCallback
from app.web.status.status_update import ExerciseChatStatusCallback

logger = logging.getLogger(__name__)

def convert_chat_history_to_str(chat_history: List[PyrisMessage]) -> str:
"""
Converts the chat history to a string
:param chat_history: The chat history
:return: The chat history as a string
"""

def map_message_role(role: IrisMessageRole) -> str:
if role == IrisMessageRole.SYSTEM:
return "System"
elif role == IrisMessageRole.ASSISTANT:
return "AI Tutor"
elif role == IrisMessageRole.USER:
return "Student"
else:
return "Unknown"

return "\n\n".join(
[
f"{map_message_role(message.sender)} {"" if not message.sent_at else f"at {message.sent_at.strftime(
"%Y-%m-%d %H:%M:%S")}"}: {message.contents[0].text_content}"
for message in chat_history
]
)

class ChatGPTWrapperPipeline(Pipeline):
callback: ChatGPTWrapperCallback
callback: ExerciseChatStatusCallback
request_handler: CapabilityRequestHandler

def __init__(self, callback: Optional[ChatGPTWrapperCallback] = None):
def __init__(self, callback: Optional[ExerciseChatStatusCallback] = None):
super().__init__(implementation_id="chat_gpt_wrapper_pipeline_reference_impl")
self.callback = callback
self.request_handler = CapabilityRequestHandler(
Expand All @@ -35,25 +58,30 @@ def __init__(self, callback: Optional[ChatGPTWrapperCallback] = None):

def __call__(
self,
dto: ChatGPTWrapperPipelineExecutionDTO,
dto: ExerciseChatPipelineExecutionDTO,
prompt: Optional[ChatPromptTemplate] = None,
**kwargs,
):
"""
Run the ChatGPT wrapper pipeline.
This consists of a single response generation step.
"""
if not dto.conversation:
raise ValueError("Conversation with at least one message is required")

pyris_system_prompt = PyrisMessage(
sender=IrisMessageRole.SYSTEM,
contents=[
TextMessageContentDTO(text_content=chat_gpt_initial_system_prompt)
],
query = dto.chat_history[-1] if dto.chat_history else None
if query and query.sender != IrisMessageRole.USER:
query = None

chat_history = (
dto.chat_history[-5:] if query is None else dto.chat_history[-6:-1]
)

chat_history_messages = convert_chat_history_to_str(chat_history)

prompts = [pyris_system_prompt] + dto.conversation
prompts = ChatPromptTemplate.from_messages(
[
SystemMessage(chat_gpt_initial_system_prompt),
HumanMessage(chat_history_messages),
]
)

response = self.request_handler.chat(
prompts, CompletionArguments(temperature=0.4)
Expand Down
34 changes: 11 additions & 23 deletions app/web/routers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
CourseChatStatusCallback,
CompetencyExtractionCallback,
LectureChatCallback,
ChatGPTWrapperCallback,
)
from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline
from app.dependencies import TokenValidator
Expand All @@ -32,9 +31,6 @@
)
from app.pipeline.text_exercise_chat_pipeline import TextExerciseChatPipeline
from app.web.status.status_update import TextExerciseChatCallback
from app.domain.chat_gpt_wrapper_pipeline_execution_dto import (
ChatGPTWrapperPipelineExecutionDTO,
)
from app.pipeline.chat_gpt_wrapper_pipeline import ChatGPTWrapperPipeline

router = APIRouter(prefix="/api/v1/pipelines", tags=["pipelines"])
Expand Down Expand Up @@ -79,9 +75,12 @@ def run_exercise_chat_pipeline(
description="Exercise Chat Pipeline Execution DTO"
),
):
thread = Thread(
target=run_exercise_chat_pipeline_worker, args=(dto, variant, event)
)
if variant == "default":
thread = Thread(
target=run_exercise_chat_pipeline_worker, args=(dto, variant, event)
)
else:
thread = Thread(target=run_chatgpt_wrapper_pipeline_worker, args=(dto, variant))
thread.start()


Expand Down Expand Up @@ -238,13 +237,13 @@ def run_competency_extraction_pipeline(


def run_chatgpt_wrapper_pipeline_worker(
dto: ChatGPTWrapperPipelineExecutionDTO, _variant: str
dto: ExerciseChatPipelineExecutionDTO, _variant: str
):
try:
callback = ChatGPTWrapperCallback(
run_id=dto.execution.settings.authentication_token,
base_url=dto.execution.settings.artemis_base_url,
initial_stages=dto.execution.initial_stages,
callback = ExerciseChatStatusCallback(
run_id=dto.settings.authentication_token,
base_url=dto.settings.artemis_base_url,
initial_stages=dto.initial_stages,
)
pipeline = ChatGPTWrapperPipeline(callback=callback)
except Exception as e:
Expand All @@ -260,17 +259,6 @@ def run_chatgpt_wrapper_pipeline_worker(
logger.error(traceback.format_exc())
callback.error("Fatal error.", exception=e)


@router.post(
"/chat-gpt-wrapper/{variant}/run",
status_code=status.HTTP_202_ACCEPTED,
dependencies=[Depends(TokenValidator())],
)
def run_chatgpt_wrapper_pipeline(variant: str, dto: ChatGPTWrapperPipelineExecutionDTO):
thread = Thread(target=run_chatgpt_wrapper_pipeline_worker, args=(dto, variant))
thread.start()


@router.get("/{feature}/variants")
def get_pipeline(feature: str):
"""
Expand Down
29 changes: 0 additions & 29 deletions app/web/status/status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
from app.domain.chat.exercise_chat.exercise_chat_status_update_dto import (
ExerciseChatStatusUpdateDTO,
)
from app.domain.status.chat_gpt_wrapper_status_update_dto import (
ChatGPTWrapperStatusUpdateDTO,
)
from app.domain.status.status_update_dto import StatusUpdateDTO
import logging

Expand Down Expand Up @@ -303,29 +300,3 @@ def __init__(
stages[stage],
stage,
)


class ChatGPTWrapperCallback(StatusCallback):
def __init__(
self,
run_id: str,
base_url: str,
initial_stages: List[StageDTO],
):
url = f"{base_url}/api/public/pyris/pipelines/chat-gpt-wrapper/runs/{run_id}/status"
stages = initial_stages or []
stage = len(stages)
stages += [
StageDTO(
weight=30,
state=StageStateEnum.NOT_STARTED,
name="Thinking",
),
]
super().__init__(
url,
run_id,
ChatGPTWrapperStatusUpdateDTO(stages=stages, result=""),
stages[stage],
stage,
)

0 comments on commit 148a065

Please sign in to comment.