|
6 | 6 | from collections import abc # noqa: TCH003
|
7 | 7 | from dataclasses import dataclass
|
8 | 8 | from datetime import datetime # noqa: TCH003
|
| 9 | +from operator import attrgetter |
9 | 10 | from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
10 | 11 |
|
11 | 12 | from sqlalchemy import BinaryExpression, and_, any_, or_, text
|
12 | 13 |
|
13 | 14 | if TYPE_CHECKING:
|
| 15 | + from typing import Callable |
| 16 | + |
14 | 17 | from sqlalchemy import Select, StatementLambdaElement
|
15 | 18 | from sqlalchemy.orm import InstrumentedAttribute
|
16 | 19 | from typing_extensions import TypeAlias
|
@@ -295,76 +298,49 @@ def normalized_field_names(self) -> set[str]:
|
295 | 298 |
|
296 | 299 |
|
297 | 300 | @dataclass
|
298 |
| -class SearchFilter(TextSearchFilter): |
299 |
| - """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" |
| 301 | +class SearchFilter(StatementFilter): |
| 302 | + field_name: str | set[str] |
| 303 | + value: str |
| 304 | + ignore_case: bool | None = False |
300 | 305 |
|
301 |
| - def append_to_statement( |
302 |
| - self, |
303 |
| - statement: Select[tuple[ModelT]], |
304 |
| - model: type[ModelT], |
305 |
| - ) -> Select[tuple[ModelT]]: |
306 |
| - fields = self.normalized_field_names |
307 |
| - search_clause: list[BinaryExpression[bool]] = [] |
308 |
| - for field_name in fields: |
309 |
| - field = self._get_instrumented_attr(model, field_name) |
310 |
| - search_text = f"%{self.value}%" |
311 |
| - if self.ignore_case: |
312 |
| - search_clause.append(field.ilike(search_text)) |
313 |
| - else: |
314 |
| - search_clause.append(field.like(search_text)) |
315 |
| - return statement.where(or_(*search_clause)) |
| 306 | + _operator: Any = or_ |
316 | 307 |
|
317 |
| - def append_to_lambda_statement( |
318 |
| - self, |
319 |
| - statement: StatementLambdaElement, |
320 |
| - model: type[ModelT], |
321 |
| - ) -> StatementLambdaElement: |
322 |
| - fields = self.normalized_field_names |
| 308 | + @property |
| 309 | + def func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: |
| 310 | + return attrgetter("ilike" if self.ignore_case else "like") |
| 311 | + |
| 312 | + @property |
| 313 | + def normalized_field_names(self) -> set[str]: |
| 314 | + return {self.field_name} if isinstance(self.field_name, str) else self.field_name |
| 315 | + |
| 316 | + def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]: |
323 | 317 | search_clause: list[BinaryExpression[bool]] = []
|
324 |
| - for field_name in fields: |
| 318 | + for field_name in self.normalized_field_names: |
325 | 319 | field = self._get_instrumented_attr(model, field_name)
|
326 | 320 | search_text = f"%{self.value}%"
|
327 |
| - if self.ignore_case: |
328 |
| - search_clause.append(field.ilike(search_text)) |
329 |
| - else: |
330 |
| - search_clause.append(field.like(search_text)) |
331 |
| - statement += lambda s: s.where(or_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] |
332 |
| - return statement |
333 |
| - |
334 |
| - |
335 |
| -@dataclass |
336 |
| -class NotInSearchFilter(TextSearchFilter): |
337 |
| - """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause.""" |
| 321 | + search_clause.append(self.func(field)(search_text)) |
| 322 | + return search_clause |
338 | 323 |
|
339 | 324 | def append_to_statement(
|
340 | 325 | self,
|
341 | 326 | statement: Select[tuple[ModelT]],
|
342 | 327 | model: type[ModelT],
|
343 | 328 | ) -> Select[tuple[ModelT]]:
|
344 |
| - fields = self.normalized_field_names |
345 |
| - search_clause: list[BinaryExpression[bool]] = [] |
346 |
| - for field_name in fields: |
347 |
| - field = self._get_instrumented_attr(model, field_name) |
348 |
| - search_text = f"%{self.value}%" |
349 |
| - if self.ignore_case: |
350 |
| - search_clause.append(field.not_ilike(search_text)) |
351 |
| - else: |
352 |
| - search_clause.append(field.not_like(search_text)) |
353 |
| - return statement.where(and_(*search_clause)) |
| 329 | + return statement.where(self._operator(*self.get_search_clauses(model))) |
354 | 330 |
|
355 | 331 | def append_to_lambda_statement(
|
356 | 332 | self,
|
357 | 333 | statement: StatementLambdaElement,
|
358 | 334 | model: type[ModelT],
|
359 | 335 | ) -> StatementLambdaElement:
|
360 |
| - fields = self.normalized_field_names |
361 |
| - search_clause: list[BinaryExpression[bool]] = [] |
362 |
| - for field_name in fields: |
363 |
| - field = self._get_instrumented_attr(model, field_name) |
364 |
| - search_text = f"%{self.value}%" |
365 |
| - if self.ignore_case: |
366 |
| - search_clause.append(field.not_ilike(search_text)) |
367 |
| - else: |
368 |
| - search_clause.append(field.not_like(search_text)) |
369 |
| - statement += lambda s: s.where(and_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] |
| 336 | + statement += lambda s: s.where(self._operator(*self.get_search_clauses(model))) |
370 | 337 | return statement
|
| 338 | + |
| 339 | + |
| 340 | +@dataclass |
| 341 | +class NotInSearchFilter(SearchFilter): |
| 342 | + _operator: Any = and_ |
| 343 | + |
| 344 | + @property |
| 345 | + def func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: |
| 346 | + return attrgetter("not_ilike" if self.ignore_case else "not_like") |
0 commit comments