Skip to content

Commit

Permalink
refactor: switch to LiteLLM (#164)
Browse files Browse the repository at this point in the history
* refactor: LLM calls to use LiteLLM
* fix: re-add and fix Otel callback_handler
  * It was throwing errors and then crashing the main process.
  * Also fixes the instrumentation logic errors
* bump version v0.6.1 -> v0.6.2
  • Loading branch information
janaka authored Nov 29, 2023
1 parent 0122615 commit cfa76cd
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 61 deletions.
25 changes: 24 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "docq"
version = "0.6.1"
version = "0.6.2"
description = "Docq.AI - private and secure knowledge insight on your data."
authors = ["Docq.AI Team <[email protected]>"]
maintainers = ["Docq.AI Team <[email protected]>"]
Expand Down Expand Up @@ -46,6 +46,7 @@ google-auth-oauthlib = "^1.1.0"
google-api-python-client = "^2.104.0"
google-auth-httplib2 = "^0.1.1"
llama-index = "^0.9.8.post1"
litellm = "^1.7.7"

[tool.poetry.group.dev.dependencies]
pre-commit = "^2.18.1"
Expand Down
4 changes: 3 additions & 1 deletion source/docq/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Optional, Self, Type

from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode

import docq

Expand Down Expand Up @@ -123,7 +124,8 @@ def _import_extensions(extensions_config_path: str = DEFAULT_EXTENSION_JSON_PATH
span.add_event("importlib.spec_from_file_location() for extension failed", {"module_name": module_name, "module_source": module_source, "class_name": class_name})
logging.error("importlib.spec_from_file_location() for extension failed. Skipping... could not find extension module '%s' at '%s'.", module_name, module_source)
except Exception as e:
span.set_status(status=trace.StatusCode.ERROR, description=str(e))
span.set_status(status=Status(StatusCode.ERROR))
span.record_exception(e)
logging.error("_import_extensions() failed hard!")
logging.error(e)
raise e
Expand Down
99 changes: 82 additions & 17 deletions source/docq/support/llamaindex_otel_callbackhandler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Llama Index callback handler for OpenTelemetry tracing."""
import inspect
import logging
import uuid
from typing import Any, Dict, List, Optional, Self
from typing import Any, Dict, List, Optional, Self, Tuple

import llama_index
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import BASE_TRACE_EVENT, CBEventType
from llama_index.callbacks.schema import CBEventType, EventPayload
from opentelemetry import trace
from opentelemetry.trace import NonRecordingSpan

logger = logging.getLogger(__name__)

class OtelCallbackHandler(BaseCallbackHandler):
"""Base callback handler that can be used to track event starts and ends."""

_spans: dict[str, Any] = {} # track Otel spans so they can be ended on the end events.
def __init__(
self: Self,
tracer_provider: trace.TracerProvider,
Expand All @@ -22,11 +25,14 @@ def __init__(
end_ignore = event_ends_to_ignore or []

self.event_starts_to_ignore = tuple(start_ignore)

self.event_ends_to_ignore = tuple(end_ignore)
#module_name, function_name = get_caller_function_and_module()
self._tracer = tracer_provider.get_tracer(instrumenting_module_name="docq.llama_index_otel_callbackhandler", instrumenting_library_version=llama_index.__version__)

self._tracer = tracer_provider.get_tracer(instrumenting_module_name="llama_index", instrumenting_library_version=llama_index.__version__)
logging.debug("OtelCallbackHandler initialized")
super().__init__(
event_starts_to_ignore=start_ignore,
event_ends_to_ignore=end_ignore,
)

def on_event_start(
self:Self,
Expand All @@ -37,8 +43,17 @@ def on_event_start(
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
trace.get_current_span().add_event(name=f"{event_type.value}_started", attributes=payload)
#logging.debug("Starting event %s, %s", event_type, event_id)
try:
#logging.debug("Starting event event_id: %s, event_type: %s, payload: %s", event_id, event_type, payload)
parent_span =self._spans[parent_id] if parent_id in self._spans else trace.get_current_span()
ctx = trace.set_span_in_context(NonRecordingSpan(parent_span.get_span_context()))
span = self._tracer.start_span(name=event_type.name, context=ctx, attributes=self._serialize_payload(payload))
span.add_event(name="callback_handler.on_event_start", attributes={"cbevent.event_id": event_id, "cbevent.parent_id": parent_id, "cbevent.event_type": event_type})
self._spans[event_id] = span

except Exception as e:
logger.error("tracer threw an error: %s", e)

return event_id

def on_event_end(
Expand All @@ -49,21 +64,71 @@ def on_event_end(
**kwargs: Any,
) -> None:
"""Run when an event ends."""
trace.get_current_span().add_event(name=f"{event_type.value}_ended", attributes=payload)

#logger.debug("Ending event - event_id: '%s', event_type: '%s', event_payload: '%s'", event_id, event_type, payload)
if event_id in self._spans:
span = self._spans.pop(event_id)
span.set_attributes(self._serialize_payload(payload))
span.add_event(name="callback_handler.on_event_end", attributes={"cbevent.event_id": event_id, "cbevent.event_type": event_type})
span.end()

def start_trace(self: Self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
trace_id = trace_id or str(uuid.uuid4())
with self._tracer.start_as_current_span(name=trace_id):
pass
logging.debug("Starting trace %s", trace_id)
if trace_id:
#logger.debug("Starting trace - trace_id: '%s'", trace_id)
current_span = trace.get_current_span()
ctx = trace.set_span_in_context(NonRecordingSpan(current_span.get_span_context()))
span = self._tracer.start_span(name=trace_id, context=ctx)
span.add_event(name="callback_handler.start_trace", attributes={"cbevent.trace_id": trace_id})
self._spans[trace_id] = span

def end_trace(
self: Self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
trace.get_current_span().end()
logging.debug("Ending trace %s", trace_id)
#logger.debug("Ending trace - trace_id: '%s'", trace_id)
#logger.debug("Ending trace - trace_map: '%s'", trace_map)
if trace_id and trace_id in self._spans:
span = self._spans.pop(trace_id)
span.add_event(name="callback_handler.end_trace", attributes={"cbevent.trace_id": trace_id})
span.end()


@staticmethod
def _serialize_payload(payload: Dict[str, Any] | None) -> Dict[str, str]:
"""Serialize payload."""
_result: Dict[str, str] = {}
try:
if payload:
if EventPayload.SERIALIZED in payload:
_result = payload[EventPayload.SERIALIZED]
else:
_result = {k: str(v) for k, v in payload.items()}
except Exception as e:
_result = {EventPayload.EXCEPTION: "error message: " + str(e)}
logger.error("tracer threw an error: %s", e)

return _result


@staticmethod
def get_caller_function_and_module() -> Tuple[str, str]:
"""Get the caller function and module."""
function_name, module_name = "", ""
# Get the frame of the caller
frame = inspect.currentframe()
if frame is not None:
_frame = frame.f_back
if _frame is not None:
# Get the name of the function of the caller
frame_info = inspect.getframeinfo(frame)
function_name = frame_info.function

# Get the name of the module of the caller
module = inspect.getmodule(frame)
if module is not None:
module_name = module.__name__

return function_name, module_name

Loading

0 comments on commit cfa76cd

Please sign in to comment.