Skip to content

Commit

Permalink
Pipeline: Add lecture chat pipeline connection (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianloose authored Nov 27, 2024
1 parent 3e18a0b commit 5c82d5b
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 8 deletions.
12 changes: 12 additions & 0 deletions app/domain/status/lecture_chat_status_update_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from app.domain.status.status_update_dto import StatusUpdateDTO


class LectureChatStatusUpdateDTO(StatusUpdateDTO):
"""Data Transfer Object for lecture chat status updates.
This DTO extends the base StatusUpdateDTO to include the result of lecture chat
pipeline operations, facilitating communication between Artemis and the lecture
chat system.
"""

result: str
"""The result message or status of the lecture chat pipeline operation."""
28 changes: 24 additions & 4 deletions app/pipeline/chat/lecture_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...llm.langchain import IrisLangchainChatModel

from ..pipeline import Pipeline
from ...web.status.status_update import LectureChatCallback

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,17 +58,28 @@ class LectureChatPipeline(Pipeline):
llm: IrisLangchainChatModel
pipeline: Runnable
prompt: ChatPromptTemplate
callback: LectureChatCallback

def __init__(self):
def __init__(
self,
callback: LectureChatCallback,
dto: LectureChatPipelineExecutionDTO,
variant: str,
):
super().__init__(implementation_id="lecture_chat_pipeline")
# Set the langchain chat model
request_handler = CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=3.5,
gpt_version_equivalent=4.5,
context_length=16385,
privacy_compliance=True,
)
)

self.callback = callback
self.dto = dto
self.variant = variant

