diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 4092e05..5f05aea 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1262,12 +1262,11 @@ def _create_filter_clause(self, filters: Any) -> Any: # Then it's a field return self._handle_field_filter(key, filters[key]) - # Here we handle the $and, $or, and $not operators - if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) if key.lower() == "$and": + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) and_ = [self._create_filter_clause(el) for el in value] if len(and_) > 1: return sqlalchemy.and_(*and_) @@ -1279,6 +1278,10 @@ def _create_filter_clause(self, filters: Any) -> Any: "but got an empty dictionary" ) elif key.lower() == "$or": + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) or_ = [self._create_filter_clause(el) for el in value] if len(or_) > 1: return sqlalchemy.or_(*or_) @@ -1290,13 +1293,25 @@ def _create_filter_clause(self, filters: Any) -> Any: "but got an empty dictionary" ) elif key.lower() == "$not": - not_conditions = [ - self._create_filter_clause(item) for item in value - ] - not_ = sqlalchemy.and_( - *[sqlalchemy.not_(condition) for condition in not_conditions] - ) - return not_ + if isinstance(value, list): + not_conditions = [ + self._create_filter_clause(item) for item in value + ] + not_ = sqlalchemy.and_( + *[ + sqlalchemy.not_(condition) + for condition in not_conditions + ] + ) + return not_ + elif isinstance(value, dict): + not_ = self._create_filter_clause(value) + return sqlalchemy.not_(not_) + else: + raise ValueError( + f"Invalid filter condition. Expected a dictionary " + f"or a list but got: {type(value)}" + ) else: raise ValueError( f"Invalid filter condition. Expected $and, $or or $not " diff --git a/tests/unit_tests/fixtures/filtering_test_cases.py b/tests/unit_tests/fixtures/filtering_test_cases.py index a354847..181e8ba 100644 --- a/tests/unit_tests/fixtures/filtering_test_cases.py +++ b/tests/unit_tests/fixtures/filtering_test_cases.py @@ -81,7 +81,7 @@ TYPE_2_FILTERING_TEST_CASES = [ # These involve equality checks and other operators - # like $ne, $gt, $gte, $lt, $lte, $not + # like $ne, $gt, $gte, $lt, $lte ( {"id": 1}, [1], @@ -165,42 +165,58 @@ {"height": {"$lte": 5.8}}, [2, 3], ), +] + +TYPE_3_FILTERING_TEST_CASES = [ + # These involve usage of AND, OR and NOT operators + ( + {"$or": [{"id": 1}, {"id": 2}]}, + [1, 2], + ), + ( + {"$or": [{"id": 1}, {"name": "bob"}]}, + [1, 2], + ), + ( + {"$and": [{"id": 1}, {"id": 2}]}, + [], + ), + ( + {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]}, + [1, 2, 3], + ), # Test for $not operator ( {"$not": {"id": 1}}, [2, 3], ), ( - {"$not": {"name": "adam"}}, + {"$not": [{"id": 1}]}, [2, 3], ), ( - {"$not": {"is_active": True}}, - [2], + {"$not": {"name": "adam"}}, + [2, 3], ), ( - {"$not": {"height": {"$gt": 5.0}}}, - [3], + {"$not": [{"name": "adam"}]}, + [2, 3], ), -] - -TYPE_3_FILTERING_TEST_CASES = [ - # These involve usage of AND and OR operators ( - {"$or": [{"id": 1}, {"id": 2}]}, - [1, 2], + {"$not": {"is_active": True}}, + [2], ), ( - {"$or": [{"id": 1}, {"name": "bob"}]}, - [1, 2], + {"$not": [{"is_active": True}]}, + [2], ), ( - {"$and": [{"id": 1}, {"id": 2}]}, - [], + {"$not": {"height": {"$gt": 5.0}}}, + [3], ), ( - {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]}, - [1, 2, 3], + {"$not": [{"height": {"$gt": 5.0}}]}, + [3], ), ]