Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(search_filter): Removes some redundant code #211

Merged
merged 6 commits into from
Jun 5, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 33 additions & 59 deletions advanced_alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from collections import abc # noqa: TCH003
from dataclasses import dataclass
from datetime import datetime # noqa: TCH003
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast

from sqlalchemy import BinaryExpression, and_, any_, or_, text

if TYPE_CHECKING:
from sqlalchemy import Select, StatementLambdaElement
from typing import Callable

from sqlalchemy import ColumnElement, Select, StatementLambdaElement
from sqlalchemy.orm import InstrumentedAttribute
from typing_extensions import TypeAlias

Expand All @@ -32,7 +35,6 @@
"InAnyFilter",
"StatementFilter",
"StatementFilterT",
"TextSearchFilter",
)

T = TypeVar("T")
Expand Down Expand Up @@ -281,90 +283,62 @@ def append_to_lambda_statement(


@dataclass
class TextSearchFilter(StatementFilter):
class SearchFilter(StatementFilter):
"""Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause."""

field_name: str | set[str]
"""Name of the model attribute to search on."""
value: str
"""Values for ``NOT LIKE`` clause."""
ignore_case: bool | None = False
"""Should the search be case insensitive."""

@property
def _operator(self) -> Callable[..., ColumnElement[bool]]:
return or_

@property
def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]:
return attrgetter("ilike" if self.ignore_case else "like")

@property
def normalized_field_names(self) -> set[str]:
return {self.field_name} if isinstance(self.field_name, str) else self.field_name


@dataclass
class SearchFilter(TextSearchFilter):
"""Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause."""
def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]:
search_clause: list[BinaryExpression[bool]] = []
for field_name in self.normalized_field_names:
field = self._get_instrumented_attr(model, field_name)
search_text = f"%{self.value}%"
search_clause.append(self._func(field)(search_text))
return search_clause

def append_to_statement(
self,
statement: Select[tuple[ModelT]],
model: type[ModelT],
) -> Select[tuple[ModelT]]:
fields = self.normalized_field_names
search_clause: list[BinaryExpression[bool]] = []
for field_name in fields:
field = self._get_instrumented_attr(model, field_name)
search_text = f"%{self.value}%"
if self.ignore_case:
search_clause.append(field.ilike(search_text))
else:
search_clause.append(field.like(search_text))
return statement.where(or_(*search_clause))
where_clause = self._operator(*self.get_search_clauses(model))
return statement.where(where_clause)

def append_to_lambda_statement(
self,
statement: StatementLambdaElement,
model: type[ModelT],
) -> StatementLambdaElement:
fields = self.normalized_field_names
search_clause: list[BinaryExpression[bool]] = []
for field_name in fields:
field = self._get_instrumented_attr(model, field_name)
search_text = f"%{self.value}%"
if self.ignore_case:
search_clause.append(field.ilike(search_text))
else:
search_clause.append(field.like(search_text))
statement += lambda s: s.where(or_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType]
where_clause = self._operator(*self.get_search_clauses(model))
statement += lambda s: s.where(where_clause)
return statement


@dataclass
class NotInSearchFilter(TextSearchFilter):
class NotInSearchFilter(SearchFilter):
"""Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""

def append_to_statement(
self,
statement: Select[tuple[ModelT]],
model: type[ModelT],
) -> Select[tuple[ModelT]]:
fields = self.normalized_field_names
search_clause: list[BinaryExpression[bool]] = []
for field_name in fields:
field = self._get_instrumented_attr(model, field_name)
search_text = f"%{self.value}%"
if self.ignore_case:
search_clause.append(field.not_ilike(search_text))
else:
search_clause.append(field.not_like(search_text))
return statement.where(and_(*search_clause))
@property
def _operator(self) -> Callable[..., ColumnElement[bool]]:
return and_

def append_to_lambda_statement(
self,
statement: StatementLambdaElement,
model: type[ModelT],
) -> StatementLambdaElement:
fields = self.normalized_field_names
search_clause: list[BinaryExpression[bool]] = []
for field_name in fields:
field = self._get_instrumented_attr(model, field_name)
search_text = f"%{self.value}%"
if self.ignore_case:
search_clause.append(field.not_ilike(search_text))
else:
search_clause.append(field.not_like(search_text))
statement += lambda s: s.where(and_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType]
return statement
@property
def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]:
return attrgetter("not_ilike" if self.ignore_case else "not_like")
Loading