Skip to content

Commit

Permalink
feat(audit): add audit events for SimilarityIndexes (#33)
Browse files Browse the repository at this point in the history
* feat(audit): add audit events for SimilarityIndexes

* Apply suggestions from code review

Co-authored-by: Michał Pstrąg <[email protected]>

---------

Co-authored-by: Michał Pstrąg <[email protected]>
  • Loading branch information
mhordynski and micpst authored May 15, 2024
1 parent c8cfe97 commit 59e08ac
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 33 deletions.
10 changes: 5 additions & 5 deletions src/dbally/audit/event_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC
from typing import Generic, TypeVar, Union

from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent

RequestCtx = TypeVar("RequestCtx")
EventCtx = TypeVar("EventCtx")
Expand All @@ -26,13 +26,13 @@ async def request_start(self, user_request: RequestStart) -> RequestCtx:
"""

@abc.abstractmethod
async def event_start(self, event: Union[LLMEvent], request_context: RequestCtx) -> EventCtx:
async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: RequestCtx) -> EventCtx:
"""
Function that is called during every event execution.
Args:
event: LLMEvent to be logged with all the details.
event: db-ally event to be logged with all the details.
request_context: Optional context passed from request_start method
Returns:
Expand All @@ -41,13 +41,13 @@ async def event_start(self, event: Union[LLMEvent], request_context: RequestCtx)

@abc.abstractmethod
async def event_end(
self, event: Union[None, LLMEvent], request_context: RequestCtx, event_context: EventCtx
self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RequestCtx, event_context: EventCtx
) -> None:
"""
Function that is called during every event execution.
Args:
event: LLMEvent to be logged with all the details.
event: db-ally event to be logged with all the details.
request_context: Optional context passed from request_start method
event_context: Optional context passed from event_start method
"""
Expand Down
27 changes: 20 additions & 7 deletions src/dbally/audit/event_handlers/cli_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
pprint = print # type: ignore

from dbally.audit.event_handlers.base import EventHandler
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent


class CLIEventHandler(EventHandler):
"""
This handler displays all interactions between LLM and user happending during `Collection.ask`\
This handler displays all interactions between LLM and user happening during `Collection.ask`\
execution inside the terminal.
### Usage
Expand Down Expand Up @@ -57,13 +57,13 @@ async def request_start(self, user_request: RequestStart) -> None:
pprint("[grey53]\n=======================================")
pprint("[grey53]=======================================\n")

async def event_start(self, event: Union[LLMEvent], request_context: None) -> None:
async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: None) -> None:
"""
Displays information that event has started, then all messages inside the prompt
Args:
event: LLMEvent to be logged with all the details.
event: db-ally event to be logged with all the details.
request_context: Optional context passed from request_start method
"""

Expand All @@ -76,13 +76,22 @@ async def event_start(self, event: Union[LLMEvent], request_context: None) -> No
self._print_syntax(msg["content"], "text")
else:
self._print_syntax(f"{event.prompt}", "text")

async def event_end(self, event: Union[None, LLMEvent], request_context: None, event_context: None) -> None:
elif isinstance(event, SimilarityEvent):
pprint(
f"[cyan bold]Similarity event starts... \n"
f"[cyan bold]INPUT: [grey53]{event.input_value}\n"
f"[cyan bold]STORE: [grey53]{event.store}\n"
f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n"
)

async def event_end(
self, event: Union[None, LLMEvent, SimilarityEvent], request_context: None, event_context: None
) -> None:
"""
Displays the response from the LLM.
Args:
event: LLMEvent to be logged with all the details.
event: db-ally event to be logged with all the details.
request_context: Optional context passed from request_start method
event_context: Optional context passed from event_start method
"""
Expand All @@ -91,6 +100,10 @@ async def event_end(self, event: Union[None, LLMEvent], request_context: None, e
pprint(f"\n[green bold]RESPONSE: {event.response}")
pprint("[grey53]\n=======================================")
pprint("[grey53]=======================================\n")
elif isinstance(event, SimilarityEvent):
pprint(f"[green bold]OUTPUT: {event.output_value}")
pprint("[grey53]\n=======================================")
pprint("[grey53]=======================================\n")

async def request_end(self, output: RequestEnd, request_context: Optional[dict] = None) -> None:
"""
Expand Down
17 changes: 14 additions & 3 deletions src/dbally/audit/event_handlers/langsmith_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langsmith.run_trees import RunTree

from dbally.audit.event_handlers.base import EventHandler
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent


class LangSmithEventHandler(EventHandler[RunTree, RunTree]):
Expand Down Expand Up @@ -47,7 +47,7 @@ async def request_start(self, user_request: RequestStart) -> RunTree:

return run_tree

async def event_start(self, event: Union[None, LLMEvent], request_context: RunTree) -> RunTree:
async def event_start(self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree) -> RunTree:
"""
Log the start of the event.
Expand All @@ -67,12 +67,21 @@ async def event_start(self, event: Union[None, LLMEvent], request_context: RunTr
run_type="llm",
inputs={"prompts": [event.prompt]},
)
return child_run

if isinstance(event, SimilarityEvent):
child_run = request_context.create_child(
name="Similarity Lookup",
run_type="tool",
inputs={"input": event.input_value, "store": event.store, "fetcher": event.fetcher},
)
return child_run

raise ValueError("Unsupported event")

async def event_end(self, event: Union[None, LLMEvent], request_context: RunTree, event_context: RunTree) -> None:
async def event_end(
self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree, event_context: RunTree
) -> None:
"""
Log the end of the event.
Expand All @@ -83,6 +92,8 @@ async def event_end(self, event: Union[None, LLMEvent], request_context: RunTree
"""
if isinstance(event, LLMEvent):
event_context.end(outputs={"output": event.response})
elif isinstance(event, SimilarityEvent):
event_context.end(outputs={"output": event.output_value})

async def request_end(self, output: RequestEnd, request_context: RunTree) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/dbally/audit/event_span.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional
from typing import Any, Optional, Union

from dbally.data_models.audit import LLMEvent
from dbally.data_models.audit import LLMEvent, SimilarityEvent


class EventSpan:
Expand All @@ -11,7 +11,7 @@ class EventSpan:
def __init__(self) -> None:
self.data = None

def __call__(self, data: LLMEvent) -> None:
def __call__(self, data: Union[LLMEvent, SimilarityEvent]) -> None:
"""
Call method for logging events.
Expand Down
6 changes: 3 additions & 3 deletions src/dbally/audit/event_tracker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator, Dict, List, Optional
from typing import AsyncIterator, Dict, List, Optional, Union

from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.event_span import EventSpan
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent


class EventTracker:
Expand Down Expand Up @@ -69,7 +69,7 @@ def subscribe(self, event_handler: EventHandler) -> None:
self._handlers.append(event_handler)

@asynccontextmanager
async def track_event(self, event: LLMEvent) -> AsyncIterator[EventSpan]:
async def track_event(self, event: Union[LLMEvent, SimilarityEvent]) -> AsyncIterator[EventSpan]:
"""
Context manager for processing an event.
Expand Down
13 changes: 13 additions & 0 deletions src/dbally/data_models/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ class LLMEvent:
total_tokens: Optional[int] = None


@dataclass
class SimilarityEvent:
"""
SimilarityEvent is fired when a SimilarityIndex lookup is performed.
"""

store: str
fetcher: str

input_value: str
output_value: Optional[str] = None


@dataclass
class RequestStart:
"""
Expand Down
10 changes: 7 additions & 3 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
from typing import TYPE_CHECKING, Any, List, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union

from dbally.audit.event_tracker import EventTracker
from dbally.iql import syntax
from dbally.iql._exceptions import (
IQLArgumentParsingError,
Expand All @@ -20,9 +21,12 @@ class IQLProcessor:
Parses IQL string to tree structure.
"""

def __init__(self, source: str, allowed_functions: List["ExposedFunction"]):
def __init__(
self, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None
) -> None:
self.source = source
self.allowed_functions = {func.name: func for func in allowed_functions}
self._event_tracker = event_tracker or EventTracker()

async def process(self) -> syntax.Node:
"""
Expand Down Expand Up @@ -84,7 +88,7 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
arg_value = self._parse_arg(arg)

if arg_def.similarity_index:
arg_value = await arg_def.similarity_index.similar(arg_value)
arg_value = await arg_def.similarity_index.similar(arg_value, event_tracker=self._event_tracker)

check_result = validate_arg_type(arg_def.type, arg_value)

Expand Down
11 changes: 7 additions & 4 deletions src/dbally/iql/_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional

from ..audit.event_tracker import EventTracker
from . import syntax
from ._processor import IQLProcessor

Expand All @@ -18,15 +19,17 @@ def __init__(self, root: syntax.Node):
self.root = root

@classmethod
async def parse(cls, source: str, allowed_functions: List["ExposedFunction"]) -> "IQLQuery":
async def parse(
cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None
) -> "IQLQuery":
"""
Parse IQL string to IQLQuery object.
Args:
source: IQL string that needs to be parsed
allowed_functions: list of IQL functions that are allowed for this query
event_tracker: EventTracker object to track events
Returns:
IQLQuery object
"""
return cls(await IQLProcessor(source, allowed_functions).process())
return cls(await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process())
11 changes: 10 additions & 1 deletion src/dbally/similarity/chroma_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ChromadbStore(SimilarityStore):
def __init__(
self,
index_name: str,
chroma_client: chromadb.Client,
chroma_client: chromadb.ClientAPI,
embedding_function: Union[EmbeddingClient, chromadb.EmbeddingFunction],
max_distance: Optional[float] = None,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
Expand Down Expand Up @@ -94,3 +94,12 @@ async def find_similar(self, text: str) -> Optional[str]:
retrieved = collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)

def __repr__(self) -> str:
"""
Returns the string representation of the object.
Returns:
str: The string representation of the object.
"""
return f"{self.__class__.__name__}(index_name={self.index_name})"
9 changes: 9 additions & 0 deletions src/dbally/similarity/faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,12 @@ async def find_similar(self, text: str) -> Optional[str]:
data = np.load(file)
return data[best_idx]
return None

def __repr__(self) -> str:
"""
Returns the string representation of the FaissStore.
Returns:
str: The string representation of the FaissStore.
"""
return f"{self.__class__.__name__}(index_dir={self.index_dir}, index_name={self.index_name})"
19 changes: 16 additions & 3 deletions src/dbally/similarity/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import abc
from typing import Optional

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.audit import SimilarityEvent
from dbally.similarity.fetcher import SimilarityFetcher
from dbally.similarity.store import SimilarityStore

Expand All @@ -20,12 +23,13 @@ async def update(self) -> None:
"""

@abc.abstractmethod
async def similar(self, text: str) -> str:
async def similar(self, text: str, event_tracker: Optional[EventTracker] = None) -> str:
"""
Finds the most similar text or returns the original text if no similar text is found.
Args:
text: The text to find similar to.
event_tracker: The event tracker to use for auditing the similarity search.
Returns:
str: The most similar text or the original text if no similar text is found.
Expand Down Expand Up @@ -54,15 +58,24 @@ async def update(self) -> None:
data = await self.fetcher.fetch()
await self.store.store(data)

async def similar(self, text: str) -> str:
async def similar(self, text: str, event_tracker: Optional[EventTracker] = None) -> str:
"""
Finds the most similar text in the store or returns the original text if no similar text is found.
Args:
text: The text to find similar to.
event_tracker: The event tracker to use for auditing the similarity search.
Returns:
str: The most similar text or the original text if no similar text is found.
"""
found = await self.store.find_similar(text)

event_tracker = event_tracker or EventTracker()
event = SimilarityEvent(input_value=text, store=repr(self.store), fetcher=repr(self.fetcher))

async with event_tracker.track_event(event) as span:
found = await self.store.find_similar(text)
event.output_value = found
span(event)

return found if found else text
18 changes: 18 additions & 0 deletions src/dbally/similarity/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ async def fetch(self) -> List[str]:
result = conn.execute(self.get_query())
return [row[0] for row in result]

def __repr__(self) -> str:
"""
Returns a string representation of the fetcher.
Returns:
str: The string representation of the fetcher.
"""
return f"{self.__class__.__name__}()"


class SimpleSqlAlchemyFetcher(SqlAlchemyFetcher):
"""
Expand All @@ -58,6 +67,15 @@ def get_query(self) -> sqlalchemy.Select:
"""
return sqlalchemy.select(self.column).select_from(self.table).distinct()

def __repr__(self) -> str:
"""
Returns a string representation of the fetcher.
Returns:
str: The string representation of the fetcher.
"""
return f"{self.__class__.__name__}(table={self.table.name}, column={self.column.name})"


class AbstractSqlAlchemyStore(SimilarityStore, metaclass=abc.ABCMeta):
"""
Expand Down
Loading

0 comments on commit 59e08ac

Please sign in to comment.