Skip to content

Commit

Permalink
Merge branch 'main' into cilint
Browse files Browse the repository at this point in the history
  • Loading branch information
kemingy committed Nov 1, 2024
2 parents 2f6975a + c22885f commit 591c192
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 32 deletions.
60 changes: 32 additions & 28 deletions bench/index.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import asyncio
from sys import version_info
from time import perf_counter
import argparse
from pathlib import Path
import multiprocessing

if version_info >= (3, 12):
raise RuntimeError("h5py doesn't support 3.12")

import psycopg
import h5py
from pgvector.psycopg import register_vector_async
Expand Down Expand Up @@ -47,7 +43,7 @@ def build_arg_parse():
help="Workers to build index",
type=int,
required=False,
default=max(multiprocessing.cpu_count() - 2, 1),
default=max(multiprocessing.cpu_count() - 1, 1),
)
return parser

Expand Down Expand Up @@ -87,9 +83,9 @@ def get_ivf_ops_config(metric, k, name=None):
return metric_ops, ivf_config


async def create_connection(password):
async def create_connection(url):
conn = await psycopg.AsyncConnection.connect(
conninfo=f"postgresql://postgres:{password}@localhost:5432/postgres",
conninfo=url,
dbname="postgres",
autocommit=True,
**KEEPALIVE_KWARGS,
Expand All @@ -113,8 +109,8 @@ async def add_centroids(conn, name, centroids):
copy.set_types(["vector"])
for centroid in tqdm(centroids, desc="Adding centroids"):
await copy.write_row((centroid,))
while conn.pgconn.flush() == 1:
pass
while conn.pgconn.flush() == 1:
await asyncio.sleep(0)


async def add_embeddings(conn, name, dim, train):
Expand All @@ -130,44 +126,50 @@ async def add_embeddings(conn, name, dim, train):
enumerate(train), desc="Adding embeddings", total=len(train)
):
await copy.write_row((i, vec))
while conn.pgconn.flush() == 1:
pass
while conn.pgconn.flush() == 1:
await asyncio.sleep(0)


async def build_index(
conn, name, workers, metric_ops, ivf_config, finish: asyncio.Event
):
start_time = perf_counter()
await conn.execute(f"SET max_parallel_maintenance_workers TO {workers}")
await conn.execute(f"SET max_parallel_workers TO {workers}")
await conn.execute(
f"CREATE INDEX ON {name} USING rabbithole (embedding {metric_ops}) WITH (options = $${ivf_config}$$)"
)
print(f"Index build time: {perf_counter() - start_time:.2f}s")
finish.set()


async def monitor_index_build(password, finish: asyncio.Event):
conn = await psycopg.AsyncConnection.connect(
conninfo=f"postgresql://postgres:{password}@localhost:5432/postgres",
dbname="postgres",
autocommit=True,
**KEEPALIVE_KWARGS,
)
pbar = tqdm(smoothing=0.0)
async def monitor_index_build(conn, finish: asyncio.Event):
async with conn.cursor() as acur:
blocks_total = None
while blocks_total is None:
await asyncio.sleep(1)
await acur.execute(
f"SELECT blocks_total FROM pg_stat_progress_create_index"
)
blocks_total = await acur.fetchone()
total = 0 if blocks_total is None else blocks_total[0]
pbar = tqdm(smoothing=0.0, total=total, desc="Building index")
while True:
if finish.is_set():
pbar.update(pbar.total - pbar.n)
return
await acur.execute(f"SELECT tuples_done FROM pg_stat_progress_create_index")
tuples_done = await acur.fetchone()
update = 0 if tuples_done is None else tuples_done[0]
pbar.update(update - pbar.n)
await acur.execute(f"SELECT blocks_done FROM pg_stat_progress_create_index")
blocks_done = await acur.fetchone()
done = 0 if blocks_done is None else blocks_done[0]
pbar.update(done - pbar.n)
await asyncio.sleep(1)
pbar.close()


async def main(dataset):
dataset = h5py.File(Path(args.input), "r")
conn = await create_connection(args.password)
url = f"postgresql://postgres:{args.password}@localhost:5432/postgres",
conn = await create_connection(url)
if args.centroids:
centroids = np.load(args.centroids, allow_pickle=False)
await add_centroids(conn, args.name, centroids)
Expand All @@ -177,6 +179,12 @@ async def main(dataset):
await add_embeddings(conn, args.name, args.dim, dataset["train"])

index_finish = asyncio.Event()
# Need a seperate connection for monitor process
monitor_conn = await create_connection(url)
monitor_task = monitor_index_build(
monitor_conn,
index_finish,
)
index_task = build_index(
conn,
args.name,
Expand All @@ -185,10 +193,6 @@ async def main(dataset):
ivf_config,
index_finish,
)
monitor_task = monitor_index_build(
args.password,
index_finish,
)
await asyncio.gather(index_task, monitor_task)


Expand Down
10 changes: 7 additions & 3 deletions bench/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def build_arg_parse():
"--niter", help="number of iterations", type=int, default=N_ITER
)
parser.add_argument("-m", "--metric", choices=["l2", "cos"], default="l2")
parser.add_argument(
"-g", "--gpu", help="enable GPU for KMeans", action="store_true"
)
return parser


Expand All @@ -56,14 +59,14 @@ def filter_by_label(iter, labels, target):
yield vec


def kmeans_cluster(data, k, child_k, niter, metric):
def kmeans_cluster(data, k, child_k, niter, metric, gpu=False):
n, dim = data.shape
if n > MAX_POINTS_PER_CLUSTER * k:
train = reservoir_sampling(iter(data), MAX_POINTS_PER_CLUSTER * args.k)
else:
train = data[:]
kmeans = Kmeans(
dim, k, verbose=True, niter=niter, seed=SEED, spherical=metric == "cos"
dim, k, gpu=gpu, verbose=True, niter=niter, seed=SEED, spherical=metric == "cos"
)
kmeans.train(train)
if not child_k:
Expand All @@ -85,6 +88,7 @@ def kmeans_cluster(data, k, child_k, niter, metric):
child_kmeans = Kmeans(
dim,
child_k,
gpu=gpu,
verbose=True,
niter=niter,
seed=SEED,
Expand All @@ -105,7 +109,7 @@ def kmeans_cluster(data, k, child_k, niter, metric):

start_time = perf_counter()
centroids = kmeans_cluster(
dataset["train"], args.k, args.child_k, args.niter, args.metric
dataset["train"], args.k, args.child_k, args.niter, args.metric, args.gpu
)
print(f"K-means (k=({args.k}, {args.child_k})): {perf_counter() - start_time:.2f}s")

Expand Down
2 changes: 1 addition & 1 deletion src/index/am.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main(
self.index_info,
true,
false,
false,
true,
0,
pgrx::pg_sys::InvalidBlockNumber,
Some(call::<F>),
Expand Down

0 comments on commit 591c192

Please sign in to comment.