From 91e9810b02acc05ff809ecf08c9f6e5dfbee39c3 Mon Sep 17 00:00:00 2001 From: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:07:05 +0900 Subject: [PATCH] Validate index name parameter (#29) Signed-off-by: B-Step62 --- .../langchain_databricks/vectorstores.py | 8 ++++++++ .../tests/unit_tests/test_vectorstore.py | 17 ++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/libs/databricks/langchain_databricks/vectorstores.py b/libs/databricks/langchain_databricks/vectorstores.py index 097b468..36cbe48 100644 --- a/libs/databricks/langchain_databricks/vectorstores.py +++ b/libs/databricks/langchain_databricks/vectorstores.py @@ -3,6 +3,7 @@ import asyncio import json import logging +import re import uuid from enum import Enum from functools import partial @@ -36,6 +37,7 @@ class IndexType(str, Enum): _NON_MANAGED_EMB_ONLY_MSG = ( "`%s` is not supported for index with Databricks-managed embeddings." ) +_INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$") class DatabricksVectorSearch(VectorStore): @@ -215,6 +217,12 @@ def __init__( text_column: Optional[str] = None, columns: Optional[List[str]] = None, ): + if not (isinstance(index_name, str) and _INDEX_NAME_PATTERN.match(index_name)): + raise ValueError( + "The `index_name` parameter must be a string in the format " + f"'catalog.schema.index'. Received: {index_name}" + ) + try: from databricks.vector_search.client import ( # type: ignore[import] VectorSearchClient, diff --git a/libs/databricks/tests/unit_tests/test_vectorstore.py b/libs/databricks/tests/unit_tests/test_vectorstore.py index 164cd5a..44c41e1 100644 --- a/libs/databricks/tests/unit_tests/test_vectorstore.py +++ b/libs/databricks/tests/unit_tests/test_vectorstore.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest +from databricks.vector_search.client import VectorSearchIndex # type: ignore from langchain_core.embeddings import Embeddings from langchain_databricks.vectorstores import DatabricksVectorSearch @@ -65,9 +66,9 @@ def embed_query(self, text: str) -> List[float]: ### Dummy Indices #### ENDPOINT_NAME = "test-endpoint" -DIRECT_ACCESS_INDEX = "test-direct-access-index" -DELTA_SYNC_INDEX = "test-delta-sync-index" -DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index" +DIRECT_ACCESS_INDEX = "test.direct_access.index" +DELTA_SYNC_INDEX = "test.delta_sync.index" +DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test.delta_sync_self_managed.index" ALL_INDEX_NAMES = { DIRECT_ACCESS_INDEX, DELTA_SYNC_INDEX, @@ -137,8 +138,6 @@ def _get_index( endpoint_name: Optional[str] = None, index_name: str = None, # type: ignore ) -> MagicMock: - from databricks.vector_search.client import VectorSearchIndex # type: ignore - index = MagicMock(spec=VectorSearchIndex) index.describe.return_value = INDEX_DETAILS[index_name] index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE @@ -184,6 +183,14 @@ def test_init_with_endpoint_name() -> None: assert vectorsearch.index.describe() == INDEX_DETAILS[DELTA_SYNC_INDEX] +@pytest.mark.parametrize( + "index_name", [None, "invalid", 123, MagicMock(spec=VectorSearchIndex)] +) +def test_init_fail_invalid_index_name(index_name) -> None: + with pytest.raises(ValueError, match="The `index_name` parameter must be"): + DatabricksVectorSearch(index_name=index_name) + + def test_init_fail_text_column_mismatch() -> None: with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"): DatabricksVectorSearch(