Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community: Implement Mlflow tracer #17099

Closed
wants to merge 11 commits into from
166 changes: 166 additions & 0 deletions libs/community/langchain_community/callbacks/tracers/mlflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""A Tracer Implementation that records activity to MLflow."""
import logging
import os
import random
import string
import uuid
from datetime import datetime
from typing import Any, Dict, Optional

from langchain_core.env import get_runtime_environment
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run

from langchain_community.callbacks.utils import (
flatten_dict,
)

logger = logging.getLogger(__name__)

RUN_DETAILS_ORDER = [
"session_id",
"trace_id",
"id",
"parent_run_id",
"execution_order",
"child_runs",
"child_execution_order",
"name",
"run_type",
"start_time",
"end_time",
"inputs",
"outputs",
"serialized",
"serialized_object",
"events",
"extra",
"tags",
"dotted_order",
"error",
]


def import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed."""
try:
import mlflow
except ImportError:
raise ImportError(
"To use MLflowTracer you need to have the `mlflow` python "
"package installed. Please install it with `pip install mlflow -U`"
)
return mlflow


class MLflowTracer(BaseTracer):
"""Callback Handler that logs to MLflow.

This handler will log the model architecture and run traces to MLflow.
This will ensure that all LangChain activity is logged to MLflow.
"""

def __init__(
self,
tracking_uri: Optional[str] = None,
experiment_name: Optional[str] = None,
run_id: Optional[str] = None,
run_name: Optional[str] = None,
tags: Optional[Dict] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

self.mlflow = import_mlflow()
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
self.mlflow.set_tracking_uri("databricks")
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
else:
if tracking_uri:
self.mlflow.set_tracking_uri(tracking_uri)
if experiment_name:
if exp := self.mlflow.get_experiment_by_name(experiment_name):
self.mlf_expid = exp.experiment_id
else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name)

self.tags = tags or {}
if run_id is None:
if run_name is None:
run_name = "langchain-tracer-" + "".join(
random.choices(string.ascii_uppercase + string.digits, k=7)
)
run = self.mlflow.MlflowClient().create_run(
self.mlf_expid, run_name=run_name, tags=tags
)
self.run_id = run.info.run_id
else:
self.run_id = run_id
self.session_id = kwargs.get("session_id", uuid.uuid4().hex)
self.run_table = kwargs.get("run_table_name", "langchain_runs.json")
self.run_dict: Dict[str, Any] = {}

def _convert_type(self, value: Any) -> Any:
"""Convert a value to a type that can be json-serialized."""
if isinstance(value, dict):
for k, v in value.items():
value[k] = self._convert_type(v)
elif isinstance(value, list):
value = [self._convert_type(v) for v in value]
elif isinstance(value, datetime):
value = value.isoformat()
elif not isinstance(value, (str, int, float, bool, type(None))):
value = str(value)
return value

def _order_dict_by_list(self, d: Dict, order: list) -> Dict:
"""Order a dictionary by a list."""
return {k: d[k] for k in order if k in d}

def _convert_run_to_dict(self, run: Run) -> Dict:
"""Convert a Run object to a dictionary."""
run_dict = run.dict(exclude={"child_runs"})
extra = run_dict.get("extra", {})
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
run_dict["session_id"] = self.session_id
run_dict["child_runs"] = [str(run.id) for run in run.child_runs]
if run.serialized:
run_dict["serialized_object"] = flatten_dict(run.serialized)
run_dict = self._convert_type(run_dict)
# order the run dict
run_dict = self._order_dict_by_list(run_dict, RUN_DETAILS_ORDER)
return run_dict

def _log_trace_from_run(self, run_dict: Dict[str, Any]) -> None:
"""Log the trace of a run dictionary into MLflow."""
self.mlflow.log_table(run_dict, self.run_table, self.run_id)
for child_run_id in run_dict.get("child_runs", []):
if child_run_id not in self.run_dict:
logger.debug(
f"Child run {child_run_id} not found in run_dict {self.run_dict}"
)
else:
self._log_trace_from_run(self.run_dict[child_run_id])

def _persist_run(self, run: Run) -> None:
"""Persist a run."""
# Only persist the runs without a parent
run_dict = self._convert_run_to_dict(run)
self.run_dict[str(run.id)] = run_dict
self._log_trace_from_run(run_dict)

def _on_run_update(self, run: Run) -> None:
"""Process a run upon update."""
# save the run when the run trace ends
if run.parent_run_id:
self.run_dict[str(run.id)] = self._convert_run_to_dict(run)

def _reset(self) -> None:
"""Reset the tracer."""
self.run_dict = {}

def end_run(self) -> None:
"""End the run."""
self._reset()
self.mlflow.MlflowClient().set_terminated(self.run_id)
Loading