diff --git a/app/pipeline/chat/tutor_chat_pipeline.py b/app/pipeline/chat/tutor_chat_pipeline.py index 7efc8409..acc6997d 100644 --- a/app/pipeline/chat/tutor_chat_pipeline.py +++ b/app/pipeline/chat/tutor_chat_pipeline.py @@ -12,6 +12,7 @@ PromptTemplate, ) from langchain_core.runnables import Runnable +from weaviate.collections.classes.filters import Filter from .lecture_chat_pipeline import LectureChatPipeline from .output_models.output_models.selected_paragraphs import SelectedParagraphs @@ -87,16 +88,25 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs): :param kwargs: The keyword arguments """ try: - execution_dto = LectureChatPipelineExecutionDTO( - settings=dto.settings, course=dto.course, chatHistory=dto.chat_history - ) - lecture_chat_thread = threading.Thread( - target=self._run_lecture_chat_pipeline(execution_dto), args=(dto,) + should_execute_lecture_pipeline = self.should_execute_lecture_pipeline( + dto.course.id ) + self.lecture_chat_response = "" + if should_execute_lecture_pipeline: + execution_dto = LectureChatPipelineExecutionDTO( + settings=dto.settings, + course=dto.course, + chatHistory=dto.chat_history, + ) + lecture_chat_thread = threading.Thread( + target=self._run_lecture_chat_pipeline(execution_dto), args=(dto,) + ) + lecture_chat_thread.start() + tutor_chat_thread = threading.Thread( - target=self._run_tutor_chat_pipeline(dto), args=(dto,) + target=self._run_tutor_chat_pipeline(dto), + args=(dto, should_execute_lecture_pipeline), ) - lecture_chat_thread.start() tutor_chat_thread.start() response = self.choose_best_response( [self.tutor_chat_response, self.lecture_chat_response], @@ -150,7 +160,11 @@ def _run_lecture_chat_pipeline(self, dto: LectureChatPipelineExecutionDTO): pipeline = LectureChatPipeline() self.lecture_chat_response = pipeline(dto=dto) - def _run_tutor_chat_pipeline(self, dto: TutorChatPipelineExecutionDTO): + def _run_tutor_chat_pipeline( + self, + dto: TutorChatPipelineExecutionDTO, + should_execute_lecture_pipeline: bool = False, + ): """ Runs the pipeline :param dto: execution data transfer object @@ -206,18 +220,18 @@ def _run_tutor_chat_pipeline(self, dto: TutorChatPipelineExecutionDTO): submission, selected_files, ) - - retrieved_lecture_chunks = self.retriever( - chat_history=history, - student_query=query.contents[0].text_content, - result_limit=10, - course_name=dto.course.name, - problem_statement=problem_statement, - exercise_title=exercise_title, - course_id=dto.course.id, - base_url=dto.settings.artemis_base_url, - ) - self._add_relevant_chunks_to_prompt(retrieved_lecture_chunks) + if should_execute_lecture_pipeline: + retrieved_lecture_chunks = self.retriever( + chat_history=history, + student_query=query.contents[0].text_content, + result_limit=10, + course_name=dto.course.name, + problem_statement=problem_statement, + exercise_title=exercise_title, + course_id=dto.course.id, + base_url=dto.settings.artemis_base_url, + ) + self._add_relevant_chunks_to_prompt(retrieved_lecture_chunks) self.callback.in_progress("Generating response...") @@ -360,3 +374,21 @@ def _add_relevant_chunks_to_prompt(self, retrieved_lecture_chunks: List[dict]): self.prompt += SystemMessagePromptTemplate.from_template( "USE ONLY THE CONTENT YOU NEED TO ANSWER THE QUESTION:\n" ) + + def should_execute_lecture_pipeline(self, course_id: int) -> bool: + """ + Checks if the lecture pipeline should be executed + :param course_id: The course ID + :return: True if the lecture pipeline should be executed + """ + if course_id: + # Fetch the first object that matches the course ID with the language property + result = self.db.lectures.query.fetch_objects( + filters=Filter.by_property(LectureSchema.COURSE_ID.value).equal( + course_id + ), + limit=1, + return_properties=[LectureSchema.COURSE_NAME.value], + ) + return len(result.objects) > 0 + return False diff --git a/app/retrieval/lecture_retrieval.py b/app/retrieval/lecture_retrieval.py index c8509a4a..7ff483a0 100644 --- a/app/retrieval/lecture_retrieval.py +++ b/app/retrieval/lecture_retrieval.py @@ -111,15 +111,8 @@ def __call__( """ Retrieve lecture data from the database. """ - course_language = ( - self.collection.query.fetch_objects( - limit=1, return_properties=[LectureSchema.COURSE_LANGUAGE.value] - ) - .objects[0] - .properties.get(LectureSchema.COURSE_LANGUAGE.value) - ) + course_language = self.fetch_course_language(course_id) - # Call the function to run the tasks response, response_hyde = self.run_parallel_rewrite_tasks( chat_history=chat_history, student_query=student_query, @@ -427,3 +420,30 @@ def run_parallel_rewrite_tasks( response_hyde = response_hyde_future.result() return response, response_hyde + + def fetch_course_language(self, course_id): + """ + Fetch the language of the course based on the course ID. + If no specific language is set, it defaults to English. + """ + course_language = "english" + + if course_id: + # Fetch the first object that matches the course ID with the language property + result = self.collection.query.fetch_objects( + filters=Filter.by_property(LectureSchema.COURSE_ID.value).equal( + course_id + ), + limit=1, # We only need one object to check and retrieve the language + return_properties=[LectureSchema.COURSE_LANGUAGE.value], + ) + + # Check if the result has objects and retrieve the language + if result.objects: + fetched_language = result.objects[0].properties.get( + LectureSchema.COURSE_LANGUAGE.value + ) + if fetched_language: + course_language = fetched_language + + return course_language