diff --git a/docs/docs/integrations/vectorstores/lancedb.ipynb b/docs/docs/integrations/vectorstores/lancedb.ipynb index 9d9fcbdbdecc4..bd67759b9769f 100644 --- a/docs/docs/integrations/vectorstores/lancedb.ipynb +++ b/docs/docs/integrations/vectorstores/lancedb.ipynb @@ -12,6 +12,16 @@ "This notebook shows how to use functionality related to the `LanceDB` vector database based on the Lance data format." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "88ac92c0", + "metadata": {}, + "outputs": [], + "source": [ + "! pip install -U langchain-openai" + ] + }, { "cell_type": "code", "execution_count": null, @@ -32,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "a0361f5c-e6f4-45f4-b829-11680cf03cec", "metadata": { "tags": [] @@ -47,25 +57,14 @@ }, { "cell_type": "code", - "execution_count": 10, - "id": "aac9563e", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.embeddings import OpenAIEmbeddings\n", - "from langchain.vectorstores import LanceDB" - ] - }, - { - "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "id": "a3c3999a", "metadata": {}, "outputs": [], "source": [ "from langchain.document_loaders import TextLoader\n", + "from langchain.vectorstores import LanceDB\n", + "from langchain_openai import OpenAIEmbeddings\n", "from langchain_text_splitters import CharacterTextSplitter\n", "\n", "loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n", @@ -75,22 +74,61 @@ "embeddings = OpenAIEmbeddings()" ] }, + { + "cell_type": "markdown", + "id": "e9517bb0", + "metadata": {}, + "source": [ + "##### For LanceDB cloud, you can invoke the vector store as follows :\n", + "\n", + "\n", + "```python\n", + "db_url = \"db://lang_test\" # url of db you created\n", + "api_key = \"xxxxx\" # your API key\n", + "region=\"us-east-1-dev\" # your selected region\n", + "\n", + "vector_store = LanceDB(\n", + " uri=db_url,\n", + " api_key=api_key,\n", + " region=region,\n", + " embedding=embeddings,\n", + " table_name='langchain_test'\n", + " )\n", + "```\n" + ] + }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "id": "6e104aee", "metadata": {}, "outputs": [], "source": [ "docsearch = LanceDB.from_documents(documents, embeddings)\n", - "\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n", "docs = docsearch.similarity_search(query)" ] }, + { + "cell_type": "markdown", + "id": "f5e1cdfd", + "metadata": {}, + "source": [ + "Additionaly, to explore the table you can load it into a df or save it in a csv file: \n", + "```python\n", + "tbl = docsearch.get_table()\n", + "print(\"tbl:\", tbl)\n", + "pd_df = tbl.to_pandas()\n", + "# pd_df.to_csv(\"docsearch.csv\", index=False)\n", + "\n", + "# you can also create a new vector store object using an older connection object:\n", + "vector_store = LanceDB(connection=tbl, embedding=embeddings)\n", + "```" + ] + }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "id": "9c608226", "metadata": {}, "outputs": [ @@ -166,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "id": "a359ed74", "metadata": {}, "outputs": [ @@ -267,7 +305,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/docs/docs/modules/data_connection/indexing.ipynb b/docs/docs/modules/data_connection/indexing.ipynb index 922f238840bbc..a7980a6896ab4 100644 --- a/docs/docs/modules/data_connection/indexing.ipynb +++ b/docs/docs/modules/data_connection/indexing.ipynb @@ -60,7 +60,7 @@ " * document addition by id (`add_documents` method with `ids` argument)\n", " * delete by id (`delete` method with `ids` argument)\n", "\n", - "Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `CouchbaseVectorStore`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `Vald`, `VDMS`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`, `TencentVectorDB`, `OpenSearchVectorSearch`.\n", + "Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `CouchbaseVectorStore`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `LanceDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `Vald`, `VDMS`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`, `TencentVectorDB`, `OpenSearchVectorSearch`.\n", " \n", "## Caution\n", "\n", diff --git a/libs/community/langchain_community/vectorstores/lancedb.py b/libs/community/langchain_community/vectorstores/lancedb.py index 414517793ee44..e591fc76664c9 100644 --- a/libs/community/langchain_community/vectorstores/lancedb.py +++ b/libs/community/langchain_community/vectorstores/lancedb.py @@ -1,6 +1,8 @@ from __future__ import annotations +import os import uuid +import warnings from typing import Any, Iterable, List, Optional from langchain_core.documents import Document @@ -8,6 +10,17 @@ from langchain_core.vectorstores import VectorStore +def import_lancedb() -> Any: + try: + import lancedb + except ImportError as e: + raise ImportError( + "Could not import pinecone lancedb package. " + "Please install it with `pip install lancedb`." + ) from e + return lancedb + + class LanceDB(VectorStore): """`LanceDB` vector store. @@ -22,15 +35,15 @@ class LanceDB(VectorStore): id_key: Key to use for the id in the database. Defaults to ``id``. text_key: Key to use for the text in the database. Defaults to ``text``. table_name: Name of the table to use. Defaults to ``vectorstore``. + api_key: API key to use for LanceDB cloud database. + region: Region to use for LanceDB cloud database. + mode: Mode to use for adding data to the table. Defaults to ``overwrite``. Example: .. code-block:: python - - db = lancedb.connect('./lancedb') - table = db.open_table('my_table') - vectorstore = LanceDB(table, embedding_function) + vectorstore = LanceDB(uri='/lancedb', embedding_function) vectorstore.add_texts(['text1', 'text2']) result = vectorstore.similarity_search('text1') """ @@ -39,38 +52,55 @@ def __init__( self, connection: Optional[Any] = None, embedding: Optional[Embeddings] = None, + uri: Optional[str] = "/tmp/lancedb", vector_key: Optional[str] = "vector", id_key: Optional[str] = "id", text_key: Optional[str] = "text", table_name: Optional[str] = "vectorstore", + api_key: Optional[str] = None, + region: Optional[str] = None, + mode: Optional[str] = "overwrite", ): """Initialize with Lance DB vectorstore""" - try: - import lancedb - except ImportError: - raise ImportError( - "Could not import lancedb python package. " - "Please install it with `pip install lancedb`." - ) - self.lancedb = lancedb + lancedb = import_lancedb() self._embedding = embedding self._vector_key = vector_key self._id_key = id_key self._text_key = text_key self._table_name = table_name + self.api_key = api_key or os.getenv("LANCE_API_KEY") if api_key != "" else None + self.region = region + self.mode = mode + + if isinstance(uri, str) and self.api_key is None: + if uri.startswith("db://"): + raise ValueError("API key is required for LanceDB cloud.") if self._embedding is None: - raise ValueError("embedding should be provided") + raise ValueError("embedding object should be provided") - if connection is not None: - if not isinstance(connection, lancedb.db.LanceTable): - raise ValueError( - "connection should be an instance of lancedb.db.LanceTable, ", - f"got {type(connection)}", - ) + if isinstance(connection, lancedb.db.LanceDBConnection): self._connection = connection + elif isinstance(connection, (str, lancedb.db.LanceTable)): + raise ValueError( + "`connection` has to be a lancedb.db.LanceDBConnection object.\ + `lancedb.db.LanceTable` is deprecated." + ) else: - self._connection = self._init_table() + if self.api_key is None: + self._connection = lancedb.connect(uri) + else: + if isinstance(uri, str): + if uri.startswith("db://"): + self._connection = lancedb.connect( + uri, api_key=self.api_key, region=self.region + ) + else: + self._connection = lancedb.connect(uri) + warnings.warn( + "api key provided with local uri.\ + The data will be stored locally" + ) @property def embeddings(self) -> Optional[Embeddings]: @@ -88,7 +118,7 @@ def add_texts( Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of ids to associate with the texts. + ids: Optional list of ids to associate w ith the texts. Returns: List of ids of the added texts. @@ -99,20 +129,70 @@ def add_texts( embeddings = self._embedding.embed_documents(list(texts)) # type: ignore for idx, text in enumerate(texts): embedding = embeddings[idx] - metadata = metadatas[idx] if metadatas else {} + metadata = metadatas[idx] if metadatas else {"id": ids[idx]} docs.append( { self._vector_key: embedding, self._id_key: ids[idx], self._text_key: text, - **metadata, + "metadata": metadata, } ) - self._connection.add(docs) + + if self._table_name in self._connection.table_names(): + tbl = self._connection.open_table(self._table_name) + if self.api_key is None: + tbl.add(docs, mode=self.mode) + else: + tbl.add(docs) + else: + self._connection.create_table(self._table_name, data=docs) return ids + def get_table(self, name: Optional[str] = None) -> Any: + if name is not None: + try: + self._connection.open_table(name) + except Exception: + raise ValueError(f"Table {name} not found in the database") + else: + return self._connection.open_table(self._table_name) + + def create_index( + self, + col_name: Optional[str] = None, + vector_col: Optional[str] = None, + num_partitions: Optional[int] = 256, + num_sub_vectors: Optional[int] = 96, + index_cache_size: Optional[int] = None, + ) -> None: + """ + Create a scalar(for non-vector cols) or a vector index on a table. + Make sure your vector column has enough data before creating an index on it. + + Args: + vector_col: Provide if you want to create index on a vector column. + col_name: Provide if you want to create index on a non-vector column. + metric: Provide the metric to use for vector index. Defaults to 'L2' + choice of metrics: 'L2', 'dot', 'cosine' + + Returns: + None + """ + if vector_col: + self._connection.create_index( + vector_column_name=vector_col, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + index_cache_size=index_cache_size, + ) + elif col_name: + self._connection.create_scalar_index(col_name) + else: + raise ValueError("Provide either vector_col or col_name") + def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, query: str, k: int = 4, name: Optional[str] = None, **kwargs: Any ) -> List[Document]: """Return documents most similar to the query @@ -124,8 +204,9 @@ def similarity_search( List of documents most similar to the query. """ embedding = self._embedding.embed_query(query) # type: ignore + tbl = self.get_table(name) docs = ( - self._connection.search(embedding, vector_column_name=self._vector_key) + tbl.search(embedding, vector_column_name=self._vector_key) .limit(k) .to_arrow() ) @@ -155,32 +236,47 @@ def from_texts( **kwargs: Any, ) -> LanceDB: instance = LanceDB( - connection, - embedding, - vector_key, - id_key, - text_key, + connection=connection, + embedding=embedding, + vector_key=vector_key, + id_key=id_key, + text_key=text_key, ) instance.add_texts(texts, metadatas=metadatas, **kwargs) return instance - def _init_table(self) -> Any: - import pyarrow as pa - - schema = pa.schema( - [ - pa.field( - self._vector_key, - pa.list_( - pa.float32(), - len(self.embeddings.embed_query("test")), # type: ignore - ), - ), - pa.field(self._id_key, pa.string()), - pa.field(self._text_key, pa.string()), - ] - ) - db = self.lancedb.connect("/tmp/lancedb") - tbl = db.create_table(self._table_name, schema=schema, mode="overwrite") - return tbl + def delete( + self, + ids: Optional[List[str]] = None, + delete_all: Optional[bool] = None, + filter: Optional[str] = None, + drop_columns: Optional[List[str]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """ + Allows deleting rows by filtering, by ids or drop columns from the table. + + Args: + filter: Provide a string SQL expression - "{col} {operation} {value}". + ids: Provide list of ids to delete from the table. + drop_columns: Provide list of columns to drop from the table. + delete_all: If True, delete all rows from the table. + """ + tbl = self.get_table(name) + if filter: + tbl.delete(filter) + elif ids: + tbl.delete("id in ('{}')".format(",".join(ids))) + elif drop_columns: + if self.api_key is not None: + raise NotImplementedError( + "Column operations currently not supported in LanceDB Cloud." + ) + else: + tbl.drop_columns(drop_columns) + elif delete_all: + tbl.delete("true") + else: + raise ValueError("Provide either filter, ids, drop_columns or delete_all") diff --git a/libs/community/tests/integration_tests/vectorstores/test_lancedb.py b/libs/community/tests/integration_tests/vectorstores/test_lancedb.py index bde46e800e116..d9ddf966e5437 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_lancedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_lancedb.py @@ -1,30 +1,39 @@ +from typing import Any + import pytest from langchain_community.vectorstores import LanceDB from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +def import_lancedb() -> Any: + try: + import lancedb + except ImportError as e: + raise ImportError( + "Could not import pinecone lancedb package. " + "Please install it with `pip install lancedb`." + ) from e + return lancedb + + @pytest.mark.requires("lancedb") def test_lancedb_with_connection() -> None: - import lancedb + lancedb = import_lancedb() embeddings = FakeEmbeddings() - db = lancedb.connect("/tmp/lancedb") + db = lancedb.connect("/tmp/lancedb_connection") texts = ["text 1", "text 2", "item 3"] - vectors = embeddings.embed_documents(texts) - table = db.create_table( - "my_table", - data=[ - {"vector": vectors[idx], "id": text, "text": text} - for idx, text in enumerate(texts) - ], - mode="overwrite", - ) - store = LanceDB(table, embeddings) + store = LanceDB(connection=db, embedding=embeddings) + store.add_texts(texts) + result = store.similarity_search("text 1") result_texts = [doc.page_content for doc in result] assert "text 1" in result_texts + store.delete(filter="text = 'text 1'") + assert store.get_table().count_rows() == 2 + @pytest.mark.requires("lancedb") def test_lancedb_without_connection() -> None: diff --git a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py index b5b9c4b78e03b..c0e29eb995173 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py +++ b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py @@ -67,6 +67,7 @@ def check_compatibility(vector_store: VectorStore) -> bool: "FAISS", "HanaDB", "InMemoryVectorStore", + "LanceDB", "Milvus", "MomentoVectorIndex", "MyScale",