Skip to content

Commit

Permalink
Make vector store more embeddable (#5491)
Browse files Browse the repository at this point in the history
Make vector store more embeddable:

* Add new endpoints
* Adding decorator for checking schemas
* Add glob filters

---------

Co-authored-by: Szymon Dudycz <[email protected]>
GitOrigin-RevId: b874f4ce685792953078dba86b2891b518f73883
  • Loading branch information
2 people authored and Manul from Pathway committed Jan 25, 2024
1 parent 3928bf5 commit 9793c6d
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 100 deletions.
38 changes: 38 additions & 0 deletions python/pathway/stdlib/ml/classifiers/_knn_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@

from __future__ import annotations

import fnmatch
import logging
from statistics import mode
from typing import Literal

import jmespath
import jmespath.functions
import numpy as np

# TODO change to `import pathway as pw` when it is not imported as part of stdlib, OR move the whole file to stdlib
Expand Down Expand Up @@ -95,6 +97,41 @@ def knn_lsh_classifier_train(
)


# support for glob metadata search
def _globmatch_impl(pat_i, pat_n, pattern, p_i, p_n, path):
"""Match pattern to path, recursively expanding **."""
if pat_i == pat_n:
return p_i == p_n
if p_i == p_n:
return False
if pattern[pat_i] == "**":
return _globmatch_impl(
pat_i, pat_n, pattern, p_i + 1, p_n, path
) or _globmatch_impl(pat_i + 1, pat_n, pattern, p_i, p_n, path)
if fnmatch.fnmatch(path[p_i], pattern[pat_i]):
return _globmatch_impl(pat_i + 1, pat_n, pattern, p_i + 1, p_n, path)
return False


def _globmatch(pattern, path):
"""globmatch path to patter, using fnmatch at every level."""
pattern_parts = pattern.split("/")
path_parts = path.split("/")
return _globmatch_impl(
0, len(pattern_parts), pattern_parts, 0, len(path_parts), path_parts
)


class CustomFunctions(jmespath.functions.Functions):
@jmespath.functions.signature({"types": ["string"]}, {"types": ["string"]})
def _func_globmatch(self, pattern, string):
# Given a string, check if it matches the globbing pattern
return _globmatch(pattern, string)


_glob_options = jmespath.Options(custom_functions=CustomFunctions())


def knn_lsh_generic_classifier_train(
data: pw.Table, lsh_projection, distance_function, L: int
):
Expand Down Expand Up @@ -202,6 +239,7 @@ def knns(self) -> list[tuple[pw.Pointer, float]]:
self.transformer.training_data[
id_candidate
].metadata.value,
options=_glob_options,
)
is True
]
Expand Down
46 changes: 15 additions & 31 deletions python/pathway/xpacks/llm/tests/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,13 @@

import pathway as pw
from pathway.tests.utils import assert_table_equality
from pathway.xpacks.llm.vector_store import (
QueryInputSchema,
StatsInputSchema,
VectorStoreClient,
VectorStoreServer,
)
from pathway.xpacks.llm.vector_store import VectorStoreClient, VectorStoreServer

PATHWAY_HOST = "127.0.0.1"
PATHWAY_PORT = int(os.environ.get("PATHWAY_MONITORING_HTTP_PORT", "20000")) + 20000


class DebugStatsInputSchema(StatsInputSchema):
debug: str | None = pw.column_definition(default_value=None)


class DebugInputInputSchema(StatsInputSchema):
class DebugStatsInputSchema(VectorStoreServer.StatisticsQuerySchema):
debug: str | None = pw.column_definition(default_value=None)


Expand All @@ -45,24 +36,14 @@ def _test_vs(fake_embeddings_model):
embedder=fake_embeddings_model,
)

queries = pw.debug.table_from_rows(schema=QueryInputSchema, rows=[])
info_queries = pw.debug.table_from_rows(
schema=DebugStatsInputSchema,
rows=[
(None,),
],
).select()

input_queries = pw.debug.table_from_rows(
schema=DebugInputInputSchema,
rows=[
(None,),
],
).select()

info_outputs = vector_server._build_graph(
queries, info_queries, input_queries
).info_results
info_outputs = vector_server.statistics_query(info_queries)
assert_table_equality(
info_outputs.select(result=pw.unwrap(pw.this.result["file_count"].as_int())),
pw.debug.table_from_markdown(
Expand All @@ -73,9 +54,14 @@ def _test_vs(fake_embeddings_model):
),
)

input_outputs = vector_server._build_graph(
queries, info_queries, input_queries
).input_results
input_queries = pw.debug.table_from_rows(
schema=VectorStoreServer.InputsQuerySchema,
rows=[
(None, "**/*.py"),
],
)

input_outputs = vector_server.inputs_query(input_queries)

@pw.udf
def get_file_name(path) -> str:
Expand All @@ -98,15 +84,13 @@ def get_file_name(path) -> str:
# parse_graph.G.clear()
retrieve_queries = pw.debug.table_from_markdown(
"""
query | k | metadata_filter
"Foo" | 1 |
metadata_filter | filepath_globpattern | query | k
| | "Foo" | 1
""",
schema=QueryInputSchema,
schema=VectorStoreServer.RetrieveQuerySchema,
)

retrieve_outputs = vector_server._build_graph(
retrieve_queries, info_queries, input_queries
).retrieval_results
retrieve_outputs = vector_server.retrieve_query(retrieve_queries)
_, rows = pw.debug.table_to_dicts(retrieve_outputs)
(val,) = rows["result"].values()
assert isinstance(val, pw.Json)
Expand Down
Loading

0 comments on commit 9793c6d

Please sign in to comment.