Skip to content

Commit

Permalink
Add in-mem vector search (#2547)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Nov 27, 2024
1 parent dfaff25 commit 62a36be
Show file tree
Hide file tree
Showing 13 changed files with 2,046 additions and 190 deletions.
19 changes: 18 additions & 1 deletion libs/checkpoint-duckdb/langgraph/store/duckdb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Op,
PutOp,
Result,
SearchItem,
SearchOp,
)

Expand Down Expand Up @@ -283,7 +284,7 @@ def _batch_search_ops(

for cur, idx in cursors:
rows = cur.fetchall()
items = [_row_to_item(_convert_ns(row[0]), row) for row in rows]
items = [_row_to_search_item(_convert_ns(row[0]), row) for row in rows]
results[idx] = items

def _batch_list_namespaces_ops(
Expand Down Expand Up @@ -376,6 +377,22 @@ def _row_to_item(
)


def _row_to_search_item(
namespace: tuple[str, ...],
row: tuple,
) -> SearchItem:
"""Convert a row from the database into an SearchItem."""
# TODO: Add support for search
_, key, val, created_at, updated_at = row
return SearchItem(
value=val if isinstance(val, dict) else json.loads(val),
key=key,
namespace=namespace,
created_at=created_at,
updated_at=updated_at,
)


def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]:
grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list)
tot = 0
Expand Down
16 changes: 10 additions & 6 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,19 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur:
Expand Down
23 changes: 14 additions & 9 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,20 +338,25 @@ async def _cursor(
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
async with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
async with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
async with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
async with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
async with self.lock, conn.cursor(
binary=True, row_factory=dict_row
) as cur:
async with (
self.lock,
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur

def list(
Expand Down
11 changes: 7 additions & 4 deletions libs/checkpoint-postgres/langgraph/store/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_decode_ns_bytes,
_group_ops,
_row_to_item,
_row_to_search_item,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -146,7 +147,7 @@ async def _batch_search_ops(
await cur.execute(query, params)
rows = cast(list[Row], await cur.fetchall())
items = [
_row_to_item(
_row_to_search_item(
_decode_ns_bytes(row["prefix"]), row, loader=self._deserializer
)
for row in rows
Expand Down Expand Up @@ -195,9 +196,11 @@ async def _cursor(
async with self.lock, conn.pipeline(), conn.cursor(binary=True) as cur:
yield cur
else:
async with self.lock, conn.transaction(), conn.cursor(
binary=True
) as cur:
async with (
self.lock,
conn.transaction(),
conn.cursor(binary=True) as cur,
):
yield cur
else:
async with conn.cursor(binary=True) as cur:
Expand Down
45 changes: 38 additions & 7 deletions libs/checkpoint-postgres/langgraph/store/postgres/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Op,
PutOp,
Result,
SearchItem,
SearchOp,
)

Expand Down Expand Up @@ -344,14 +345,18 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
with self.lock, conn.pipeline(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with self.lock, conn.transaction(), conn.cursor(
binary=True, row_factory=dict_row
) as cur:
with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with conn.cursor(binary=True, row_factory=dict_row) as cur:
Expand Down Expand Up @@ -430,7 +435,7 @@ def _batch_search_ops(
cur.execute(query, params)
rows = cast(list[Row], cur.fetchall())
results[idx] = [
_row_to_item(
_row_to_search_item(
_decode_ns_bytes(row["prefix"]), row, loader=self._deserializer
)
for row in rows
Expand Down Expand Up @@ -517,6 +522,32 @@ def _row_to_item(
)


def _row_to_search_item(
namespace: tuple[str, ...],
row: Row,
*,
loader: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] = None,
) -> SearchItem:
"""Convert a row from the database into an Item."""
loader = loader or _json_loads
val = row["value"]
score = row.get("score")
if score is not None:
try:
score = float(score) # type: ignore[arg-type]
except ValueError:
logger.warning("Invalid score: %s", score)
score = None
return SearchItem(
value=val if isinstance(val, dict) else loader(val),
key=row["key"],
namespace=namespace,
created_at=row["created_at"],
updated_at=row["updated_at"],
score=score,
)


def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]:
grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list)
tot = 0
Expand Down
6 changes: 4 additions & 2 deletions libs/checkpoint/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# TESTING AND COVERAGE
######################

TEST ?= .

test:
poetry run pytest tests
poetry run pytest $(TEST)

test_watch:
poetry run ptw .
poetry run ptw $(TEST)

######################
# LINTING AND FORMATTING
Expand Down
Loading

0 comments on commit 62a36be

Please sign in to comment.