Skip to content

Commit

Permalink
Add lecture pipeline connection
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianloose committed Nov 11, 2024
1 parent fa72a17 commit 6891e3c
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 8 deletions.
5 changes: 5 additions & 0 deletions app/domain/status/lecture_chat_status_update_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from app.domain.status.status_update_dto import StatusUpdateDTO


class LectureChatStatusUpdateDTO(StatusUpdateDTO):
result: str
2 changes: 2 additions & 0 deletions app/llm/request_handler/basic_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class BasicRequestHandler(RequestHandler):
def __init__(self, model_id: str):
self.model_id = model_id
self.llm_manager = LlmManager()
print("llm manager" + str(self.llm_manager.entries))

def complete(
self,
Expand All @@ -32,4 +33,5 @@ def chat(

def embed(self, text: str) -> list[float]:
llm = self.llm_manager.get_llm_by_id(self.model_id)
print("llm uiuiu" + str(llm))
return llm.embed(text)
19 changes: 17 additions & 2 deletions app/pipeline/chat/lecture_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...llm.langchain import IrisLangchainChatModel

from ..pipeline import Pipeline
from ...web.status.status_update import LectureChatCallback

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,8 +56,9 @@ class LectureChatPipeline(Pipeline):
llm: IrisLangchainChatModel
pipeline: Runnable
prompt: ChatPromptTemplate
callback: LectureChatCallback

def __init__(self):
def __init__(self, callback: LectureChatCallback):
super().__init__(implementation_id="lecture_chat_pipeline")
# Set the langchain chat model
request_handler = CapabilityRequestHandler(
Expand All @@ -66,6 +68,9 @@ def __init__(self):
privacy_compliance=True,
)
)

self.callback = callback

completion_args = CompletionArguments(temperature=0, max_tokens=2000)
self.llm = IrisLangchainChatModel(
request_handler=request_handler, completion_args=completion_args
Expand All @@ -89,7 +94,6 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO):
Runs the pipeline
:param dto: execution data transfer object
"""

self.prompt = ChatPromptTemplate.from_messages(
[
("system", lecture_initial_prompt()),
Expand All @@ -115,7 +119,13 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO):
prompt_val = self.prompt.format_messages()
self.prompt = ChatPromptTemplate.from_messages(prompt_val)
try:
self.callback.in_progress()
response = (self.prompt | self.pipeline).invoke({})
self.callback.done(
"Generated response",
final_result=response,
tokens=self.tokens,
)
self._append_tokens(self.llm.tokens, PipelineEnum.IRIS_CHAT_LECTURE_MESSAGE)
response_with_citation = self.citation_pipeline(
retrieved_lecture_chunks, response
Expand All @@ -124,6 +134,11 @@ def __call__(self, dto: LectureChatPipelineExecutionDTO):
logger.info(f"Response from lecture chat pipeline: {response}")
return response_with_citation
except Exception as e:
self.callback.error(
"Generating interaction suggestions failed.",
exception=e,
tokens=self.tokens,
)
raise e

def _add_conversation_to_prompt(
Expand Down
4 changes: 4 additions & 0 deletions app/retrieval/lecture_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,10 @@ def fetch_course_language(self, course_id):
"""
course_language = "english"

print(self.collection)
print(self.llm_embedding)
print(self.collection.query)

if course_id:
# Fetch the first object that matches the course ID with the language property
result = self.collection.query.fetch_objects(
Expand Down
20 changes: 19 additions & 1 deletion app/vector_database/lecture_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,19 @@ def init_lecture_schema(client: WeaviateClient) -> Collection:
Initialize the schema for the lecture slides
"""
if client.collections.exists(LectureSchema.COLLECTION_NAME.value):
return client.collections.get(LectureSchema.COLLECTION_NAME.value)
collection = client.collections.get(LectureSchema.COLLECTION_NAME.value)

# collection.config.add_property(
# Property(
# name=LectureSchema.COURSE_LANGUAGE.value,
# description="The language of the COURSE",
# data_type=DataType.TEXT,
# index_searchable=False,
# )
# )

return collection

return client.collections.create(
name=LectureSchema.COLLECTION_NAME.value,
vectorizer_config=Configure.Vectorizer.none(),
Expand All @@ -56,6 +68,12 @@ def init_lecture_schema(client: WeaviateClient) -> Collection:
data_type=DataType.TEXT,
index_searchable=False,
),
Property(
name=LectureSchema.COURSE_LANGUAGE.value,
description="The language of the COURSE",
data_type=DataType.TEXT,
index_searchable=False,
),
Property(
name=LectureSchema.LECTURE_ID.value,
description="The ID of the lecture",
Expand Down
39 changes: 34 additions & 5 deletions app/web/routers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
CompetencyExtractionPipelineExecutionDTO,
)
from app.domain.chat.lecture_chat.lecture_chat_pipeline_execution_dto import LectureChatPipelineExecutionDTO
from app.pipeline.chat.lecture_chat_pipeline import LectureChatPipeline
from app.web.status.status_update import (
ExerciseChatStatusCallback,
CourseChatStatusCallback,
CompetencyExtractionCallback,
CompetencyExtractionCallback, LectureChatCallback,
)
from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline
from app.pipeline.chat.exercise_chat_pipeline import ExerciseChatPipeline
Expand Down Expand Up @@ -122,17 +123,44 @@ def run_text_exercise_chat_pipeline_worker(dto, variant):
callback.error("Fatal error.", exception=e)


def run_lecture_chat_pipeline_worker(dto, variant):
try:
callback = LectureChatCallback(
run_id=dto.settings.authentication_token,
base_url=dto.settings.artemis_base_url,
initial_stages=dto.initial_stages,
)
match variant:
case "default" | "lecture_chat_pipeline_reference_impl":
pipeline = LectureChatPipeline(callback=callback)
case _:
raise ValueError(f"Unknown variant: {variant}")
except Exception as e:
logger.error(f"Error preparing lecture chat pipeline: {e}")
logger.error(traceback.format_exc())
capture_exception(e)
return

try:
pipeline(dto=dto)
except Exception as e:
logger.error(f"Error running lecture chat pipeline: {e}")
logger.error(traceback.format_exc())
callback.error("Fatal error.", exception=e)


@router.post(
"/text-exercise-chat/{variant}/run",
status_code=status.HTTP_202_ACCEPTED,
dependencies=[Depends(TokenValidator())],
)
def run_text_exercise_chat_pipeline(
variant: str, dto: TextExerciseChatPipelineExecutionDTO
variant: str, dto: TextExerciseChatPipelineExecutionDTO
):
thread = Thread(target=run_text_exercise_chat_pipeline_worker, args=(dto, variant))
thread.start()


@router.post(
"/lecture-chat/{variant}/run",
status_code=status.HTTP_202_ACCEPTED,
Expand All @@ -141,11 +169,12 @@ def run_text_exercise_chat_pipeline(
def run_lecture_chat_pipeline(
variant: str, dto: LectureChatPipelineExecutionDTO
):
thread = Thread(target=run_lecture_chat_pipeline, args=(dto, variant))
thread = Thread(target=run_lecture_chat_pipeline_worker, args=(dto, variant))
thread.start()


def run_competency_extraction_pipeline_worker(
dto: CompetencyExtractionPipelineExecutionDTO, _variant: str
dto: CompetencyExtractionPipelineExecutionDTO, _variant: str
):
try:
callback = CompetencyExtractionCallback(
Expand Down Expand Up @@ -174,7 +203,7 @@ def run_competency_extraction_pipeline_worker(
dependencies=[Depends(TokenValidator())],
)
def run_competency_extraction_pipeline(
variant: str, dto: CompetencyExtractionPipelineExecutionDTO
variant: str, dto: CompetencyExtractionPipelineExecutionDTO
):
thread = Thread(
target=run_competency_extraction_pipeline_worker, args=(dto, variant)
Expand Down
35 changes: 35 additions & 0 deletions app/web/status/status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from app.domain.chat.course_chat.course_chat_status_update_dto import (
CourseChatStatusUpdateDTO,
)
from app.domain.status.lecture_chat_status_update_dto import (
LectureChatStatusUpdateDTO,
)
from app.domain.status.stage_state_dto import StageStateEnum
from app.domain.status.stage_dto import StageDTO
from app.domain.status.text_exercise_chat_status_update_dto import (
Expand Down Expand Up @@ -54,6 +57,7 @@ def __init__(
def on_status_update(self):
"""Send a status update to the Artemis API."""
try:
print(self.url)
print(self.status.dict(by_alias=True))
requests.post(
self.url,
Expand Down Expand Up @@ -277,3 +281,34 @@ def __init__(
status = CompetencyExtractionStatusUpdateDTO(stages=stages)
stage = stages[-1]
super().__init__(url, run_id, status, stage, len(stages) - 1)


class LectureChatCallback(StatusCallback):
def __init__(
self,
run_id: str,
base_url: str,
initial_stages: List[StageDTO],
):
url = f"{base_url}/api/public/pyris/pipelines/lecture-chat/runs/{run_id}/status"
stages = initial_stages or []
stage = len(stages)
stages += [
StageDTO(
weight=30,
state=StageStateEnum.NOT_STARTED,
name="Thinking",
),
StageDTO(
weight=20,
state=StageStateEnum.NOT_STARTED,
name="Responding",
),
]
super().__init__(
url,
run_id,
LectureChatStatusUpdateDTO(stages=stages, result=""),
stages[stage],
stage,
)

0 comments on commit 6891e3c

Please sign in to comment.