Skip to content

Commit

Permalink
Detect graph queries and route to graph index, closes #865
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Feb 2, 2025
1 parent 2b6aa5f commit 44fb2f9
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 11 deletions.
10 changes: 5 additions & 5 deletions src/python/txtai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def count(self):

def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False):
"""
Finds documents most similar to the input query. This method will run either an index search
or an index + database search depending on if a database is available.
Finds documents most similar to the input query. This method runs an index search, index + database search
or a graph search, depending on the embeddings configuration and query.
Args:
query: input query
Expand All @@ -377,8 +377,8 @@ def search(self, query, limit=None, weights=None, index=None, parameters=None, g

def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False):
"""
Finds documents most similar to the input queries. This method will run either an index search
or an index + database search depending on if a database is available.
Finds documents most similar to the input query. This method runs an index search, index + database search
or a graph search, depending on the embeddings configuration and query.
Args:
queries: input queries
Expand All @@ -401,7 +401,7 @@ def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=
results = Search(self, indexids=graph)(queries, limit, weights, index, parameters)

# Create subgraphs using results, if necessary
return [self.graph.filter(x) for x in results] if graph else results
return [self.graph.filter(x) if isinstance(x, list) else x for x in results] if graph else results

def similarity(self, query, data):
"""
Expand Down
9 changes: 8 additions & 1 deletion src/python/txtai/embeddings/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, embeddings, indexids=False, indexonly=False):
self.database = embeddings.database
self.ids = embeddings.ids
self.indexes = embeddings.indexes
self.graph = embeddings.graph
self.query = embeddings.query
self.scoring = embeddings.scoring if embeddings.issparse() else None

Expand All @@ -52,7 +53,9 @@ def __call__(self, queries, limit=None, weights=None, index=None, parameters=Non
parameters: list of dicts of named parameters to bind to placeholders
Returns:
list of (id, score) per query for index search, list of dict per query for an index + database search
list of (id, score) per query for index search
list of dict per query for an index + database search
list of graph results for a graph index search
"""

# Default input parameters
Expand All @@ -67,6 +70,10 @@ def __call__(self, queries, limit=None, weights=None, index=None, parameters=Non
if not index and not self.ann and not self.scoring and self.indexes:
index = self.indexes.default()

# Graph search
if self.graph and self.graph.isquery(queries):
return self.graph.batchsearch(queries, limit, self.indexids)

# Database search
if not self.indexonly and self.database:
return self.dbsearch(queries, limit, weights, index, parameters)
Expand Down
13 changes: 13 additions & 0 deletions src/python/txtai/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,19 @@ def batchsearch(self, queries, limit=None, graph=False):

return [self.search(query, limit, graph) for query in queries]

def isquery(self, queries):
"""
Checks if queries are supported graph queries.
Args:
queries: queries to check
Returns:
True if all the queries are supported graph queries, False otherwise
"""

raise NotImplementedError

def communities(self, config):
"""
Run community detection on the graph.
Expand Down
4 changes: 4 additions & 0 deletions src/python/txtai/graph/networkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def search(self, query, limit=None, graph=False):

return rows

def isquery(self, queries):
# Check for required graph query clauses
return all(query and query.strip().startswith("MATCH ") and "RETURN " in query for query in queries)

def communities(self, config):
# Get community detection algorithm
algorithm = config.get("algorithm")
Expand Down
11 changes: 6 additions & 5 deletions test/python/testgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def testNotImplemented(self):
self.assertRaises(NotImplementedError, graph.pagerank)
self.assertRaises(NotImplementedError, graph.showpath, None, None)
self.assertRaises(NotImplementedError, graph.search, None)
self.assertRaises(NotImplementedError, graph.isquery, None)
self.assertRaises(NotImplementedError, graph.communities, None)
self.assertRaises(NotImplementedError, graph.load, None)
self.assertRaises(NotImplementedError, graph.save, None)
Expand Down Expand Up @@ -423,7 +424,7 @@ def testSearch(self):
self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])

# Run standard search
results = self.embeddings.graph.search(
results = self.embeddings.search(
"""
MATCH (A)-[]->(B)
RETURN A, B
Expand All @@ -432,7 +433,7 @@ def testSearch(self):
self.assertEqual(len(results), 3)

# Run path search
results = self.embeddings.graph.search(
results = self.embeddings.search(
"""
MATCH P=()-[]->()
RETURN P
Expand All @@ -441,7 +442,7 @@ def testSearch(self):
self.assertEqual(len(results), 3)

# Run graph search
g = self.embeddings.graph.search(
g = self.embeddings.search(
"""
MATCH (A)-[]->(B)
RETURN A, B
Expand All @@ -451,7 +452,7 @@ def testSearch(self):
self.assertEqual(g.count(), 3)

# Run path search
results = self.embeddings.graph.search(
results = self.embeddings.search(
"""
MATCH P=()-[]->()
RETURN P
Expand All @@ -469,7 +470,7 @@ def testSearchBatch(self):
self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])

# Run standard search
results = self.embeddings.graph.batchsearch(
results = self.embeddings.batchsearch(
[
"""
MATCH (A)-[]->(B)
Expand Down

0 comments on commit 44fb2f9

Please sign in to comment.