completion_args = CompletionArguments(temperature=0, max_tokens=2000)
self.llm = IrisLangchainChatModel(
request_handler=request_handler, completion_args=completion_args
Expand All @@ -91,7 +103,6 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO):
Runs the pipeline
:param dto: execution data transfer object
"""

self.prompt = ChatPromptTemplate.from_messages(
[
("system", lecture_initial_prompt()),
Expand Down Expand Up @@ -124,8 +135,17 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO):
)
self.tokens.extend(self.citation_pipeline.tokens)
logger.info(f"Response from lecture chat pipeline: {response}")
return response_with_citation
self.callback.done(
"Response created",
final_result=response_with_citation,
tokens=self.tokens,
)
except Exception as e:
self.callback.error(
"Generating interaction suggestions failed.",
exception=e,
tokens=self.tokens,
)
raise e

def _add_conversation_to_prompt(
Expand Down
30 changes: 26 additions & 4 deletions app/vector_database/lecture_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,26 @@ def init_lecture_schema(client: WeaviateClient) -> Collection:
if client.collections.exists(LectureSchema.COLLECTION_NAME.value):
collection = client.collections.get(LectureSchema.COLLECTION_NAME.value)
properties = collection.config.get(simple=True).properties

# Check and add 'course_language' property if missing
if not any(
property.name == LectureSchema.COURSE_LANGUAGE.value
for property in collection.config.get(simple=False).properties
):
collection.config.add_property(
Property(
name=LectureSchema.COURSE_LANGUAGE.value,
description="The language of the COURSE",
data_type=DataType.TEXT,
index_searchable=False,
)
)

# Check and add 'lecture_unit_link' property if missing
if not any(
property.__name__ == LectureSchema.LECTURE_UNIT_LINK.value
for property_found in properties
property.name == LectureSchema.LECTURE_UNIT_LINK.value
for property in properties
):
return collection
else:
collection.config.add_property(
Property(
name=LectureSchema.LECTURE_UNIT_LINK.value,
Expand All @@ -47,7 +61,9 @@ def init_lecture_schema(client: WeaviateClient) -> Collection:
index_searchable=False,
)
)

return collection

return client.collections.create(
name=LectureSchema.COLLECTION_NAME.value,
vectorizer_config=Configure.Vectorizer.none(),
Expand All @@ -73,6 +89,12 @@ def init_lecture_schema(client: WeaviateClient) -> Collection:
data_type=DataType.TEXT,
index_searchable=False,
),
Property(
name=LectureSchema.COURSE_LANGUAGE.value,
description="The language of the COURSE",
data_type=DataType.TEXT,
index_searchable=False,
),
Property(
name=LectureSchema.LECTURE_ID.value,
description="The ID of the lecture",
Expand Down
51 changes: 51 additions & 0 deletions app/web/routers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
CompetencyExtractionPipelineExecutionDTO,
)
from app.pipeline.chat.exercise_chat_agent_pipeline import ExerciseChatAgentPipeline
from app.domain.chat.lecture_chat.lecture_chat_pipeline_execution_dto import (
LectureChatPipelineExecutionDTO,
)
from app.pipeline.chat.lecture_chat_pipeline import LectureChatPipeline
from app.web.status.status_update import (
ExerciseChatStatusCallback,
CourseChatStatusCallback,
CompetencyExtractionCallback,
LectureChatCallback,
)
from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline
from app.dependencies import TokenValidator
Expand Down Expand Up @@ -139,6 +144,34 @@ def run_text_exercise_chat_pipeline_worker(dto, variant):
callback.error("Fatal error.", exception=e)


def run_lecture_chat_pipeline_worker(dto, variant):
try:
callback = LectureChatCallback(
run_id=dto.settings.authentication_token,
base_url=dto.settings.artemis_base_url,
initial_stages=dto.initial_stages,
)
match variant:
case "default" | "lecture_chat_pipeline_reference_impl":
pipeline = LectureChatPipeline(
callback=callback, dto=dto, variant=variant
)
case _:
raise ValueError(f"Unknown variant: {variant}")
except Exception as e:
logger.error(f"Error preparing lecture chat pipeline: {e}")
logger.error(traceback.format_exc())
capture_exception(e)
return

try:
pipeline(dto=dto)
except Exception as e:
logger.error(f"Error running lecture chat pipeline: {e}")
logger.error(traceback.format_exc())
callback.error("Fatal error.", exception=e)


@router.post(
"/text-exercise-chat/{variant}/run",
status_code=status.HTTP_202_ACCEPTED,
Expand All @@ -151,6 +184,16 @@ def run_text_exercise_chat_pipeline(
thread.start()


@router.post(
"/lecture-chat/{variant}/run",
status_code=status.HTTP_202_ACCEPTED,
dependencies=[Depends(TokenValidator())],
)
def run_lecture_chat_pipeline(variant: str, dto: LectureChatPipelineExecutionDTO):
thread = Thread(target=run_lecture_chat_pipeline_worker, args=(dto, variant))
thread.start()


def run_competency_extraction_pipeline_worker(
dto: CompetencyExtractionPipelineExecutionDTO, _variant: str
):
Expand Down Expand Up @@ -243,5 +286,13 @@ def get_pipeline(feature: str):
description="Default lecture ingestion variant.",
)
]
case "LECTURE_CHAT":
return [
FeatureDTO(
id="default",
name="Default Variant",
description="Default lecture chat variant.",
)
]
case _:
return Response(status_code=status.HTTP_400_BAD_REQUEST)
29 changes: 29 additions & 0 deletions app/web/status/status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from app.domain.chat.course_chat.course_chat_status_update_dto import (
CourseChatStatusUpdateDTO,
)
from app.domain.status.lecture_chat_status_update_dto import (
LectureChatStatusUpdateDTO,
)
from app.domain.status.stage_state_dto import StageStateEnum
from app.domain.status.stage_dto import StageDTO
from app.domain.status.text_exercise_chat_status_update_dto import (
Expand Down Expand Up @@ -267,3 +270,29 @@ def __init__(
status = CompetencyExtractionStatusUpdateDTO(stages=stages)
stage = stages[-1]
super().__init__(url, run_id, status, stage, len(stages) - 1)


class LectureChatCallback(StatusCallback):
def __init__(
self,
run_id: str,
base_url: str,
initial_stages: List[StageDTO],
):
url = f"{base_url}/api/public/pyris/pipelines/lecture-chat/runs/{run_id}/status"
stages = initial_stages or []
stage = len(stages)
stages += [
StageDTO(
weight=30,
state=StageStateEnum.NOT_STARTED,
name="Thinking",
),
]
super().__init__(
url,
run_id,
LectureChatStatusUpdateDTO(stages=stages, result=""),
stages[stage],
stage,
)

0 comments on commit 5c82d5b

Please sign in to comment.