Skip to content

Commit 057573c

Browse files
author
Alc-Alc
committed
refactor(search_filter): Removes some redundant code
1 parent 0941b61 commit 057573c

File tree

1 file changed

+28
-60
lines changed

1 file changed

+28
-60
lines changed

advanced_alchemy/filters.py

+28-60
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
from collections import abc # noqa: TCH003
77
from dataclasses import dataclass
88
from datetime import datetime # noqa: TCH003
9+
from operator import attrgetter
910
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
1011

1112
from sqlalchemy import BinaryExpression, and_, any_, or_, text
1213

1314
if TYPE_CHECKING:
14-
from sqlalchemy import Select, StatementLambdaElement
15+
from typing import Callable
16+
17+
from sqlalchemy import ColumnElement, Select, StatementLambdaElement
1518
from sqlalchemy.orm import InstrumentedAttribute
1619
from typing_extensions import TypeAlias
1720

@@ -32,7 +35,6 @@
3235
"InAnyFilter",
3336
"StatementFilter",
3437
"StatementFilterT",
35-
"TextSearchFilter",
3638
)
3739

3840
T = TypeVar("T")
@@ -281,90 +283,56 @@ def append_to_lambda_statement(
281283

282284

283285
@dataclass
284-
class TextSearchFilter(StatementFilter):
286+
class SearchFilter(StatementFilter):
287+
"""Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause."""
285288
field_name: str | set[str]
286289
"""Name of the model attribute to search on."""
287290
value: str
288291
"""Values for ``NOT LIKE`` clause."""
289292
ignore_case: bool | None = False
290293
"""Should the search be case insensitive."""
291294

295+
_operator: Callable[..., ColumnElement[bool]] = or_
296+
297+
@property
298+
def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]:
299+
return attrgetter("ilike" if self.ignore_case else "like")
300+
292301
@property
293302
def normalized_field_names(self) -> set[str]:
294303
return {self.field_name} if isinstance(self.field_name, str) else self.field_name
295304

296-
297-
@dataclass
298-
class SearchFilter(TextSearchFilter):
299-
"""Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause."""
305+
def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]:
306+
search_clause: list[BinaryExpression[bool]] = []
307+
for field_name in self.normalized_field_names:
308+
field = self._get_instrumented_attr(model, field_name)
309+
search_text = f"%{self.value}%"
310+
search_clause.append(self._func(field)(search_text))
311+
return search_clause
300312

301313
def append_to_statement(
302314
self,
303315
statement: Select[tuple[ModelT]],
304316
model: type[ModelT],
305317
) -> Select[tuple[ModelT]]:
306-
fields = self.normalized_field_names
307-
search_clause: list[BinaryExpression[bool]] = []
308-
for field_name in fields:
309-
field = self._get_instrumented_attr(model, field_name)
310-
search_text = f"%{self.value}%"
311-
if self.ignore_case:
312-
search_clause.append(field.ilike(search_text))
313-
else:
314-
search_clause.append(field.like(search_text))
315-
return statement.where(or_(*search_clause))
318+
where_clause = self._operator(*self.get_search_clauses(model))
319+
return statement.where(where_clause)
316320

317321
def append_to_lambda_statement(
318322
self,
319323
statement: StatementLambdaElement,
320324
model: type[ModelT],
321325
) -> StatementLambdaElement:
322-
fields = self.normalized_field_names
323-
search_clause: list[BinaryExpression[bool]] = []
324-
for field_name in fields:
325-
field = self._get_instrumented_attr(model, field_name)
326-
search_text = f"%{self.value}%"
327-
if self.ignore_case:
328-
search_clause.append(field.ilike(search_text))
329-
else:
330-
search_clause.append(field.like(search_text))
331-
statement += lambda s: s.where(or_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType]
326+
where_clause = self._operator(*self.get_search_clauses(model))
327+
statement += lambda s: s.where(where_clause)
332328
return statement
333329

334330

335331
@dataclass
336-
class NotInSearchFilter(TextSearchFilter):
332+
class NotInSearchFilter(SearchFilter):
337333
"""Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
334+
_operator: Callable[..., ColumnElement[bool]] = and_
338335

339-
def append_to_statement(
340-
self,
341-
statement: Select[tuple[ModelT]],
342-
model: type[ModelT],
343-
) -> Select[tuple[ModelT]]:
344-
fields = self.normalized_field_names
345-
search_clause: list[BinaryExpression[bool]] = []
346-
for field_name in fields:
347-
field = self._get_instrumented_attr(model, field_name)
348-
search_text = f"%{self.value}%"
349-
if self.ignore_case:
350-
search_clause.append(field.not_ilike(search_text))
351-
else:
352-
search_clause.append(field.not_like(search_text))
353-
return statement.where(and_(*search_clause))
354-
355-
def append_to_lambda_statement(
356-
self,
357-
statement: StatementLambdaElement,
358-
model: type[ModelT],
359-
) -> StatementLambdaElement:
360-
fields = self.normalized_field_names
361-
search_clause: list[BinaryExpression[bool]] = []
362-
for field_name in fields:
363-
field = self._get_instrumented_attr(model, field_name)
364-
search_text = f"%{self.value}%"
365-
if self.ignore_case:
366-
search_clause.append(field.not_ilike(search_text))
367-
else:
368-
search_clause.append(field.not_like(search_text))
369-
statement += lambda s: s.where(and_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType]
370-
return statement
336+
@property
337+
def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]:
338+
return attrgetter("not_ilike" if self.ignore_case else "not_like")

0 commit comments

Comments
 (0)