From 6891e3c3d6c9b1ad8ec0462ed068cb2104d7089e Mon Sep 17 00:00:00 2001 From: Sebastian Loose Date: Thu, 31 Oct 2024 16:03:23 +0100 Subject: [PATCH] Add lecture pipeline connection --- .../status/lecture_chat_status_update_dto.py | 5 +++ .../request_handler/basic_request_handler.py | 2 + app/pipeline/chat/lecture_chat_pipeline.py | 19 ++++++++- app/retrieval/lecture_retrieval.py | 4 ++ app/vector_database/lecture_schema.py | 20 +++++++++- app/web/routers/pipelines.py | 39 ++++++++++++++++--- app/web/status/status_update.py | 35 +++++++++++++++++ 7 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 app/domain/status/lecture_chat_status_update_dto.py diff --git a/app/domain/status/lecture_chat_status_update_dto.py b/app/domain/status/lecture_chat_status_update_dto.py new file mode 100644 index 00000000..3eb40a31 --- /dev/null +++ b/app/domain/status/lecture_chat_status_update_dto.py @@ -0,0 +1,5 @@ +from app.domain.status.status_update_dto import StatusUpdateDTO + + +class LectureChatStatusUpdateDTO(StatusUpdateDTO): + result: str diff --git a/app/llm/request_handler/basic_request_handler.py b/app/llm/request_handler/basic_request_handler.py index 1342a71c..5392bd75 100644 --- a/app/llm/request_handler/basic_request_handler.py +++ b/app/llm/request_handler/basic_request_handler.py @@ -14,6 +14,7 @@ class BasicRequestHandler(RequestHandler): def __init__(self, model_id: str): self.model_id = model_id self.llm_manager = LlmManager() + print("llm manager" + str(self.llm_manager.entries)) def complete( self, @@ -32,4 +33,5 @@ def chat( def embed(self, text: str) -> list[float]: llm = self.llm_manager.get_llm_by_id(self.model_id) + print("llm uiuiu" + str(llm)) return llm.embed(text) diff --git a/app/pipeline/chat/lecture_chat_pipeline.py b/app/pipeline/chat/lecture_chat_pipeline.py index 22eb8c7a..273c7d3f 100644 --- a/app/pipeline/chat/lecture_chat_pipeline.py +++ b/app/pipeline/chat/lecture_chat_pipeline.py @@ -25,6 +25,7 @@ from ...llm.langchain import IrisLangchainChatModel from ..pipeline import Pipeline +from ...web.status.status_update import LectureChatCallback logger = logging.getLogger(__name__) @@ -55,8 +56,9 @@ class LectureChatPipeline(Pipeline): llm: IrisLangchainChatModel pipeline: Runnable prompt: ChatPromptTemplate + callback: LectureChatCallback - def __init__(self): + def __init__(self, callback: LectureChatCallback): super().__init__(implementation_id="lecture_chat_pipeline") # Set the langchain chat model request_handler = CapabilityRequestHandler( @@ -66,6 +68,9 @@ def __init__(self): privacy_compliance=True, ) ) + + self.callback = callback + completion_args = CompletionArguments(temperature=0, max_tokens=2000) self.llm = IrisLangchainChatModel( request_handler=request_handler, completion_args=completion_args @@ -89,7 +94,6 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO): Runs the pipeline :param dto: execution data transfer object """ - self.prompt = ChatPromptTemplate.from_messages( [ ("system", lecture_initial_prompt()), @@ -115,7 +119,13 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO): prompt_val = self.prompt.format_messages() self.prompt = ChatPromptTemplate.from_messages(prompt_val) try: + self.callback.in_progress() response = (self.prompt | self.pipeline).invoke({}) + self.callback.done( + "Generated response", + final_result=response, + tokens=self.tokens, + ) self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_CHAT_LECTURE_MESSAGE) response_with_citation = self.citation_pipeline( retrieved_lecture_chunks, response @@ -124,6 +134,11 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO): logger.info(f"Response from lecture chat pipeline: {response}") return response_with_citation except Exception as e: + self.callback.error( + "Generating interaction suggestions failed.", + exception=e, + tokens=self.tokens, + ) raise e def _add_conversation_to_prompt( diff --git a/app/retrieval/lecture_retrieval.py b/app/retrieval/lecture_retrieval.py index 7bcd8ce0..17ef723a 100644 --- a/app/retrieval/lecture_retrieval.py +++ b/app/retrieval/lecture_retrieval.py @@ -512,6 +512,10 @@ def fetch_course_language(self, course_id): """ course_language = "english" + print(self.collection) + print(self.llm_embedding) + print(self.collection.query) + if course_id: # Fetch the first object that matches the course ID with the language property result = self.collection.query.fetch_objects( diff --git a/app/vector_database/lecture_schema.py b/app/vector_database/lecture_schema.py index 912abed7..a1305f2b 100644 --- a/app/vector_database/lecture_schema.py +++ b/app/vector_database/lecture_schema.py @@ -30,7 +30,19 @@ def init_lecture_schema(client: WeaviateClient) -> Collection: Initialize the schema for the lecture slides """ if client.collections.exists(LectureSchema.COLLECTION_NAME.value): - return client.collections.get(LectureSchema.COLLECTION_NAME.value) + collection = client.collections.get(LectureSchema.COLLECTION_NAME.value) + + # collection.config.add_property( + # Property( + # name=LectureSchema.COURSE_LANGUAGE.value, + # description="The language of the COURSE", + # data_type=DataType.TEXT, + # index_searchable=False, + # ) + # ) + + return collection + return client.collections.create( name=LectureSchema.COLLECTION_NAME.value, vectorizer_config=Configure.Vectorizer.none(), @@ -56,6 +68,12 @@ def init_lecture_schema(client: WeaviateClient) -> Collection: data_type=DataType.TEXT, index_searchable=False, ), + Property( + name=LectureSchema.COURSE_LANGUAGE.value, + description="The language of the COURSE", + data_type=DataType.TEXT, + index_searchable=False, + ), Property( name=LectureSchema.LECTURE_ID.value, description="The ID of the lecture", diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index fe4bfc39..ba4cfc17 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -12,10 +12,11 @@ CompetencyExtractionPipelineExecutionDTO, ) from app.domain.chat.lecture_chat.lecture_chat_pipeline_execution_dto import LectureChatPipelineExecutionDTO +from app.pipeline.chat.lecture_chat_pipeline import LectureChatPipeline from app.web.status.status_update import ( ExerciseChatStatusCallback, CourseChatStatusCallback, - CompetencyExtractionCallback, + CompetencyExtractionCallback, LectureChatCallback, ) from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline from app.pipeline.chat.exercise_chat_pipeline import ExerciseChatPipeline @@ -122,17 +123,44 @@ def run_text_exercise_chat_pipeline_worker(dto, variant): callback.error("Fatal error.", exception=e) +def run_lecture_chat_pipeline_worker(dto, variant): + try: + callback = LectureChatCallback( + run_id=dto.settings.authentication_token, + base_url=dto.settings.artemis_base_url, + initial_stages=dto.initial_stages, + ) + match variant: + case "default" | "lecture_chat_pipeline_reference_impl": + pipeline = LectureChatPipeline(callback=callback) + case _: + raise ValueError(f"Unknown variant: {variant}") + except Exception as e: + logger.error(f"Error preparing lecture chat pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + return + + try: + pipeline(dto=dto) + except Exception as e: + logger.error(f"Error running lecture chat pipeline: {e}") + logger.error(traceback.format_exc()) + callback.error("Fatal error.", exception=e) + + @router.post( "/text-exercise-chat/{variant}/run", status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(TokenValidator())], ) def run_text_exercise_chat_pipeline( - variant: str, dto: TextExerciseChatPipelineExecutionDTO + variant: str, dto: TextExerciseChatPipelineExecutionDTO ): thread = Thread(target=run_text_exercise_chat_pipeline_worker, args=(dto, variant)) thread.start() + @router.post( "/lecture-chat/{variant}/run", status_code=status.HTTP_202_ACCEPTED, @@ -141,11 +169,12 @@ def run_text_exercise_chat_pipeline( def run_lecture_chat_pipeline( variant: str, dto: LectureChatPipelineExecutionDTO ): - thread = Thread(target=run_lecture_chat_pipeline, args=(dto, variant)) + thread = Thread(target=run_lecture_chat_pipeline_worker, args=(dto, variant)) thread.start() + def run_competency_extraction_pipeline_worker( - dto: CompetencyExtractionPipelineExecutionDTO, _variant: str + dto: CompetencyExtractionPipelineExecutionDTO, _variant: str ): try: callback = CompetencyExtractionCallback( @@ -174,7 +203,7 @@ def run_competency_extraction_pipeline_worker( dependencies=[Depends(TokenValidator())], ) def run_competency_extraction_pipeline( - variant: str, dto: CompetencyExtractionPipelineExecutionDTO + variant: str, dto: CompetencyExtractionPipelineExecutionDTO ): thread = Thread( target=run_competency_extraction_pipeline_worker, args=(dto, variant) diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index a122b6e7..8a873dd0 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -12,6 +12,9 @@ from app.domain.chat.course_chat.course_chat_status_update_dto import ( CourseChatStatusUpdateDTO, ) +from app.domain.status.lecture_chat_status_update_dto import ( + LectureChatStatusUpdateDTO, +) from app.domain.status.stage_state_dto import StageStateEnum from app.domain.status.stage_dto import StageDTO from app.domain.status.text_exercise_chat_status_update_dto import ( @@ -54,6 +57,7 @@ def __init__( def on_status_update(self): """Send a status update to the Artemis API.""" try: + print(self.url) print(self.status.dict(by_alias=True)) requests.post( self.url, @@ -277,3 +281,34 @@ def __init__( status = CompetencyExtractionStatusUpdateDTO(stages=stages) stage = stages[-1] super().__init__(url, run_id, status, stage, len(stages) - 1) + + +class LectureChatCallback(StatusCallback): + def __init__( + self, + run_id: str, + base_url: str, + initial_stages: List[StageDTO], + ): + url = f"{base_url}/api/public/pyris/pipelines/lecture-chat/runs/{run_id}/status" + stages = initial_stages or [] + stage = len(stages) + stages += [ + StageDTO( + weight=30, + state=StageStateEnum.NOT_STARTED, + name="Thinking", + ), + StageDTO( + weight=20, + state=StageStateEnum.NOT_STARTED, + name="Responding", + ), + ] + super().__init__( + url, + run_id, + LectureChatStatusUpdateDTO(stages=stages, result=""), + stages[stage], + stage, + )