From 9793c6d80af98d37dd563907e410f7044c4410d6 Mon Sep 17 00:00:00 2001 From: Jan Chorowski Date: Thu, 25 Jan 2024 16:16:46 +0100 Subject: [PATCH] Make vector store more embeddable (#5491) Make vector store more embeddable: * Add new endpoints * Adding decorator for checking schemas * Add glob filters --------- Co-authored-by: Szymon Dudycz GitOrigin-RevId: b874f4ce685792953078dba86b2891b518f73883 --- .../pathway/stdlib/ml/classifiers/_knn_lsh.py | 38 ++++ .../xpacks/llm/tests/test_vector_store.py | 46 ++-- python/pathway/xpacks/llm/vector_store.py | 199 ++++++++++++------ 3 files changed, 183 insertions(+), 100 deletions(-) diff --git a/python/pathway/stdlib/ml/classifiers/_knn_lsh.py b/python/pathway/stdlib/ml/classifiers/_knn_lsh.py index 2a501ed8..3b37be0c 100644 --- a/python/pathway/stdlib/ml/classifiers/_knn_lsh.py +++ b/python/pathway/stdlib/ml/classifiers/_knn_lsh.py @@ -24,11 +24,13 @@ from __future__ import annotations +import fnmatch import logging from statistics import mode from typing import Literal import jmespath +import jmespath.functions import numpy as np # TODO change to `import pathway as pw` when it is not imported as part of stdlib, OR move the whole file to stdlib @@ -95,6 +97,41 @@ def knn_lsh_classifier_train( ) +# support for glob metadata search +def _globmatch_impl(pat_i, pat_n, pattern, p_i, p_n, path): + """Match pattern to path, recursively expanding **.""" + if pat_i == pat_n: + return p_i == p_n + if p_i == p_n: + return False + if pattern[pat_i] == "**": + return _globmatch_impl( + pat_i, pat_n, pattern, p_i + 1, p_n, path + ) or _globmatch_impl(pat_i + 1, pat_n, pattern, p_i, p_n, path) + if fnmatch.fnmatch(path[p_i], pattern[pat_i]): + return _globmatch_impl(pat_i + 1, pat_n, pattern, p_i + 1, p_n, path) + return False + + +def _globmatch(pattern, path): + """globmatch path to patter, using fnmatch at every level.""" + pattern_parts = pattern.split("/") + path_parts = path.split("/") + return _globmatch_impl( + 0, len(pattern_parts), pattern_parts, 0, len(path_parts), path_parts + ) + + +class CustomFunctions(jmespath.functions.Functions): + @jmespath.functions.signature({"types": ["string"]}, {"types": ["string"]}) + def _func_globmatch(self, pattern, string): + # Given a string, check if it matches the globbing pattern + return _globmatch(pattern, string) + + +_glob_options = jmespath.Options(custom_functions=CustomFunctions()) + + def knn_lsh_generic_classifier_train( data: pw.Table, lsh_projection, distance_function, L: int ): @@ -202,6 +239,7 @@ def knns(self) -> list[tuple[pw.Pointer, float]]: self.transformer.training_data[ id_candidate ].metadata.value, + options=_glob_options, ) is True ] diff --git a/python/pathway/xpacks/llm/tests/test_vector_store.py b/python/pathway/xpacks/llm/tests/test_vector_store.py index 96844781..76384d65 100644 --- a/python/pathway/xpacks/llm/tests/test_vector_store.py +++ b/python/pathway/xpacks/llm/tests/test_vector_store.py @@ -13,22 +13,13 @@ import pathway as pw from pathway.tests.utils import assert_table_equality -from pathway.xpacks.llm.vector_store import ( - QueryInputSchema, - StatsInputSchema, - VectorStoreClient, - VectorStoreServer, -) +from pathway.xpacks.llm.vector_store import VectorStoreClient, VectorStoreServer PATHWAY_HOST = "127.0.0.1" PATHWAY_PORT = int(os.environ.get("PATHWAY_MONITORING_HTTP_PORT", "20000")) + 20000 -class DebugStatsInputSchema(StatsInputSchema): - debug: str | None = pw.column_definition(default_value=None) - - -class DebugInputInputSchema(StatsInputSchema): +class DebugStatsInputSchema(VectorStoreServer.StatisticsQuerySchema): debug: str | None = pw.column_definition(default_value=None) @@ -45,7 +36,6 @@ def _test_vs(fake_embeddings_model): embedder=fake_embeddings_model, ) - queries = pw.debug.table_from_rows(schema=QueryInputSchema, rows=[]) info_queries = pw.debug.table_from_rows( schema=DebugStatsInputSchema, rows=[ @@ -53,16 +43,7 @@ def _test_vs(fake_embeddings_model): ], ).select() - input_queries = pw.debug.table_from_rows( - schema=DebugInputInputSchema, - rows=[ - (None,), - ], - ).select() - - info_outputs = vector_server._build_graph( - queries, info_queries, input_queries - ).info_results + info_outputs = vector_server.statistics_query(info_queries) assert_table_equality( info_outputs.select(result=pw.unwrap(pw.this.result["file_count"].as_int())), pw.debug.table_from_markdown( @@ -73,9 +54,14 @@ def _test_vs(fake_embeddings_model): ), ) - input_outputs = vector_server._build_graph( - queries, info_queries, input_queries - ).input_results + input_queries = pw.debug.table_from_rows( + schema=VectorStoreServer.InputsQuerySchema, + rows=[ + (None, "**/*.py"), + ], + ) + + input_outputs = vector_server.inputs_query(input_queries) @pw.udf def get_file_name(path) -> str: @@ -98,15 +84,13 @@ def get_file_name(path) -> str: # parse_graph.G.clear() retrieve_queries = pw.debug.table_from_markdown( """ - query | k | metadata_filter - "Foo" | 1 | + metadata_filter | filepath_globpattern | query | k + | | "Foo" | 1 """, - schema=QueryInputSchema, + schema=VectorStoreServer.RetrieveQuerySchema, ) - retrieve_outputs = vector_server._build_graph( - retrieve_queries, info_queries, input_queries - ).retrieval_results + retrieve_outputs = vector_server.retrieve_query(retrieve_queries) _, rows = pw.debug.table_to_dicts(retrieve_outputs) (val,) = rows["result"].values() assert isinstance(val, pw.Json) diff --git a/python/pathway/xpacks/llm/vector_store.py b/python/pathway/xpacks/llm/vector_store.py index 7becaee6..2b4d71f9 100644 --- a/python/pathway/xpacks/llm/vector_store.py +++ b/python/pathway/xpacks/llm/vector_store.py @@ -14,8 +14,8 @@ import json import threading from collections.abc import Callable -from dataclasses import dataclass +import jmespath import numpy as np import requests @@ -23,23 +23,7 @@ import pathway.xpacks.llm.parsers import pathway.xpacks.llm.splitters from pathway.stdlib.ml import index - - -class QueryInputSchema(pw.Schema): - query: str - k: int - metadata_filter: str | None = pw.column_definition(default_value=None) - - -class StatsInputSchema(pw.Schema): - pass - - -@dataclass -class GraphResultTables: - retrieval_results: pw.Table - info_results: pw.Table - input_results: pw.Table +from pathway.stdlib.ml.classifiers import _knn_lsh def _unwrap_udf(func): @@ -118,9 +102,9 @@ def __init__( # detect the dimensionality of the embeddings self.embedding_dimension = len(_coerce_sync(self.embedder)(".")) - def _build_graph( - self, retrieval_queries, info_queries, input_queries - ) -> GraphResultTables: + self._graph = self._build_graph() + + def _build_graph(self) -> dict: """ Builds the pathway computation graph for indexing documents and serving queries. """ @@ -188,6 +172,30 @@ def embedder(txt): metadata=chunked_docs.data["metadata"], ) + parsed_docs += parsed_docs.select( + modified=pw.this.data["metadata"]["modified_at"], + path=pw.this.data["metadata"]["path"], + ) + + stats = parsed_docs.reduce( + count=pw.reducers.count(), + last_modified=pw.reducers.max(pw.this.modified), + paths=pw.reducers.tuple(pw.this.path), + ) + return locals() + + class StatisticsQuerySchema(pw.Schema): + pass + + class QueryResultSchema(pw.Schema): + result: pw.Json + + @pw.table_transformer + def statistics_query( + self, info_queries: pw.Table[StatisticsQuerySchema] + ) -> pw.Table[QueryResultSchema]: + stats = self._graph["stats"] + # VectorStore statistics computation @pw.udf def format_stats(counts, last_modified) -> pw.Json: @@ -197,33 +205,98 @@ def format_stats(counts, last_modified) -> pw.Json: response = {"file_count": 0, "last_modified": None} return pw.Json(response) - @pw.udf - def format_inputs(paths: list[pw.Json]) -> pw.Json: - input_set = [] - if paths: - input_set = list(set([i.value for i in paths])) + info_results = info_queries.join_left(stats, id=info_queries.id).select( + result=format_stats(stats.count, stats.last_modified) + ) + return info_results - response = {"input_files": input_set} - return pw.Json(response) # type: ignore + class FilterSchema(pw.Schema): + metadata_filter: str | None = pw.column_definition(default_value=None) + filepath_globpattern: str | None = pw.column_definition(default_value=None) - parsed_docs += parsed_docs.select( - modified=pw.this.data["metadata"]["modified_at"], - path=pw.this.data["metadata"]["path"], - ) - stats = parsed_docs.reduce( - count=pw.reducers.count(), last_modified=pw.reducers.max(pw.this.modified) + @staticmethod + def merge_filters(queries: pw.Table): + @pw.udf + def _get_jmespath_filter( + metadata_filter: str, filepath_globpattern: str + ) -> str | None: + ret_parts = [] + if metadata_filter: + ret_parts.append(f"({metadata_filter})") + if filepath_globpattern: + ret_parts.append(f'globmatch(`"{filepath_globpattern}"`, path)') + if ret_parts: + return " && ".join(ret_parts) + return None + + queries = queries.without( + *VectorStoreServer.FilterSchema.__columns__.keys() + ) + queries.select( + metadata_filter=_get_jmespath_filter( + pw.this.metadata_filter, pw.this.filepath_globpattern + ) ) - inputs = parsed_docs.reduce(paths=pw.reducers.tuple(pw.this.path)) + return queries - info_results = info_queries.join_left(stats, id=info_queries.id).select( - result=format_stats(stats.count, stats.last_modified) - ) + class InputsQuerySchema(FilterSchema): + pass - input_results = input_queries.join_left(inputs, id=input_queries.id).select( - result=format_inputs(pw.this.paths) # pw.this.paths# + @pw.table_transformer + def inputs_query( + self, input_queries: pw.Table[InputsQuerySchema] + ) -> pw.Table[QueryResultSchema]: + docs = self._graph["docs"] + # TODO: compare this approach to first joining queries to dicuments, then filtering, + # then grouping to get each response. + # The "dumb" tuple approach has more work precomputed for an all inputs query + all_metas = docs.reduce(metadatas=pw.reducers.tuple(pw.this._metadata)) + + input_queries = self.merge_filters(input_queries) + + @pw.udf + def format_inputs( + metadatas: list[pw.Json] | None, metadata_filter: str | None + ) -> pw.Json: + metadatas: list = metadatas if metadatas is not None else [] # type:ignore + assert metadatas is not None + if metadata_filter: + metadatas = [ + m + for m in metadatas + if jmespath.search( + metadata_filter, m.value, options=_knn_lsh._glob_options + ) + ] + return pw.Json({"input_files": [m.value["path"] for m in metadatas]}) # type: ignore + + input_results = input_queries.join_left(all_metas, id=input_queries.id).select( + all_metas.metadatas, input_queries.metadata_filter ) + input_results = input_results.select( + result=format_inputs(pw.this.metadatas, pw.this.metadata_filter) + ) + return input_results + + # TODO: fix + # class RetrieveQuerySchema(FilterSchema): + # query: str + # k: int + # Keep a weird column ordering to match one that stems from the class inheritance above. + class RetrieveQuerySchema(pw.Schema): + metadata_filter: str | None = pw.column_definition(default_value=None) + filepath_globpattern: str | None = pw.column_definition(default_value=None) + query: str + k: int + + @pw.table_transformer + def retrieve_query( + self, retrieval_queries: pw.Table[RetrieveQuerySchema] + ) -> pw.Table[QueryResultSchema]: + embedder = self._graph["embedder"] + knn_index = self._graph["knn_index"] # Relevant document search + retrieval_queries = self.merge_filters(retrieval_queries) retrieval_queries += retrieval_queries.select( embedding=embedder(pw.this.query), ) @@ -250,7 +323,7 @@ def format_inputs(paths: list[pw.Json]) -> pw.Json: ) ) - return GraphResultTables(retrieval_results, info_results, input_results) + return retrieval_results def run_server( self, @@ -280,33 +353,21 @@ def run_server( """ webserver = pw.io.http.PathwayWebserver(host=host, port=port) - retrieval_queries, retrieval_response_writer = pw.io.http.rest_connector( - webserver=webserver, - route="/query", - schema=QueryInputSchema, - autocommit_duration_ms=50, - delete_completed_queries=True, - ) - info_queries, info_response_writer = pw.io.http.rest_connector( - webserver=webserver, - route="/stats", - schema=StatsInputSchema, - autocommit_duration_ms=50, - delete_completed_queries=True, - ) - input_queries, inputs_response_writer = pw.io.http.rest_connector( - webserver=webserver, - route="/get_inputs", - schema=StatsInputSchema, - autocommit_duration_ms=50, - delete_completed_queries=True, - ) - graph_tables = self._build_graph(retrieval_queries, info_queries, input_queries) + # TODO(move into webserver??) + def serve(route, schema, handler): + queries, writer = pw.io.http.rest_connector( + webserver=webserver, + route=route, + schema=schema, + autocommit_duration_ms=50, + delete_completed_queries=True, + ) + writer(handler(queries)) - retrieval_response_writer(graph_tables.retrieval_results) - info_response_writer(graph_tables.info_results) - inputs_response_writer(graph_tables.input_results) + serve("/v1/retrieve", self.RetrieveQuerySchema, self.retrieve_query) + serve("/v1/statistics", self.StatisticsQuerySchema, self.statistics_query) + serve("/v1/inputs", self.InputsQuerySchema, self.inputs_query) def run(): if with_cache: @@ -345,7 +406,7 @@ def query(self, query, k=3, metadata_filter=None) -> list[dict]: data = {"query": query, "k": k} if metadata_filter is not None: data["metadata_filter"] = metadata_filter - url = f"http://{self.host}:{self.port}/query" + url = f"http://{self.host}:{self.port}/v1/retrieve" response = requests.post( url, data=json.dumps(data), @@ -360,7 +421,7 @@ def query(self, query, k=3, metadata_filter=None) -> list[dict]: def get_vectorstore_statistics(self): """Fetch basic statistics about the vector store.""" - url = f"http://{self.host}:{self.port}/stats" + url = f"http://{self.host}:{self.port}/v1/statistics" response = requests.post( url, json={}, @@ -371,7 +432,7 @@ def get_vectorstore_statistics(self): def get_input_files(self): """Fetch basic statistics about the vector store.""" - url = f"http://{self.host}:{self.port}/get_inputs" + url = f"http://{self.host}:{self.port}/v1/inputs" response = requests.post( url, json={},