Skip to content

Commit

Permalink
reverse search
Browse files Browse the repository at this point in the history
  • Loading branch information
Behrad Babaee committed Dec 17, 2024
1 parent 6b88aeb commit 71e808b
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions graph/basic-graph-vector-search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,6 @@
default=10,
help="number of items in a dimension",
)
arg_parser.add_argument(
"--search-count",
dest="search_count",
required=False,
default=10,
help="number of random searches at the end",
)
arg_parser.add_argument(
"--load-balancer",
dest="load_balancer",
Expand All @@ -105,9 +98,9 @@ def vector_space_builder(current_list, current_iteration):
list.append(newItem)
return list

def vertex_builder(fake, dimensions):
def vertex_builder(fake):
result = []
for i in range(dimensions):
for i in range(args.dimensions):
list = []
list.append(i)
list.append(fake.job())
Expand All @@ -126,7 +119,7 @@ def insert_data(gClient, vClient, vectors):
for v in vectors:
key = ','.join( str(x) for x in v )
vClient.upsert(namespace=args.namespace, set_name=args.set, key=key, record_data={ "vector": v } )
person = gClient.add_v('Person').property(T.id, key).property('name', fake.name()).property('ip', fake.ipv4_private()).next()
person = gClient.add_v('Person').property(T.id, key).property('Name', fake.name()).property('IP', fake.ipv4_private()).next()
gClient.V(person).addE("HAS_JOB").to(gClient.V(random.randint(0, args.dimensions-1)).next()).next()

m_secs = round((perf_counter_ns() - start) / 10 ** 6, 3)
Expand All @@ -139,22 +132,38 @@ def wait_for_index(vClient):
m_secs = round((perf_counter_ns() - start) / 10 ** 6, 3)
print(f"Indexing took: {m_secs} milliseconds")

def query_random(vClient, gClient):
print("querying")
for i in range(args.search_count):
def query_vector_and_then_graph(vClient, gClient):
print("Querying Vector and then Graph:")
for i in range(args.dimensions):
v = []
for j in range(args.dimensions):
v.append(random.uniform(0, args.number_of_items_in_each_dimesnsion))

key = ','.join(map(str, v))
start = perf_counter_ns()
results = vClient.vector_search(namespace=args.namespace, index_name=args.index_name, query=v, limit=args.dimensions*2 + 1)
results = vClient.vector_search(namespace=args.namespace, index_name=args.index_name, query=v, limit=args.dimensions + 1)
m_secs = round((perf_counter_ns() - start) / 10 ** 6, 3)
print(f"Querying [{key}] took: {m_secs} milliseconds")

for result in results:
print(str(gClient.V(result.key.key).element_map().to_list()) + " -> " + str(gClient.V(result.key.key).out("HAS_JOB").value_map().to_list()))

def query_graph_and_then_vector(vClient, gClient, jobs):
print("Querying Graph and then Vector:")
jobId = jobs[random.randint(0,args.dimensions-1)]
print("Finding people who are " + gClient.V(jobId).values("Job").next())
vertices = gClient.V(gClient.V(jobId).next()).inE().outV().to_list()
people = list(map(lambda v: v.id, vertices))

for j in range(args.dimensions):
personId = people[random.randint(0,len(people)-1)]
personVertex = gClient.V(personId).element_map().to_list()[0]
personVector = [float(i) for i in personId.split(",")]
print("A vertex representing of a random person with that job is: " + str(personVertex) )
print("Here are few people close to them in vector space:")
results = vClient.vector_search(namespace=args.namespace, index_name=args.index_name, query=personVector, limit=args.dimensions + 1)
for result in results:
print(str(gClient.V(result.key.key).element_map().to_list()))

print("Clearing the environment and setting up!")
with AdminClient(seeds=types.HostPort(host=args.host, port=args.port), is_loadbalancer=args.load_balancer) as adminClient:
try:
Expand All @@ -163,9 +172,7 @@ def query_random(vClient, gClient):
adminClient.index_drop(namespace=args.namespace, name=args.index_name, timeout=60)
except Exception as e:
pass

sys.stderr = old_stderr # reset old stderr

try:
adminClient.index_create(namespace=args.namespace, name=args.index_name, vector_field="vector", dimensions=args.dimensions, sets=args.set, index_storage=types.IndexStorage(namespace=args.index_namespace, set_name=args.index_set))
except Exception as e:
Expand All @@ -174,7 +181,7 @@ def query_random(vClient, gClient):

fake = Faker()
vectors = vector_space_builder([], 1)
jobs = vertex_builder(fake, args.dimensions)
jobs = vertex_builder(fake)
vClient = Client(seeds=types.HostPort(host=args.host, port=args.port), is_loadbalancer=args.load_balancer)
gClient = traversal().with_remote(DriverRemoteConnection('ws://localhost:8182/gremlin', 'g'))
gClient.V().drop().iterate()
Expand All @@ -183,4 +190,5 @@ def query_random(vClient, gClient):
insert_jobs(gClient, jobs)
insert_data(gClient, vClient, vectors)
wait_for_index(vClient)
query_random(vClient, gClient)
query_vector_and_then_graph(vClient, gClient)
query_graph_and_then_vector(vClient, gClient, jobs)

0 comments on commit 71e808b

Please sign in to comment.