diff --git a/advanced_alchemy/filters.py b/advanced_alchemy/filters.py index 72123ffb..49c8e672 100644 --- a/advanced_alchemy/filters.py +++ b/advanced_alchemy/filters.py @@ -6,12 +6,15 @@ from collections import abc # noqa: TCH003 from dataclasses import dataclass from datetime import datetime # noqa: TCH003 +from operator import attrgetter from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast from sqlalchemy import BinaryExpression, and_, any_, or_, text if TYPE_CHECKING: - from sqlalchemy import Select, StatementLambdaElement + from typing import Callable + + from sqlalchemy import ColumnElement, Select, StatementLambdaElement from sqlalchemy.orm import InstrumentedAttribute from typing_extensions import TypeAlias @@ -32,7 +35,6 @@ "InAnyFilter", "StatementFilter", "StatementFilterT", - "TextSearchFilter", ) T = TypeVar("T") @@ -281,7 +283,9 @@ def append_to_lambda_statement( @dataclass -class TextSearchFilter(StatementFilter): +class SearchFilter(StatementFilter): + """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" + field_name: str | set[str] """Name of the model attribute to search on.""" value: str @@ -289,82 +293,52 @@ class TextSearchFilter(StatementFilter): ignore_case: bool | None = False """Should the search be case insensitive.""" + @property + def _operator(self) -> Callable[..., ColumnElement[bool]]: + return or_ + + @property + def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: + return attrgetter("ilike" if self.ignore_case else "like") + @property def normalized_field_names(self) -> set[str]: return {self.field_name} if isinstance(self.field_name, str) else self.field_name - -@dataclass -class SearchFilter(TextSearchFilter): - """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" + def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]: + search_clause: list[BinaryExpression[bool]] = [] + for field_name in self.normalized_field_names: + field = self._get_instrumented_attr(model, field_name) + search_text = f"%{self.value}%" + search_clause.append(self._func(field)(search_text)) + return search_clause def append_to_statement( self, statement: Select[tuple[ModelT]], model: type[ModelT], ) -> Select[tuple[ModelT]]: - fields = self.normalized_field_names - search_clause: list[BinaryExpression[bool]] = [] - for field_name in fields: - field = self._get_instrumented_attr(model, field_name) - search_text = f"%{self.value}%" - if self.ignore_case: - search_clause.append(field.ilike(search_text)) - else: - search_clause.append(field.like(search_text)) - return statement.where(or_(*search_clause)) + where_clause = self._operator(*self.get_search_clauses(model)) + return statement.where(where_clause) def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], ) -> StatementLambdaElement: - fields = self.normalized_field_names - search_clause: list[BinaryExpression[bool]] = [] - for field_name in fields: - field = self._get_instrumented_attr(model, field_name) - search_text = f"%{self.value}%" - if self.ignore_case: - search_clause.append(field.ilike(search_text)) - else: - search_clause.append(field.like(search_text)) - statement += lambda s: s.where(or_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] + where_clause = self._operator(*self.get_search_clauses(model)) + statement += lambda s: s.where(where_clause) return statement @dataclass -class NotInSearchFilter(TextSearchFilter): +class NotInSearchFilter(SearchFilter): """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause.""" - def append_to_statement( - self, - statement: Select[tuple[ModelT]], - model: type[ModelT], - ) -> Select[tuple[ModelT]]: - fields = self.normalized_field_names - search_clause: list[BinaryExpression[bool]] = [] - for field_name in fields: - field = self._get_instrumented_attr(model, field_name) - search_text = f"%{self.value}%" - if self.ignore_case: - search_clause.append(field.not_ilike(search_text)) - else: - search_clause.append(field.not_like(search_text)) - return statement.where(and_(*search_clause)) + @property + def _operator(self) -> Callable[..., ColumnElement[bool]]: + return and_ - def append_to_lambda_statement( - self, - statement: StatementLambdaElement, - model: type[ModelT], - ) -> StatementLambdaElement: - fields = self.normalized_field_names - search_clause: list[BinaryExpression[bool]] = [] - for field_name in fields: - field = self._get_instrumented_attr(model, field_name) - search_text = f"%{self.value}%" - if self.ignore_case: - search_clause.append(field.not_ilike(search_text)) - else: - search_clause.append(field.not_like(search_text)) - statement += lambda s: s.where(and_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] - return statement + @property + def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: + return attrgetter("not_ilike" if self.ignore_case else "not_like") diff --git a/advanced_alchemy/repository/memory/_async.py b/advanced_alchemy/repository/memory/_async.py index d6b8fe2f..9809ae78 100644 --- a/advanced_alchemy/repository/memory/_async.py +++ b/advanced_alchemy/repository/memory/_async.py @@ -290,15 +290,15 @@ def _apply_filters( filter_.field_name, sort_desc=filter_.sort_order == "desc", ) - elif isinstance(filter_, SearchFilter): - result = self._filter_by_like( + elif isinstance(filter_, NotInSearchFilter): + result = self._filter_by_not_like( result, filter_.field_name, value=filter_.value, ignore_case=bool(filter_.ignore_case), ) - elif isinstance(filter_, NotInSearchFilter): - result = self._filter_by_not_like( + elif isinstance(filter_, SearchFilter): + result = self._filter_by_like( result, filter_.field_name, value=filter_.value, diff --git a/advanced_alchemy/repository/memory/_sync.py b/advanced_alchemy/repository/memory/_sync.py index 5e24d807..88320577 100644 --- a/advanced_alchemy/repository/memory/_sync.py +++ b/advanced_alchemy/repository/memory/_sync.py @@ -292,15 +292,15 @@ def _apply_filters( filter_.field_name, sort_desc=filter_.sort_order == "desc", ) - elif isinstance(filter_, SearchFilter): - result = self._filter_by_like( + elif isinstance(filter_, NotInSearchFilter): + result = self._filter_by_not_like( result, filter_.field_name, value=filter_.value, ignore_case=bool(filter_.ignore_case), ) - elif isinstance(filter_, NotInSearchFilter): - result = self._filter_by_not_like( + elif isinstance(filter_, SearchFilter): + result = self._filter_by_like( result, filter_.field_name, value=filter_.value,