Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Populate the node assistant's trace service with the existing callback #1091

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/pipelines/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def _process(self, state: PipelineState) -> PipelineState:
_config: RunnableConfig | None = None
name: str = Field(title="Node Name", json_schema_extra={"ui:widget": "node_name"})

def process(self, node_id: str, incoming_edges: list, state: PipelineState, config) -> PipelineState:
def process(
self, node_id: str, incoming_edges: list, state: PipelineState, config: RunnableConfig
) -> PipelineState:
from apps.channels.datamodels import Attachment

self._config = config
Expand Down
7 changes: 5 additions & 2 deletions apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def _process(self, input, state: PipelineState, node_id: str, **kwargs) -> Pipel
session: ExperimentSession | None = state.get("experiment_session")
runnable = self._get_assistant_runnable(assistant, session=session, node_id=node_id)
attachments = self._get_attachments(state)
chain_output: ChainOutput = runnable.invoke(input, config={}, attachments=attachments)
chain_output: ChainOutput = runnable.invoke(input, config=self._config, attachments=attachments)
output = chain_output.output

return PipelineState.from_node_output(
Expand All @@ -699,8 +699,11 @@ def _get_attachments(self, state) -> list:
return [att for att in state.get("shared_state", {}).get("attachments", []) if att.upload_to_assistant]

def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: ExperimentSession, node_id: str):
trace_service = session.experiment.trace_service
trace_service.from_callback_manager(self._config.get("callbacks"))

history_manager = PipelineHistoryManager.for_assistant()
adapter = AssistantAdapter.for_pipeline(session=session, node=self)
adapter = AssistantAdapter.for_pipeline(session=session, node=self, trace_service=trace_service)
if assistant.tools_enabled:
return AgentAssistantChat(adapter=adapter, history_manager=history_manager)
else:
Expand Down
5 changes: 2 additions & 3 deletions apps/service_providers/llm_service/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,14 @@ def for_experiment(experiment: Experiment, session: ExperimentSession, trace_ser
)

@staticmethod
def for_pipeline(session: ExperimentSession, node: "AssistantNode") -> Self:
def for_pipeline(session: ExperimentSession, node: "AssistantNode", trace_service=None) -> Self:
assistant = OpenAiAssistant.objects.get(id=node.assistant_id)
experiment = session.experiment
return AssistantAdapter(
session=session,
assistant=assistant,
citations_enabled=node.citations_enabled,
input_formatter=node.input_formatter,
trace_service=experiment.trace_service,
trace_service=trace_service,
save_message_metadata_only=True,
)

Expand Down
15 changes: 15 additions & 0 deletions apps/service_providers/tracing/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.tracers import LangChainTracer
from pydantic import BaseModel

Expand Down Expand Up @@ -29,6 +30,9 @@ def update_trace(self, metadata: dict):
def get_current_trace_info(self) -> TraceInfo | None:
return None

def from_callback_manager(self, callback_manager: CallbackManager):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename

Suggested change
def from_callback_manager(self, callback_manager: CallbackManager):
def initialize_from_callback_manager(self, callback_manager: CallbackManager):

pass


class LangFuseTraceService(TraceService):
"""
Expand All @@ -51,6 +55,17 @@ def get_callback(self, participant_id: str, session_id: str):
self._callback = CallbackHandler(user_id=participant_id, session_id=session_id, **self.config)
return self._callback

def from_callback_manager(self, callback_manager: CallbackManager):
"""
Populates the callback from the callback handler already configured in `callback_manager`. This allows the trace
service to reuse existing callbacks.
"""
from langfuse.callback import CallbackHandler

for handler in callback_manager.handlers:
if isinstance(handler, CallbackHandler):
self._callback = handler

def update_trace(self, metadata: dict):
if not metadata:
return
Expand Down
Loading