|
6 | 6 | from collections import abc # noqa: TCH003
|
7 | 7 | from dataclasses import dataclass
|
8 | 8 | from datetime import datetime # noqa: TCH003
|
| 9 | +from operator import attrgetter |
9 | 10 | from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
10 | 11 |
|
11 | 12 | from sqlalchemy import BinaryExpression, and_, any_, or_, text
|
12 | 13 |
|
13 | 14 | if TYPE_CHECKING:
|
14 |
| - from sqlalchemy import Select, StatementLambdaElement |
| 15 | + from typing import Callable |
| 16 | + |
| 17 | + from sqlalchemy import ColumnElement, Select, StatementLambdaElement |
15 | 18 | from sqlalchemy.orm import InstrumentedAttribute
|
16 | 19 | from typing_extensions import TypeAlias
|
17 | 20 |
|
|
32 | 35 | "InAnyFilter",
|
33 | 36 | "StatementFilter",
|
34 | 37 | "StatementFilterT",
|
35 |
| - "TextSearchFilter", |
36 | 38 | )
|
37 | 39 |
|
38 | 40 | T = TypeVar("T")
|
@@ -281,90 +283,56 @@ def append_to_lambda_statement(
|
281 | 283 |
|
282 | 284 |
|
283 | 285 | @dataclass
|
284 |
| -class TextSearchFilter(StatementFilter): |
| 286 | +class SearchFilter(StatementFilter): |
| 287 | + """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" |
285 | 288 | field_name: str | set[str]
|
286 | 289 | """Name of the model attribute to search on."""
|
287 | 290 | value: str
|
288 | 291 | """Values for ``NOT LIKE`` clause."""
|
289 | 292 | ignore_case: bool | None = False
|
290 | 293 | """Should the search be case insensitive."""
|
291 | 294 |
|
| 295 | + _operator: Callable[..., ColumnElement[bool]] = or_ |
| 296 | + |
| 297 | + @property |
| 298 | + def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: |
| 299 | + return attrgetter("ilike" if self.ignore_case else "like") |
| 300 | + |
292 | 301 | @property
|
293 | 302 | def normalized_field_names(self) -> set[str]:
|
294 | 303 | return {self.field_name} if isinstance(self.field_name, str) else self.field_name
|
295 | 304 |
|
296 |
| - |
297 |
| -@dataclass |
298 |
| -class SearchFilter(TextSearchFilter): |
299 |
| - """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" |
| 305 | + def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]: |
| 306 | + search_clause: list[BinaryExpression[bool]] = [] |
| 307 | + for field_name in self.normalized_field_names: |
| 308 | + field = self._get_instrumented_attr(model, field_name) |
| 309 | + search_text = f"%{self.value}%" |
| 310 | + search_clause.append(self._func(field)(search_text)) |
| 311 | + return search_clause |
300 | 312 |
|
301 | 313 | def append_to_statement(
|
302 | 314 | self,
|
303 | 315 | statement: Select[tuple[ModelT]],
|
304 | 316 | model: type[ModelT],
|
305 | 317 | ) -> Select[tuple[ModelT]]:
|
306 |
| - fields = self.normalized_field_names |
307 |
| - search_clause: list[BinaryExpression[bool]] = [] |
308 |
| - for field_name in fields: |
309 |
| - field = self._get_instrumented_attr(model, field_name) |
310 |
| - search_text = f"%{self.value}%" |
311 |
| - if self.ignore_case: |
312 |
| - search_clause.append(field.ilike(search_text)) |
313 |
| - else: |
314 |
| - search_clause.append(field.like(search_text)) |
315 |
| - return statement.where(or_(*search_clause)) |
| 318 | + where_clause = self._operator(*self.get_search_clauses(model)) |
| 319 | + return statement.where(where_clause) |
316 | 320 |
|
317 | 321 | def append_to_lambda_statement(
|
318 | 322 | self,
|
319 | 323 | statement: StatementLambdaElement,
|
320 | 324 | model: type[ModelT],
|
321 | 325 | ) -> StatementLambdaElement:
|
322 |
| - fields = self.normalized_field_names |
323 |
| - search_clause: list[BinaryExpression[bool]] = [] |
324 |
| - for field_name in fields: |
325 |
| - field = self._get_instrumented_attr(model, field_name) |
326 |
| - search_text = f"%{self.value}%" |
327 |
| - if self.ignore_case: |
328 |
| - search_clause.append(field.ilike(search_text)) |
329 |
| - else: |
330 |
| - search_clause.append(field.like(search_text)) |
331 |
| - statement += lambda s: s.where(or_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] |
| 326 | + where_clause = self._operator(*self.get_search_clauses(model)) |
| 327 | + statement += lambda s: s.where(where_clause) |
332 | 328 | return statement
|
333 | 329 |
|
334 | 330 |
|
335 | 331 | @dataclass
|
336 |
| -class NotInSearchFilter(TextSearchFilter): |
| 332 | +class NotInSearchFilter(SearchFilter): |
337 | 333 | """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
|
| 334 | + _operator: Callable[..., ColumnElement[bool]] = and_ |
338 | 335 |
|
339 |
| - def append_to_statement( |
340 |
| - self, |
341 |
| - statement: Select[tuple[ModelT]], |
342 |
| - model: type[ModelT], |
343 |
| - ) -> Select[tuple[ModelT]]: |
344 |
| - fields = self.normalized_field_names |
345 |
| - search_clause: list[BinaryExpression[bool]] = [] |
346 |
| - for field_name in fields: |
347 |
| - field = self._get_instrumented_attr(model, field_name) |
348 |
| - search_text = f"%{self.value}%" |
349 |
| - if self.ignore_case: |
350 |
| - search_clause.append(field.not_ilike(search_text)) |
351 |
| - else: |
352 |
| - search_clause.append(field.not_like(search_text)) |
353 |
| - return statement.where(and_(*search_clause)) |
354 |
| - |
355 |
| - def append_to_lambda_statement( |
356 |
| - self, |
357 |
| - statement: StatementLambdaElement, |
358 |
| - model: type[ModelT], |
359 |
| - ) -> StatementLambdaElement: |
360 |
| - fields = self.normalized_field_names |
361 |
| - search_clause: list[BinaryExpression[bool]] = [] |
362 |
| - for field_name in fields: |
363 |
| - field = self._get_instrumented_attr(model, field_name) |
364 |
| - search_text = f"%{self.value}%" |
365 |
| - if self.ignore_case: |
366 |
| - search_clause.append(field.not_ilike(search_text)) |
367 |
| - else: |
368 |
| - search_clause.append(field.not_like(search_text)) |
369 |
| - statement += lambda s: s.where(and_(*search_clause)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] |
370 |
| - return statement |
| 336 | + @property |
| 337 | + def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: |
| 338 | + return attrgetter("not_ilike" if self.ignore_case else "not_like") |
0 commit comments