diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 26dd49d..5f05aea 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -88,7 +88,7 @@ class DistanceStrategy(str, enum.Enum): "$ilike", } -LOGICAL_OPERATORS = {"$and", "$or"} +LOGICAL_OPERATORS = {"$and", "$or", "$not"} SUPPORTED_OPERATORS = ( set(COMPARISONS_TO_NATIVE) @@ -1248,26 +1248,25 @@ def _create_filter_clause(self, filters: Any) -> Any: """ if isinstance(filters, dict): if len(filters) == 1: - # The only operators allowed at the top level are $AND and $OR + # The only operators allowed at the top level are $AND, $OR, and $NOT # First check if an operator or a field key, value = list(filters.items())[0] if key.startswith("$"): # Then it's an operator - if key.lower() not in ["$and", "$or"]: + if key.lower() not in ["$and", "$or", "$not"]: raise ValueError( - f"Invalid filter condition. Expected $and or $or " + f"Invalid filter condition. Expected $and, $or or $not " f"but got: {key}" ) else: # Then it's a field return self._handle_field_filter(key, filters[key]) - # Here we handle the $and and $or operators - if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) if key.lower() == "$and": + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) and_ = [self._create_filter_clause(el) for el in value] if len(and_) > 1: return sqlalchemy.and_(*and_) @@ -1279,6 +1278,10 @@ def _create_filter_clause(self, filters: Any) -> Any: "but got an empty dictionary" ) elif key.lower() == "$or": + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) or_ = [self._create_filter_clause(el) for el in value] if len(or_) > 1: return sqlalchemy.or_(*or_) @@ -1289,9 +1292,29 @@ def _create_filter_clause(self, filters: Any) -> Any: "Invalid filter condition. Expected a dictionary " "but got an empty dictionary" ) + elif key.lower() == "$not": + if isinstance(value, list): + not_conditions = [ + self._create_filter_clause(item) for item in value + ] + not_ = sqlalchemy.and_( + *[ + sqlalchemy.not_(condition) + for condition in not_conditions + ] + ) + return not_ + elif isinstance(value, dict): + not_ = self._create_filter_clause(value) + return sqlalchemy.not_(not_) + else: + raise ValueError( + f"Invalid filter condition. Expected a dictionary " + f"or a list but got: {type(value)}" + ) else: raise ValueError( - f"Invalid filter condition. Expected $and or $or " + f"Invalid filter condition. Expected $and, $or or $not " f"but got: {key}" ) elif len(filters) > 1: diff --git a/tests/unit_tests/fixtures/filtering_test_cases.py b/tests/unit_tests/fixtures/filtering_test_cases.py index 701260c..181e8ba 100644 --- a/tests/unit_tests/fixtures/filtering_test_cases.py +++ b/tests/unit_tests/fixtures/filtering_test_cases.py @@ -81,7 +81,7 @@ TYPE_2_FILTERING_TEST_CASES = [ # These involve equality checks and other operators - # like $ne, $gt, $gte, $lt, $lte, $not + # like $ne, $gt, $gte, $lt, $lte ( {"id": 1}, [1], @@ -168,7 +168,7 @@ ] TYPE_3_FILTERING_TEST_CASES = [ - # These involve usage of AND and OR operators + # These involve usage of AND, OR and NOT operators ( {"$or": [{"id": 1}, {"id": 2}]}, [1, 2], @@ -185,6 +185,39 @@ {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]}, [1, 2, 3], ), + # Test for $not operator + ( + {"$not": {"id": 1}}, + [2, 3], + ), + ( + {"$not": [{"id": 1}]}, + [2, 3], + ), + ( + {"$not": {"name": "adam"}}, + [2, 3], + ), + ( + {"$not": [{"name": "adam"}]}, + [2, 3], + ), + ( + {"$not": {"is_active": True}}, + [2], + ), + ( + {"$not": [{"is_active": True}]}, + [2], + ), + ( + {"$not": {"height": {"$gt": 5.0}}}, + [3], + ), + ( + {"$not": [{"height": {"$gt": 5.0}}]}, + [3], + ), ] TYPE_4_FILTERING_TEST_CASES = [ diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index fcba8ef..cf0a184 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -992,6 +992,7 @@ async def test_async_pgvector_with_with_metadata_filters_5( {"$eq": {}}, {"$exists": {}}, {"$exists": 1}, + {"$not": 2}, ], ) def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None: @@ -1016,5 +1017,6 @@ def test_validate_operators() -> None: "$lte", "$ne", "$nin", + "$not", "$or", ]