Skip to content

Commit

Permalink
Validate index name parameter (#29)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored Oct 18, 2024
1 parent ff1a60b commit 91e9810
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
8 changes: 8 additions & 0 deletions libs/databricks/langchain_databricks/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import json
import logging
import re
import uuid
from enum import Enum
from functools import partial
Expand Down Expand Up @@ -36,6 +37,7 @@ class IndexType(str, Enum):
_NON_MANAGED_EMB_ONLY_MSG = (
"`%s` is not supported for index with Databricks-managed embeddings."
)
_INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$")


class DatabricksVectorSearch(VectorStore):
Expand Down Expand Up @@ -215,6 +217,12 @@ def __init__(
text_column: Optional[str] = None,
columns: Optional[List[str]] = None,
):
if not (isinstance(index_name, str) and _INDEX_NAME_PATTERN.match(index_name)):
raise ValueError(
"The `index_name` parameter must be a string in the format "
f"'catalog.schema.index'. Received: {index_name}"
)

try:
from databricks.vector_search.client import ( # type: ignore[import]
VectorSearchClient,
Expand Down
17 changes: 12 additions & 5 deletions libs/databricks/tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import MagicMock, patch

import pytest
from databricks.vector_search.client import VectorSearchIndex # type: ignore
from langchain_core.embeddings import Embeddings

from langchain_databricks.vectorstores import DatabricksVectorSearch
Expand Down Expand Up @@ -65,9 +66,9 @@ def embed_query(self, text: str) -> List[float]:
### Dummy Indices ####

ENDPOINT_NAME = "test-endpoint"
DIRECT_ACCESS_INDEX = "test-direct-access-index"
DELTA_SYNC_INDEX = "test-delta-sync-index"
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index"
DIRECT_ACCESS_INDEX = "test.direct_access.index"
DELTA_SYNC_INDEX = "test.delta_sync.index"
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test.delta_sync_self_managed.index"
ALL_INDEX_NAMES = {
DIRECT_ACCESS_INDEX,
DELTA_SYNC_INDEX,
Expand Down Expand Up @@ -137,8 +138,6 @@ def _get_index(
endpoint_name: Optional[str] = None,
index_name: str = None, # type: ignore
) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex # type: ignore

index = MagicMock(spec=VectorSearchIndex)
index.describe.return_value = INDEX_DETAILS[index_name]
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
Expand Down Expand Up @@ -184,6 +183,14 @@ def test_init_with_endpoint_name() -> None:
assert vectorsearch.index.describe() == INDEX_DETAILS[DELTA_SYNC_INDEX]


@pytest.mark.parametrize(
"index_name", [None, "invalid", 123, MagicMock(spec=VectorSearchIndex)]
)
def test_init_fail_invalid_index_name(index_name) -> None:
with pytest.raises(ValueError, match="The `index_name` parameter must be"):
DatabricksVectorSearch(index_name=index_name)


def test_init_fail_text_column_mismatch() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
DatabricksVectorSearch(
Expand Down

0 comments on commit 91e9810

Please sign in to comment.