Skip to content

A new Store for SQL DB #22203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion libs/community/langchain_community/docstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
from langchain_community.docstore.in_memory import (
InMemoryDocstore,
)
from langchain_community.docstore.sql_docstore import (
SQLStore,
)
from langchain_community.docstore.wikipedia import (
Wikipedia,
)

_module_lookup = {
"DocstoreFn": "langchain_community.docstore.arbitrary_fn",
"InMemoryDocstore": "langchain_community.docstore.in_memory",
"SQLStore": "langchain_community.docstore.sql_docstore",
"Wikipedia": "langchain_community.docstore.wikipedia",
}

Expand All @@ -43,4 +47,4 @@ def __getattr__(name: str) -> Any:
raise AttributeError(f"module {__name__} has no attribute {name}")


__all__ = ["DocstoreFn", "InMemoryDocstore", "Wikipedia"]
__all__ = ["DocstoreFn", "InMemoryDocstore", "SQLStore", "Wikipedia"]
264 changes: 264 additions & 0 deletions libs/community/langchain_community/docstore/sql_docstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import contextlib
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)

from langchain_core.stores import BaseStore, V
from sqlalchemy import (
Column,
Engine,
PickleType,
and_,
create_engine,
delete,
select,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import (
Mapped,
Session,
declarative_base,
mapped_column,
sessionmaker,
)

Base = declarative_base()


def items_equal(x: Any, y: Any) -> bool:
return x == y


class Value(Base): # type: ignore[valid-type,misc]
"""Table used to save values."""

# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "docstore"

namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
# value: Mapped[Any] = Column(type_=PickleType, index=False, nullable=False)
value: Any = Column("earthquake", PickleType(comparator=items_equal))


# This is a fix of original SQLStore.
# This can will be removed when a PR will be merged.
class SQLStore(BaseStore[str, bytes]):
"""BaseStore interface that works on an SQL database.

Examples:
Create a SQLStore instance and perform operations on it:

.. code-block:: python

from langchain_rag.storage import SQLStore

# Instantiate the SQLStore with the root path
sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:")

# Set values for keys
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])

# Get values for keys
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]

# Delete keys
sql_store.mdelete(["key1"])

# Iterate over keys
for key in sql_store.yield_keys():
print(key)

"""

def __init__(
self,
*,
namespace: str,
db_url: Optional[Union[str, Path]] = None,
engine: Optional[Union[Engine, AsyncEngine]] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
async_mode: Optional[bool] = None,
):
if db_url is None and engine is None:
raise ValueError("Must specify either db_url or engine")

if db_url is not None and engine is not None:
raise ValueError("Must specify either db_url or engine, not both")

_engine: Union[Engine, AsyncEngine]
if db_url:
if async_mode is None:
async_mode = False
if async_mode:
_engine = create_async_engine(
url=str(db_url),
**(engine_kwargs or {}),
)
else:
_engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
elif engine:
_engine = engine

else:
raise AssertionError("Something went wrong with configuration of engine.")

_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
if isinstance(_engine, AsyncEngine):
self.async_mode = True
_session_maker = async_sessionmaker(bind=_engine)
else:
self.async_mode = False
_session_maker = sessionmaker(bind=_engine)

self.engine = _engine
self.dialect = _engine.dialect.name
self.session_maker = _session_maker
self.namespace = namespace

def create_schema(self) -> None:
Base.metadata.create_all(self.engine)

async def acreate_schema(self) -> None:
assert isinstance(self.engine, AsyncEngine)
async with self.engine.begin() as session:
await session.run_sync(Base.metadata.create_all)

def drop(self) -> None:
Base.metadata.drop_all(bind=self.engine.connect())

async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
assert isinstance(self.engine, AsyncEngine)
result: Dict[str, V] = {}
async with self._make_async_session() as session:
stmt = select(Value).filter(
and_(
Value.key.in_(keys),
Value.namespace == self.namespace,
)
)
for v in await session.scalars(stmt):
result[v.key] = v.value
return [result.get(key) for key in keys]

def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
result = {}

with self._make_sync_session() as session:
stmt = select(Value).filter(
and_(
Value.key.in_(keys),
Value.namespace == self.namespace,
)
)
for v in session.scalars(stmt):
result[v.key] = v.value
return [result.get(key) for key in keys]

async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
async with self._make_async_session() as session:
await self._amdelete([key for key, _ in key_value_pairs], session)
session.add_all(
[
Value(namespace=self.namespace, key=k, value=v)
for k, v in key_value_pairs
]
)
await session.commit()

def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
values: Dict[str, bytes] = dict(key_value_pairs)
with self._make_sync_session() as session:
self._mdelete(list(values.keys()), session)
session.add_all(
[
Value(namespace=self.namespace, key=k, value=v)
for k, v in values.items()
]
)
session.commit()

def _mdelete(self, keys: Sequence[str], session: Session) -> None:
stmt = delete(Value).filter(
and_(
Value.key.in_(keys),
Value.namespace == self.namespace,
)
)
session.execute(stmt)

async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
stmt = delete(Value).filter(
and_(
Value.key.in_(keys),
Value.namespace == self.namespace,
)
)
await session.execute(stmt)

def mdelete(self, keys: Sequence[str]) -> None:
with self._make_sync_session() as session:
self._mdelete(keys, session)
session.commit()

async def amdelete(self, keys: Sequence[str]) -> None:
async with self._make_async_session() as session:
await self._amdelete(keys, session)
await session.commit()

def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
with self._make_sync_session() as session:
for v in session.query(Value).filter( # type: ignore
Value.namespace == self.namespace
):
yield str(v.key)
session.close()

async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
async with self._make_async_session() as session:
stmt = select(Value).filter(Value.namespace == self.namespace)
for v in await session.scalars(stmt):
yield str(v.key)
await session.close()

@contextlib.contextmanager
def _make_sync_session(self) -> Generator[Session, None, None]:
"""Make an async session."""
if self.async_mode:
raise ValueError(
"Attempting to use a sync method in when async mode is turned on. "
"Please use the corresponding async method instead."
)
with cast(Session, self.session_maker()) as session:
yield cast(Session, session)

@contextlib.asynccontextmanager
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Make an async session."""
if not self.async_mode:
raise ValueError(
"Attempting to use an async method in when sync mode is turned on. "
"Please use the corresponding async method instead."
)
async with cast(AsyncSession, self.session_maker()) as session:
yield cast(AsyncSession, session)
9 changes: 4 additions & 5 deletions libs/community/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
requests-mock = "^1.11.0"
aiosqlite = "^0.19.0"
langchain-core = {path = "../core", develop = true}
langchain = {path = "../langchain", develop = true}

Expand Down
Loading
Loading