From 59e08ac7e72fa4f946c7229872fea70d277ef411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= <26008518+mhordynski@users.noreply.github.com> Date: Wed, 15 May 2024 16:44:35 +0200 Subject: [PATCH] feat(audit): add audit events for SimilarityIndexes (#33) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(audit): add audit events for SimilarityIndexes * Apply suggestions from code review Co-authored-by: Michał Pstrąg --------- Co-authored-by: Michał Pstrąg --- src/dbally/audit/event_handlers/base.py | 10 +++---- .../audit/event_handlers/cli_event_handler.py | 27 ++++++++++++++----- .../event_handlers/langsmith_event_handler.py | 17 +++++++++--- src/dbally/audit/event_span.py | 6 ++--- src/dbally/audit/event_tracker.py | 6 ++--- src/dbally/data_models/audit.py | 13 +++++++++ src/dbally/iql/_processor.py | 10 ++++--- src/dbally/iql/_query.py | 11 +++++--- src/dbally/similarity/chroma_store.py | 11 +++++++- src/dbally/similarity/faiss_store.py | 9 +++++++ src/dbally/similarity/index.py | 19 ++++++++++--- src/dbally/similarity/sqlalchemy_base.py | 18 +++++++++++++ src/dbally/views/structured.py | 2 +- 13 files changed, 126 insertions(+), 33 deletions(-) diff --git a/src/dbally/audit/event_handlers/base.py b/src/dbally/audit/event_handlers/base.py index 05d16cf2..10fce0cf 100644 --- a/src/dbally/audit/event_handlers/base.py +++ b/src/dbally/audit/event_handlers/base.py @@ -2,7 +2,7 @@ from abc import ABC from typing import Generic, TypeVar, Union -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart +from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent RequestCtx = TypeVar("RequestCtx") EventCtx = TypeVar("EventCtx") @@ -26,13 +26,13 @@ async def request_start(self, user_request: RequestStart) -> RequestCtx: """ @abc.abstractmethod - async def event_start(self, event: Union[LLMEvent], request_context: RequestCtx) -> EventCtx: + async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: RequestCtx) -> EventCtx: """ Function that is called during every event execution. Args: - event: LLMEvent to be logged with all the details. + event: db-ally event to be logged with all the details. request_context: Optional context passed from request_start method Returns: @@ -41,13 +41,13 @@ async def event_start(self, event: Union[LLMEvent], request_context: RequestCtx) @abc.abstractmethod async def event_end( - self, event: Union[None, LLMEvent], request_context: RequestCtx, event_context: EventCtx + self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RequestCtx, event_context: EventCtx ) -> None: """ Function that is called during every event execution. Args: - event: LLMEvent to be logged with all the details. + event: db-ally event to be logged with all the details. request_context: Optional context passed from request_start method event_context: Optional context passed from event_start method """ diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index fba1d3be..15583f1d 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -12,12 +12,12 @@ pprint = print # type: ignore from dbally.audit.event_handlers.base import EventHandler -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart +from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent class CLIEventHandler(EventHandler): """ - This handler displays all interactions between LLM and user happending during `Collection.ask`\ + This handler displays all interactions between LLM and user happening during `Collection.ask`\ execution inside the terminal. ### Usage @@ -57,13 +57,13 @@ async def request_start(self, user_request: RequestStart) -> None: pprint("[grey53]\n=======================================") pprint("[grey53]=======================================\n") - async def event_start(self, event: Union[LLMEvent], request_context: None) -> None: + async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: None) -> None: """ Displays information that event has started, then all messages inside the prompt Args: - event: LLMEvent to be logged with all the details. + event: db-ally event to be logged with all the details. request_context: Optional context passed from request_start method """ @@ -76,13 +76,22 @@ async def event_start(self, event: Union[LLMEvent], request_context: None) -> No self._print_syntax(msg["content"], "text") else: self._print_syntax(f"{event.prompt}", "text") - - async def event_end(self, event: Union[None, LLMEvent], request_context: None, event_context: None) -> None: + elif isinstance(event, SimilarityEvent): + pprint( + f"[cyan bold]Similarity event starts... \n" + f"[cyan bold]INPUT: [grey53]{event.input_value}\n" + f"[cyan bold]STORE: [grey53]{event.store}\n" + f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n" + ) + + async def event_end( + self, event: Union[None, LLMEvent, SimilarityEvent], request_context: None, event_context: None + ) -> None: """ Displays the response from the LLM. Args: - event: LLMEvent to be logged with all the details. + event: db-ally event to be logged with all the details. request_context: Optional context passed from request_start method event_context: Optional context passed from event_start method """ @@ -91,6 +100,10 @@ async def event_end(self, event: Union[None, LLMEvent], request_context: None, e pprint(f"\n[green bold]RESPONSE: {event.response}") pprint("[grey53]\n=======================================") pprint("[grey53]=======================================\n") + elif isinstance(event, SimilarityEvent): + pprint(f"[green bold]OUTPUT: {event.output_value}") + pprint("[grey53]\n=======================================") + pprint("[grey53]=======================================\n") async def request_end(self, output: RequestEnd, request_context: Optional[dict] = None) -> None: """ diff --git a/src/dbally/audit/event_handlers/langsmith_event_handler.py b/src/dbally/audit/event_handlers/langsmith_event_handler.py index 498ea09b..5974a068 100644 --- a/src/dbally/audit/event_handlers/langsmith_event_handler.py +++ b/src/dbally/audit/event_handlers/langsmith_event_handler.py @@ -6,7 +6,7 @@ from langsmith.run_trees import RunTree from dbally.audit.event_handlers.base import EventHandler -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart +from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent class LangSmithEventHandler(EventHandler[RunTree, RunTree]): @@ -47,7 +47,7 @@ async def request_start(self, user_request: RequestStart) -> RunTree: return run_tree - async def event_start(self, event: Union[None, LLMEvent], request_context: RunTree) -> RunTree: + async def event_start(self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree) -> RunTree: """ Log the start of the event. @@ -67,12 +67,21 @@ async def event_start(self, event: Union[None, LLMEvent], request_context: RunTr run_type="llm", inputs={"prompts": [event.prompt]}, ) + return child_run + if isinstance(event, SimilarityEvent): + child_run = request_context.create_child( + name="Similarity Lookup", + run_type="tool", + inputs={"input": event.input_value, "store": event.store, "fetcher": event.fetcher}, + ) return child_run raise ValueError("Unsupported event") - async def event_end(self, event: Union[None, LLMEvent], request_context: RunTree, event_context: RunTree) -> None: + async def event_end( + self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree, event_context: RunTree + ) -> None: """ Log the end of the event. @@ -83,6 +92,8 @@ async def event_end(self, event: Union[None, LLMEvent], request_context: RunTree """ if isinstance(event, LLMEvent): event_context.end(outputs={"output": event.response}) + elif isinstance(event, SimilarityEvent): + event_context.end(outputs={"output": event.output_value}) async def request_end(self, output: RequestEnd, request_context: RunTree) -> None: """ diff --git a/src/dbally/audit/event_span.py b/src/dbally/audit/event_span.py index e4a1b728..c7cba584 100644 --- a/src/dbally/audit/event_span.py +++ b/src/dbally/audit/event_span.py @@ -1,6 +1,6 @@ -from typing import Any, Optional +from typing import Any, Optional, Union -from dbally.data_models.audit import LLMEvent +from dbally.data_models.audit import LLMEvent, SimilarityEvent class EventSpan: @@ -11,7 +11,7 @@ class EventSpan: def __init__(self) -> None: self.data = None - def __call__(self, data: LLMEvent) -> None: + def __call__(self, data: Union[LLMEvent, SimilarityEvent]) -> None: """ Call method for logging events. diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index d40a08b1..c483a65e 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -1,9 +1,9 @@ from contextlib import asynccontextmanager -from typing import AsyncIterator, Dict, List, Optional +from typing import AsyncIterator, Dict, List, Optional, Union from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_span import EventSpan -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart +from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent class EventTracker: @@ -69,7 +69,7 @@ def subscribe(self, event_handler: EventHandler) -> None: self._handlers.append(event_handler) @asynccontextmanager - async def track_event(self, event: LLMEvent) -> AsyncIterator[EventSpan]: + async def track_event(self, event: Union[LLMEvent, SimilarityEvent]) -> AsyncIterator[EventSpan]: """ Context manager for processing an event. diff --git a/src/dbally/data_models/audit.py b/src/dbally/data_models/audit.py index 4dc466ad..3315360f 100644 --- a/src/dbally/data_models/audit.py +++ b/src/dbally/data_models/audit.py @@ -29,6 +29,19 @@ class LLMEvent: total_tokens: Optional[int] = None +@dataclass +class SimilarityEvent: + """ + SimilarityEvent is fired when a SimilarityIndex lookup is performed. + """ + + store: str + fetcher: str + + input_value: str + output_value: Optional[str] = None + + @dataclass class RequestStart: """ diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index a11e45f1..8127ddfe 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,6 +1,7 @@ import ast -from typing import TYPE_CHECKING, Any, List, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union +from dbally.audit.event_tracker import EventTracker from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, @@ -20,9 +21,12 @@ class IQLProcessor: Parses IQL string to tree structure. """ - def __init__(self, source: str, allowed_functions: List["ExposedFunction"]): + def __init__( + self, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None + ) -> None: self.source = source self.allowed_functions = {func.name: func for func in allowed_functions} + self._event_tracker = event_tracker or EventTracker() async def process(self) -> syntax.Node: """ @@ -84,7 +88,7 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: arg_value = self._parse_arg(arg) if arg_def.similarity_index: - arg_value = await arg_def.similarity_index.similar(arg_value) + arg_value = await arg_def.similarity_index.similar(arg_value, event_tracker=self._event_tracker) check_result = validate_arg_type(arg_def.type, arg_value) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index ad252d51..7ad86490 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,5 +1,6 @@ -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional +from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLProcessor @@ -18,15 +19,17 @@ def __init__(self, root: syntax.Node): self.root = root @classmethod - async def parse(cls, source: str, allowed_functions: List["ExposedFunction"]) -> "IQLQuery": + async def parse( + cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None + ) -> "IQLQuery": """ Parse IQL string to IQLQuery object. Args: source: IQL string that needs to be parsed allowed_functions: list of IQL functions that are allowed for this query - + event_tracker: EventTracker object to track events Returns: IQLQuery object """ - return cls(await IQLProcessor(source, allowed_functions).process()) + return cls(await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process()) diff --git a/src/dbally/similarity/chroma_store.py b/src/dbally/similarity/chroma_store.py index f1684a44..53ef657a 100644 --- a/src/dbally/similarity/chroma_store.py +++ b/src/dbally/similarity/chroma_store.py @@ -13,7 +13,7 @@ class ChromadbStore(SimilarityStore): def __init__( self, index_name: str, - chroma_client: chromadb.Client, + chroma_client: chromadb.ClientAPI, embedding_function: Union[EmbeddingClient, chromadb.EmbeddingFunction], max_distance: Optional[float] = None, distance_method: Literal["l2", "ip", "cosine"] = "l2", @@ -94,3 +94,12 @@ async def find_similar(self, text: str) -> Optional[str]: retrieved = collection.query(query_texts=[text], n_results=1) return self._return_best_match(retrieved) + + def __repr__(self) -> str: + """ + Returns the string representation of the object. + + Returns: + str: The string representation of the object. + """ + return f"{self.__class__.__name__}(index_name={self.index_name})" diff --git a/src/dbally/similarity/faiss_store.py b/src/dbally/similarity/faiss_store.py index 7839861f..7f43e34b 100644 --- a/src/dbally/similarity/faiss_store.py +++ b/src/dbally/similarity/faiss_store.py @@ -92,3 +92,12 @@ async def find_similar(self, text: str) -> Optional[str]: data = np.load(file) return data[best_idx] return None + + def __repr__(self) -> str: + """ + Returns the string representation of the FaissStore. + + Returns: + str: The string representation of the FaissStore. + """ + return f"{self.__class__.__name__}(index_dir={self.index_dir}, index_name={self.index_name})" diff --git a/src/dbally/similarity/index.py b/src/dbally/similarity/index.py index d83085c2..6895c566 100644 --- a/src/dbally/similarity/index.py +++ b/src/dbally/similarity/index.py @@ -1,5 +1,8 @@ import abc +from typing import Optional +from dbally.audit.event_tracker import EventTracker +from dbally.data_models.audit import SimilarityEvent from dbally.similarity.fetcher import SimilarityFetcher from dbally.similarity.store import SimilarityStore @@ -20,12 +23,13 @@ async def update(self) -> None: """ @abc.abstractmethod - async def similar(self, text: str) -> str: + async def similar(self, text: str, event_tracker: Optional[EventTracker] = None) -> str: """ Finds the most similar text or returns the original text if no similar text is found. Args: text: The text to find similar to. + event_tracker: The event tracker to use for auditing the similarity search. Returns: str: The most similar text or the original text if no similar text is found. @@ -54,15 +58,24 @@ async def update(self) -> None: data = await self.fetcher.fetch() await self.store.store(data) - async def similar(self, text: str) -> str: + async def similar(self, text: str, event_tracker: Optional[EventTracker] = None) -> str: """ Finds the most similar text in the store or returns the original text if no similar text is found. Args: text: The text to find similar to. + event_tracker: The event tracker to use for auditing the similarity search. Returns: str: The most similar text or the original text if no similar text is found. """ - found = await self.store.find_similar(text) + + event_tracker = event_tracker or EventTracker() + event = SimilarityEvent(input_value=text, store=repr(self.store), fetcher=repr(self.fetcher)) + + async with event_tracker.track_event(event) as span: + found = await self.store.find_similar(text) + event.output_value = found + span(event) + return found if found else text diff --git a/src/dbally/similarity/sqlalchemy_base.py b/src/dbally/similarity/sqlalchemy_base.py index 0523936d..7f066804 100644 --- a/src/dbally/similarity/sqlalchemy_base.py +++ b/src/dbally/similarity/sqlalchemy_base.py @@ -36,6 +36,15 @@ async def fetch(self) -> List[str]: result = conn.execute(self.get_query()) return [row[0] for row in result] + def __repr__(self) -> str: + """ + Returns a string representation of the fetcher. + + Returns: + str: The string representation of the fetcher. + """ + return f"{self.__class__.__name__}()" + class SimpleSqlAlchemyFetcher(SqlAlchemyFetcher): """ @@ -58,6 +67,15 @@ def get_query(self) -> sqlalchemy.Select: """ return sqlalchemy.select(self.column).select_from(self.table).distinct() + def __repr__(self) -> str: + """ + Returns a string representation of the fetcher. + + Returns: + str: The string representation of the fetcher. + """ + return f"{self.__class__.__name__}(table={self.table.name}, column={self.column.name})" + class AbstractSqlAlchemyStore(SimilarityStore, metaclass=abc.ABCMeta): """ diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 4588c925..4cead890 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -65,7 +65,7 @@ async def ask( for _ in range(n_retries): try: - filters = await IQLQuery.parse(iql_filters, filter_list) + filters = await IQLQuery.parse(iql_filters, filter_list, event_tracker=event_tracker) await self.apply_filters(filters) break except (IQLError, ValueError) as e: