From 5c82d5b86b2f5332fb5d58c7a73e220e6f6ddefa Mon Sep 17 00:00:00 2001 From: Sebastian Loose Date: Wed, 27 Nov 2024 20:45:53 +0100 Subject: [PATCH] Pipeline: Add lecture chat pipeline connection (#173) --- .../status/lecture_chat_status_update_dto.py | 12 +++++ app/pipeline/chat/lecture_chat_pipeline.py | 28 ++++++++-- app/vector_database/lecture_schema.py | 30 +++++++++-- app/web/routers/pipelines.py | 51 +++++++++++++++++++ app/web/status/status_update.py | 29 +++++++++++ 5 files changed, 142 insertions(+), 8 deletions(-) create mode 100644 app/domain/status/lecture_chat_status_update_dto.py diff --git a/app/domain/status/lecture_chat_status_update_dto.py b/app/domain/status/lecture_chat_status_update_dto.py new file mode 100644 index 00000000..0323c885 --- /dev/null +++ b/app/domain/status/lecture_chat_status_update_dto.py @@ -0,0 +1,12 @@ +from app.domain.status.status_update_dto import StatusUpdateDTO + + +class LectureChatStatusUpdateDTO(StatusUpdateDTO): + """Data Transfer Object for lecture chat status updates. + This DTO extends the base StatusUpdateDTO to include the result of lecture chat + pipeline operations, facilitating communication between Artemis and the lecture + chat system. + """ + + result: str + """The result message or status of the lecture chat pipeline operation.""" diff --git a/app/pipeline/chat/lecture_chat_pipeline.py b/app/pipeline/chat/lecture_chat_pipeline.py index 7d3d5312..816f531b 100644 --- a/app/pipeline/chat/lecture_chat_pipeline.py +++ b/app/pipeline/chat/lecture_chat_pipeline.py @@ -25,6 +25,7 @@ from ...llm.langchain import IrisLangchainChatModel from ..pipeline import Pipeline +from ...web.status.status_update import LectureChatCallback logger = logging.getLogger(__name__) @@ -57,17 +58,28 @@ class LectureChatPipeline(Pipeline): llm: IrisLangchainChatModel pipeline: Runnable prompt: ChatPromptTemplate + callback: LectureChatCallback - def __init__(self): + def __init__( + self, + callback: LectureChatCallback, + dto: LectureChatPipelineExecutionDTO, + variant: str, + ): super().__init__(implementation_id="lecture_chat_pipeline") # Set the langchain chat model request_handler = CapabilityRequestHandler( requirements=RequirementList( - gpt_version_equivalent=3.5, + gpt_version_equivalent=4.5, context_length=16385, privacy_compliance=True, ) ) + + self.callback = callback + self.dto = dto + self.variant = variant + completion_args = CompletionArguments(temperature=0, max_tokens=2000) self.llm = IrisLangchainChatModel( request_handler=request_handler, completion_args=completion_args @@ -91,7 +103,6 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO): Runs the pipeline :param dto: execution data transfer object """ - self.prompt = ChatPromptTemplate.from_messages( [ ("system", lecture_initial_prompt()), @@ -124,8 +135,17 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO): ) self.tokens.extend(self.citation_pipeline.tokens) logger.info(f"Response from lecture chat pipeline: {response}") - return response_with_citation + self.callback.done( + "Response created", + final_result=response_with_citation, + tokens=self.tokens, + ) except Exception as e: + self.callback.error( + "Generating interaction suggestions failed.", + exception=e, + tokens=self.tokens, + ) raise e def _add_conversation_to_prompt( diff --git a/app/vector_database/lecture_schema.py b/app/vector_database/lecture_schema.py index 8c998c85..456dce91 100644 --- a/app/vector_database/lecture_schema.py +++ b/app/vector_database/lecture_schema.py @@ -33,12 +33,26 @@ def init_lecture_schema(client: WeaviateClient) -> Collection: if client.collections.exists(LectureSchema.COLLECTION_NAME.value): collection = client.collections.get(LectureSchema.COLLECTION_NAME.value) properties = collection.config.get(simple=True).properties + + # Check and add 'course_language' property if missing + if not any( + property.name == LectureSchema.COURSE_LANGUAGE.value + for property in collection.config.get(simple=False).properties + ): + collection.config.add_property( + Property( + name=LectureSchema.COURSE_LANGUAGE.value, + description="The language of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ) + ) + + # Check and add 'lecture_unit_link' property if missing if not any( - property.__name__ == LectureSchema.LECTURE_UNIT_LINK.value - for property_found in properties + property.name == LectureSchema.LECTURE_UNIT_LINK.value + for property in properties ): - return collection - else: collection.config.add_property( Property( name=LectureSchema.LECTURE_UNIT_LINK.value, @@ -47,7 +61,9 @@ def init_lecture_schema(client: WeaviateClient) -> Collection: index_searchable=False, ) ) + return collection + return client.collections.create( name=LectureSchema.COLLECTION_NAME.value, vectorizer_config=Configure.Vectorizer.none(), @@ -73,6 +89,12 @@ def init_lecture_schema(client: WeaviateClient) -> Collection: data_type=DataType.TEXT, index_searchable=False, ), + Property( + name=LectureSchema.COURSE_LANGUAGE.value, + description="The language of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ), Property( name=LectureSchema.LECTURE_ID.value, description="The ID of the lecture", diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index 4ef8880a..fbc3c9f3 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -12,10 +12,15 @@ CompetencyExtractionPipelineExecutionDTO, ) from app.pipeline.chat.exercise_chat_agent_pipeline import ExerciseChatAgentPipeline +from app.domain.chat.lecture_chat.lecture_chat_pipeline_execution_dto import ( + LectureChatPipelineExecutionDTO, +) +from app.pipeline.chat.lecture_chat_pipeline import LectureChatPipeline from app.web.status.status_update import ( ExerciseChatStatusCallback, CourseChatStatusCallback, CompetencyExtractionCallback, + LectureChatCallback, ) from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline from app.dependencies import TokenValidator @@ -139,6 +144,34 @@ def run_text_exercise_chat_pipeline_worker(dto, variant): callback.error("Fatal error.", exception=e) +def run_lecture_chat_pipeline_worker(dto, variant): + try: + callback = LectureChatCallback( + run_id=dto.settings.authentication_token, + base_url=dto.settings.artemis_base_url, + initial_stages=dto.initial_stages, + ) + match variant: + case "default" | "lecture_chat_pipeline_reference_impl": + pipeline = LectureChatPipeline( + callback=callback, dto=dto, variant=variant + ) + case _: + raise ValueError(f"Unknown variant: {variant}") + except Exception as e: + logger.error(f"Error preparing lecture chat pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + return + + try: + pipeline(dto=dto) + except Exception as e: + logger.error(f"Error running lecture chat pipeline: {e}") + logger.error(traceback.format_exc()) + callback.error("Fatal error.", exception=e) + + @router.post( "/text-exercise-chat/{variant}/run", status_code=status.HTTP_202_ACCEPTED, @@ -151,6 +184,16 @@ def run_text_exercise_chat_pipeline( thread.start() +@router.post( + "/lecture-chat/{variant}/run", + status_code=status.HTTP_202_ACCEPTED, + dependencies=[Depends(TokenValidator())], +) +def run_lecture_chat_pipeline(variant: str, dto: LectureChatPipelineExecutionDTO): + thread = Thread(target=run_lecture_chat_pipeline_worker, args=(dto, variant)) + thread.start() + + def run_competency_extraction_pipeline_worker( dto: CompetencyExtractionPipelineExecutionDTO, _variant: str ): @@ -243,5 +286,13 @@ def get_pipeline(feature: str): description="Default lecture ingestion variant.", ) ] + case "LECTURE_CHAT": + return [ + FeatureDTO( + id="default", + name="Default Variant", + description="Default lecture chat variant.", + ) + ] case _: return Response(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index bfd5e889..874c5884 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -12,6 +12,9 @@ from app.domain.chat.course_chat.course_chat_status_update_dto import ( CourseChatStatusUpdateDTO, ) +from app.domain.status.lecture_chat_status_update_dto import ( + LectureChatStatusUpdateDTO, +) from app.domain.status.stage_state_dto import StageStateEnum from app.domain.status.stage_dto import StageDTO from app.domain.status.text_exercise_chat_status_update_dto import ( @@ -267,3 +270,29 @@ def __init__( status = CompetencyExtractionStatusUpdateDTO(stages=stages) stage = stages[-1] super().__init__(url, run_id, status, stage, len(stages) - 1) + + +class LectureChatCallback(StatusCallback): + def __init__( + self, + run_id: str, + base_url: str, + initial_stages: List[StageDTO], + ): + url = f"{base_url}/api/public/pyris/pipelines/lecture-chat/runs/{run_id}/status" + stages = initial_stages or [] + stage = len(stages) + stages += [ + StageDTO( + weight=30, + state=StageStateEnum.NOT_STARTED, + name="Thinking", + ), + ] + super().__init__( + url, + run_id, + LectureChatStatusUpdateDTO(stages=stages, result=""), + stages[stage], + stage, + )