Skip to content

Commit 3db517c

Browse files
author
Alc-Alc
committed
chore: test ci
1 parent 0941b61 commit 3db517c

File tree

1 file changed

+31
-55
lines changed

1 file changed

+31
-55
lines changed

advanced_alchemy/filters.py

+31-55
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
from collections import abc # noqa: TCH003
77
from dataclasses import dataclass
88
from datetime import datetime # noqa: TCH003
9+
from operator import attrgetter
910
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
1011

1112
from sqlalchemy import BinaryExpression, and_, any_, or_, text
1213

1314
if TYPE_CHECKING:
15+
from typing import Callable
16+
1417
from sqlalchemy import Select, StatementLambdaElement
1518
from sqlalchemy.orm import InstrumentedAttribute
1619
from typing_extensions import TypeAlias
@@ -295,76 +298,49 @@ def normalized_field_names(self) -> set[str]:
295298

296299

297300
@dataclass
298-
class SearchFilter(TextSearchFilter):
299-
"""Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause."""
301+
class SearchFilter(StatementFilter):
302+
field_name: str | set[str]
303+
value: str
304+
ignore_case: bool | None = False
300305

301-
def append_to_statement(
302-
self,
303-
statement: Select[tuple[ModelT]],
304-
model: type[ModelT],
305-
) -> Select[tuple[ModelT]]:
306-
fields = self.normalized_field_names
307-
search_clause: list[BinaryExpression[bool]] = []
308-
for field_name in fields:
309-
field = self._get_instrumented_attr(model, field_name)
310-
search_text = f"%{self.value}%"
311-
if self.ignore_case:
312-
search_clause.append(field.ilike(search_text))
313-
else:
314-
search_clause.append(field.like(search_text))
315-
return statement.where(or_(*search_clause))
306+
_operator: Any = or_
316307

317-
def append_to_lambda_statement(
318-
self,
319-
statement: StatementLambdaElement,
320-
model: type[ModelT],
321-
) -> StatementLambdaElement:
322-
fields = self.normalized_field_names
308+
@property
309+
def func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]:
310+
return attrgetter("ilike" if self.ignore_case else "like")
311+
312+
@property
313+
def normalized_field_names(self) -> set[str]:
314+
return {self.field_name} if isinstance(self.field_name, str) else self.field_name
315+
316+
def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]:
323317
search_clause: list[BinaryExpression[bool]] = []
324-
for field_name in fields:
318+
for field_name in self.normalized_field_names:
325319
field = self._get_instrumented_attr(model, field_name)
326320
search_text = f"%{self.value}%"
327-
if self.ignore_case:
328-
search_clause.append(field.ilike(search_text))
329-
else:
330-
search_clause.append(field.like(search_text))
331-
statement += lambda s: s.where(or_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType]
332-
return statement
333-
334-
335-
@dataclass
336-
class NotInSearchFilter(TextSearchFilter):
337-
"""Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
321+
search_clause.append(self.func(field)(search_text))
322+
return search_clause
338323

339324
def append_to_statement(
340325
self,
341326
statement: Select[tuple[ModelT]],
342327
model: type[ModelT],
343328
) -> Select[tuple[ModelT]]:
344-
fields = self.normalized_field_names
345-
search_clause: list[BinaryExpression[bool]] = []
346-
for field_name in fields:
347-
field = self._get_instrumented_attr(model, field_name)
348-
search_text = f"%{self.value}%"
349-
if self.ignore_case:
350-
search_clause.append(field.not_ilike(search_text))
351-
else:
352-
search_clause.append(field.not_like(search_text))
353-
return statement.where(and_(*search_clause))
329+
return statement.where(self._operator(*self.get_search_clauses(model)))
354330

355331
def append_to_lambda_statement(
356332
self,
357333
statement: StatementLambdaElement,
358334
model: type[ModelT],
359335
) -> StatementLambdaElement:
360-
fields = self.normalized_field_names
361-
search_clause: list[BinaryExpression[bool]] = []
362-
for field_name in fields:
363-
field = self._get_instrumented_attr(model, field_name)
364-
search_text = f"%{self.value}%"
365-
if self.ignore_case:
366-
search_clause.append(field.not_ilike(search_text))
367-
else:
368-
search_clause.append(field.not_like(search_text))
369-
statement += lambda s: s.where(and_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType]
336+
statement += lambda s: s.where(self._operator(*self.get_search_clauses(model)))
370337
return statement
338+
339+
340+
@dataclass
341+
class NotInSearchFilter(SearchFilter):
342+
_operator: Any = and_
343+
344+
@property
345+
def func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]:
346+
return attrgetter("not_ilike" if self.ignore_case else "not_like")

0 commit comments

Comments
 (0)