From 5560cc448cd051c2ff2ee85f6524ad51aa5d1431 Mon Sep 17 00:00:00 2001 From: Aliaksandr Kuzmik <98702584+alexkuzmik@users.noreply.github.com> Date: Tue, 23 Apr 2024 19:24:41 +0200 Subject: [PATCH] community[patch]: fix CometTracer bug (#20796) Hi! My name is Alex, I'm an SDK engineer from [Comet](https://www.comet.com/site/) This PR updates the `CometTracer` class. Fixed an issue when `CometTracer` failed while logging the data to Comet because this data is not JSON-encodable. The problem was in some of the `Run` attributes that could contain non-default types inside, now these attributes are taken not from the run instance, but from the `run.dict()` return value. --- .../langchain_community/callbacks/tracers/comet.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/callbacks/tracers/comet.py b/libs/community/langchain_community/callbacks/tracers/comet.py index d3577e722e4cb..5cabdb4bd9e32 100644 --- a/libs/community/langchain_community/callbacks/tracers/comet.py +++ b/libs/community/langchain_community/callbacks/tracers/comet.py @@ -70,24 +70,26 @@ def _initialize_comet_modules(self) -> None: self._flush: Callable[[], None] = comet_llm_api.flush def _persist_run(self, run: "Run") -> None: + run_dict: Dict[str, Any] = run.dict() chain_ = self._chains_map[run.id] - chain_.set_outputs(outputs=run.outputs) + chain_.set_outputs(outputs=run_dict["outputs"]) self._chain_api.log_chain(chain_) def _process_start_trace(self, run: "Run") -> None: + run_dict: Dict[str, Any] = run.dict() if not run.parent_run_id: # This is the first run, which maps to a chain chain_: "Chain" = self._chain.Chain( - inputs=run.inputs, + inputs=run_dict["inputs"], metadata=None, experiment_info=self._experiment_info.get(), ) self._chains_map[run.id] = chain_ else: span: "Span" = self._span.Span( - inputs=run.inputs, + inputs=run_dict["inputs"], category=_get_run_type(run), - metadata=run.extra, + metadata=run_dict["extra"], name=run.name, ) span.__api__start__(self._chains_map[run.parent_run_id]) @@ -95,12 +97,13 @@ def _process_start_trace(self, run: "Run") -> None: self._span_map[run.id] = span def _process_end_trace(self, run: "Run") -> None: + run_dict: Dict[str, Any] = run.dict() if not run.parent_run_id: pass # Langchain will call _persist_run for us else: span = self._span_map[run.id] - span.set_outputs(outputs=run.outputs) + span.set_outputs(outputs=run_dict["outputs"]) span.__api__end__() def flush(self) -> None: