From bae6252d368fcbe2878a45f43a5c51ae6d9937ca Mon Sep 17 00:00:00 2001 From: Michael Dyer Date: Tue, 3 Sep 2024 17:21:21 +0200 Subject: [PATCH] Initial commit --- app/domain/data/text_exercise_dto.py | 15 +++++ ...xt_exercise_chat_pipeline_execution_dto.py | 10 ++++ .../prompts/text_exercise_chat_prompts.py | 35 +++++++++++ app/pipeline/text_exercise_chat_pipeline.py | 59 +++++++++++++++++++ app/web/routers/pipelines.py | 43 ++++++++++++++ app/web/status/status_update.py | 22 +++++++ 6 files changed, 184 insertions(+) create mode 100644 app/domain/data/text_exercise_dto.py create mode 100644 app/domain/text_exercise_chat_pipeline_execution_dto.py create mode 100644 app/pipeline/prompts/text_exercise_chat_prompts.py create mode 100644 app/pipeline/text_exercise_chat_pipeline.py diff --git a/app/domain/data/text_exercise_dto.py b/app/domain/data/text_exercise_dto.py new file mode 100644 index 00000000..7040b181 --- /dev/null +++ b/app/domain/data/text_exercise_dto.py @@ -0,0 +1,15 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, Field + +from domain.data.course_dto import CourseDTO + + +class TextExerciseDTO(BaseModel): + id: int + name: str + course: CourseDTO + problem_statement: str = Field(alias="problemStatement") + start_date: Optional[datetime] = Field(alias="startDate", default=None) + end_date: Optional[datetime] = Field(alias="endDate", default=None) diff --git a/app/domain/text_exercise_chat_pipeline_execution_dto.py b/app/domain/text_exercise_chat_pipeline_execution_dto.py new file mode 100644 index 00000000..03ff7c19 --- /dev/null +++ b/app/domain/text_exercise_chat_pipeline_execution_dto.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + +from domain import PipelineExecutionDTO +from domain.data.text_exercise_dto import TextExerciseDTO + + +class TextExerciseChatPipelineExecutionDTO(BaseModel): + execution: PipelineExecutionDTO + exercise: TextExerciseDTO + current_answer: str = Field(alias="currentAnswer") diff --git a/app/pipeline/prompts/text_exercise_chat_prompts.py b/app/pipeline/prompts/text_exercise_chat_prompts.py new file mode 100644 index 00000000..390cb954 --- /dev/null +++ b/app/pipeline/prompts/text_exercise_chat_prompts.py @@ -0,0 +1,35 @@ +def system_prompt( + exercise_name: str, + course_name: str, + course_description: str, + problem_statement: str, + start_date: str, + end_date: str, + current_date: str, + current_answer: str, +) -> str: + return """ + The student is working on a free-response exercise called '{exercise_name}' in the course '{course_name}'. + The course has the following description: + {course_description} + + The exercise has the following problem statement: + {problem_statement} + + The exercise began on {start_date} and will end on {end_date}. The current date is {current_date}. + + This is what the student has written so far: + {current_answer} + + You are a writing tutor. Provide feedback to the student on their response, + giving specific tips to better answer the problem statement. + """.format( + exercise_name=exercise_name, + course_name=course_name, + course_description=course_description, + problem_statement=problem_statement, + start_date=start_date, + end_date=end_date, + current_date=current_date, + current_answer=current_answer, + ) diff --git a/app/pipeline/text_exercise_chat_pipeline.py b/app/pipeline/text_exercise_chat_pipeline.py new file mode 100644 index 00000000..253504a2 --- /dev/null +++ b/app/pipeline/text_exercise_chat_pipeline.py @@ -0,0 +1,59 @@ +import logging +from datetime import datetime +from typing import Optional + +from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments +from app.pipeline import Pipeline +from domain import PyrisMessage, IrisMessageRole +from domain.data.text_message_content_dto import TextMessageContentDTO +from domain.text_exercise_chat_pipeline_execution_dto import ( + TextExerciseChatPipelineExecutionDTO, +) +from pipeline.prompts.text_exercise_chat_prompts import system_prompt +from web.status.status_update import TextExerciseChatCallback + +logger = logging.getLogger(__name__) + + +class TextExerciseChatPipeline(Pipeline): + callback: TextExerciseChatCallback + request_handler: CapabilityRequestHandler + + def __init__(self, callback: Optional[TextExerciseChatCallback] = None): + super().__init__(implementation_id="text_exercise_chat_pipeline_reference_impl") + self.callback = callback + self.request_handler = CapabilityRequestHandler( + requirements=RequirementList(context_length=8000) + ) + + def __call__( + self, + dto: TextExerciseChatPipelineExecutionDTO, + **kwargs, + ): + if not dto.exercise: + raise ValueError("Exercise is required") + + prompt = system_prompt( + exercise_name=dto.exercise.name, + course_name=dto.exercise.course.name, + course_description=dto.exercise.course.description, + problem_statement=dto.exercise.problem_statement, + start_date=str(dto.exercise.start_date), + end_date=str(dto.exercise.end_date), + current_date=str(datetime.now()), + current_answer=dto.current_answer, + ) + prompt = PyrisMessage( + sender=IrisMessageRole.SYSTEM, + contents=[TextMessageContentDTO(text_content=prompt)], + ) + + # done building prompt + + response = self.request_handler.chat( + [prompt], CompletionArguments(temperature=0.4) + ) + response = response.contents[0].text_content + + self.callback.done(response) diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index eb198199..fae60926 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -21,6 +21,11 @@ from app.dependencies import TokenValidator from app.domain import FeatureDTO from app.pipeline.competency_extraction_pipeline import CompetencyExtractionPipeline +from domain.text_exercise_chat_pipeline_execution_dto import ( + TextExerciseChatPipelineExecutionDTO, +) +from pipeline.text_exercise_chat_pipeline import TextExerciseChatPipeline +from web.status.status_update import TextExerciseChatCallback router = APIRouter(prefix="/api/v1/pipelines", tags=["pipelines"]) logger = logging.getLogger(__name__) @@ -90,6 +95,44 @@ def run_course_chat_pipeline(variant: str, dto: CourseChatPipelineExecutionDTO): thread.start() +def run_text_exercise_chat_pipeline_worker(dto, variant): + try: + callback = TextExerciseChatCallback( + run_id=dto.settings.authentication_token, + base_url=dto.settings.artemis_base_url, + initial_stages=dto.initial_stages, + ) + match variant: + case "default" | "text_exercise_chat_pipeline_reference_impl": + pipeline = TextExerciseChatPipeline(callback=callback) + case _: + raise ValueError(f"Unknown variant: {variant}") + except Exception as e: + logger.error(f"Error preparing text exercise chat pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + return + + try: + pipeline(dto=dto) + except Exception as e: + logger.error(f"Error running text exercise chat pipeline: {e}") + logger.error(traceback.format_exc()) + callback.error("Fatal error.", exception=e) + + +@router.post( + "/text-exercise-chat/{variant}/run", + status_code=status.HTTP_202_ACCEPTED, + dependencies=[Depends(TokenValidator())], +) +def run_text_exercise_chat_pipeline( + variant: str, dto: TextExerciseChatPipelineExecutionDTO +): + thread = Thread(target=run_text_exercise_chat_pipeline_worker, args=(dto, variant)) + thread.start() + + def run_competency_extraction_pipeline_worker( dto: CompetencyExtractionPipelineExecutionDTO, _variant: str ): diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index 1f497f75..1ddf1ca9 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -218,6 +218,28 @@ def __init__( super().__init__(url, run_id, status, stage, current_stage_index) +class TextExerciseChatCallback(StatusCallback): + def __init__( + self, + run_id: str, + base_url: str, + initial_stages: List[StageDTO], + ): + url = f"{base_url}/api/public/pyris/pipelines/text-exercise-chat/runs/{run_id}/status" + stages = initial_stages or [] + stage = len(stages) + stages += [ + StageDTO( + weight=40, + state=StageStateEnum.NOT_STARTED, + name="Thinking", + ) + ] + super().__init__( + url, run_id, StatusUpdateDTO(stages=stages), stages[stage], stage + ) + + class CompetencyExtractionCallback(StatusCallback): def __init__( self,