Skip to content

Commit

Permalink
Allow the user to hard ignore certain tables from soft-deletion (#23)
Browse files Browse the repository at this point in the history
* Update README.md with usage examples for ignored_tables=

* Change deprecated stmt.froms to stmt.get_final_froms()

* Updated isort in .pre-commit-config.yaml

* Implemented and fixed tests.
  • Loading branch information
flipbit03 committed Aug 27, 2023
1 parent 2f221ce commit 221ea6f
Show file tree
Hide file tree
Showing 14 changed files with 145 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.8.2
current_version = 0.8.3
commit = True
tag = False

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
exclude: "tests/snapshots/.*"
args: [ --unsafe ]
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
exclude: "tests/snapshots/.*"
Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ pip install sqlalchemy-easy-softdelete

```py
from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class
from sqlalchemy_easy_softdelete.hook import IgnoredTable
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, Integer
from datetime import datetime

# Create a Class that inherits from our class builder
class SoftDeleteMixin(generate_soft_delete_mixin_class()):
class SoftDeleteMixin(generate_soft_delete_mixin_class(
# This table will be ignored by the hook
# even if the table has the soft-delete column
ignored_tables=[IgnoredTable(table_schema="public", name="cars"),]
)):
# type hint for autocomplete IDE support
deleted_at: datetime

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool]
[tool.poetry]
name = "sqlalchemy-easy-softdelete"
version = "0.8.2"
version = "0.8.3"
homepage = "https://github.com/flipbit03/sqlalchemy-easy-softdelete"
description = "Easily add soft-deletion to your SQLAlchemy Models."
authors = ["Cadu <[email protected]>"]
Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy_easy_softdelete/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__author__ = """Cadu"""
__email__ = '[email protected]'
__version__ = '0.8.2'
__version__ = '0.8.3'
38 changes: 30 additions & 8 deletions sqlalchemy_easy_softdelete/handler/rewriter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Main query rewriter logic."""

from __future__ import annotations

from typing import TypeVar, Union

from sqlalchemy import Table
Expand All @@ -8,13 +10,20 @@
from sqlalchemy.sql import Alias, CompoundSelect, Executable, Join, Select, Subquery, TableClause
from sqlalchemy.sql.elements import TextClause

from sqlalchemy_easy_softdelete.hook import IgnoredTable

Statement = TypeVar('Statement', bound=Union[Select, FromStatement, CompoundSelect, Executable])


class SoftDeleteQueryRewriter:
"""Rewrites SQL statements based on configuration."""

def __init__(self, deleted_field_name: str, disable_soft_delete_option_name: str):
def __init__(
self,
deleted_field_name: str,
disable_soft_delete_option_name: str,
ignored_tables: list[IgnoredTable] | None = None,
):
"""
Instantiate a new query rewriter.
Expand All @@ -29,6 +38,8 @@ def __init__(self, deleted_field_name: str, disable_soft_delete_option_name: str
soft deletion rewriting in a query
"""
"""List of table names that should be ignored from soft-deletion"""
self.ignored_tables = ignored_tables or []
self.deleted_field_name = deleted_field_name
self.disable_soft_delete_option_name = disable_soft_delete_option_name

Expand Down Expand Up @@ -133,13 +144,24 @@ def analyze_from(self, stmt: Select, from_obj):
raise NotImplementedError(f"Unsupported object \"{(type(from_obj))}\" in statement.froms")

def rewrite_from_table(self, stmt: Select, table: Table) -> Select:
"""(possibly) Rewrite a Select based on whether the Table contains the soft-delete field or not."""
"""
(possibly) Rewrite a Select based on whether the Table contains the soft-delete field or not.
Ignore tables named like the ignore_tabl
"""
# Early return if the table is ignored
if any(ignored.match_name(table) for ignored in self.ignored_tables):
return stmt

# Try to retrieve the column object
column_obj = table.columns.get(self.deleted_field_name)

# Caveat: The automatic "bool(column_obj)" conversion actually returns
# a truthy value of False (?), so we have to explicitly compare against None
if column_obj is not None:
return stmt.filter(column_obj.is_(None))
# If the column object is not found, return unchanged statement
# Caveat: The automatic "bool(column_obj)" conversion actually returns a truthy value of False (?),
# so we have to explicitly compare against None
if column_obj is None:
return stmt

# Soft-delete argument was not found, return unchanged statement
return stmt
# Column found. Rewrite the statement with a filter condition in the soft-delete column
return stmt.filter(column_obj.is_(None))
27 changes: 20 additions & 7 deletions sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
"""This module is responsible for activating the query rewriter."""

from functools import cache
from typing import List, Optional

from sqlalchemy.event import listens_for
from sqlalchemy.orm import ORMExecuteState, Session

from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter
from sqlalchemy_easy_softdelete.hook import IgnoredTable

global_rewriter: Optional[SoftDeleteQueryRewriter] = None

@cache
def activate_soft_delete_hook(deleted_field_name: str, disable_soft_delete_option_name: str):

def activate_soft_delete_hook(
deleted_field_name: str, disable_soft_delete_option_name: str, ignored_tables: List[IgnoredTable]
):
"""Activate an event hook to rewrite the queries."""

global global_rewriter
global_rewriter = SoftDeleteQueryRewriter(
deleted_field_name=deleted_field_name,
disable_soft_delete_option_name=disable_soft_delete_option_name,
ignored_tables=ignored_tables,
)

# Enable Soft Delete on all Relationship Loads which implement SoftDeleteMixin
@listens_for(Session, "do_orm_execute")
@listens_for(Session, identifier="do_orm_execute")
def soft_delete_execute(state: ORMExecuteState):
if not state.is_select:
return

adapted = SoftDeleteQueryRewriter(deleted_field_name, disable_soft_delete_option_name).rewrite_statement(
state.statement
)
# Rewrite the statement
adapted = global_rewriter.rewrite_statement(state.statement)

# Replace the statement
state.statement = adapted
14 changes: 14 additions & 0 deletions sqlalchemy_easy_softdelete/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Optional

from sqlalchemy import Table


@dataclass
class IgnoredTable:
name: str
table_schema: Optional[str] = None

def match_name(self, table: Table):
# Table matches if the name and schema match
return self.name == table.name and self.table_schema == table.schema
8 changes: 7 additions & 1 deletion sqlalchemy_easy_softdelete/mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions related to dynamic generation of the soft-delete mixin."""
from __future__ import annotations

from datetime import datetime
from typing import Any, Callable, Optional, Type
Expand All @@ -7,10 +8,12 @@
from sqlalchemy.sql.type_api import TypeEngine

from sqlalchemy_easy_softdelete.handler.sqlalchemy_easy_softdelete import activate_soft_delete_hook
from sqlalchemy_easy_softdelete.hook import IgnoredTable


def generate_soft_delete_mixin_class(
deleted_field_name: str = "deleted_at",
ignored_tables: list[IgnoredTable] | None = None,
class_name: str = "_SoftDeleteMixin",
deleted_field_type: TypeEngine = DateTime(timezone=True),
disable_soft_delete_filtering_option_name: str = "include_deleted",
Expand All @@ -21,6 +24,9 @@ def generate_soft_delete_mixin_class(
undelete_method_name: str = "undelete",
) -> Type:
"""Generate the actual soft-delete Mixin class."""
if not ignored_tables:
ignored_tables = []

class_attributes = {deleted_field_name: Column(deleted_field_name, deleted_field_type)}

if generate_delete_method:
Expand All @@ -37,7 +43,7 @@ def undelete_method(_self):

class_attributes[undelete_method_name] = undelete_method

activate_soft_delete_hook(deleted_field_name, disable_soft_delete_filtering_option_name)
activate_soft_delete_hook(deleted_field_name, disable_soft_delete_filtering_option_name, ignored_tables)

generated_class = type(class_name, tuple(), class_attributes)

Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,6 @@ def seeded_session(db_session) -> Session:

@pytest.fixture
def rewriter() -> SoftDeleteQueryRewriter:
return SoftDeleteQueryRewriter("deleted_at", "include_deleted")
from sqlalchemy_easy_softdelete.handler.sqlalchemy_easy_softdelete import global_rewriter

return global_rewriter
19 changes: 17 additions & 2 deletions tests/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import datetime
from typing import List

from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import as_declarative, declared_attr, relationship

from sqlalchemy_easy_softdelete.hook import IgnoredTable
from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class


Expand All @@ -19,7 +20,13 @@ def __repr__(self):
return f"<{self.__class__.__name__} id={self.id}>"


class SoftDeleteMixin(generate_soft_delete_mixin_class()):
class SoftDeleteMixin(
generate_soft_delete_mixin_class(
ignored_tables=[
IgnoredTable(table_schema=None, name='sdtablethatshouldnotbesoftdeleted'),
],
)
):
# for autocomplete
deleted_at: datetime

Expand Down Expand Up @@ -86,3 +93,11 @@ class SDDerivedRequest(SDBaseRequest):
__mapper_args__ = {
"polymorphic_identity": "sdderivedrequest",
}


class SDTableThatShouldNotBeSoftDeleted(TestModelBase):
id: Integer = Column(Integer, primary_key=True)
deleted_at: datetime = Column(DateTime(timezone=True))

def __repr__(self):
return f"<{self.__class__.__name__} id={self.id} name={self.name}>"
29 changes: 28 additions & 1 deletion tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from sqlalchemy.orm import Query
from sqlalchemy.sql import Select

from tests.model import SDBaseRequest, SDChild, SDChildChild, SDDerivedRequest, SDParent, SDSimpleTable
from tests.model import (
SDBaseRequest,
SDChild,
SDChildChild,
SDDerivedRequest,
SDParent,
SDSimpleTable,
SDTableThatShouldNotBeSoftDeleted,
)
from tests.utils import is_filtering_for_softdeleted


Expand Down Expand Up @@ -195,3 +203,22 @@ def test_query_with_more_than_one_join(snapshot, seeded_session, rewriter):
)
is True
)


def test_query_with_same_field_as_softdelete_field_but_ignored(seeded_session, rewriter):
"""Test that a query with a field that has the same name as the soft-delete field
but is ignored, does not get rewritten"""

test_query = seeded_session.query(SDTableThatShouldNotBeSoftDeleted)

soft_deleted_rewritten_statement = rewriter.rewrite_statement(test_query.statement)

assert (
is_filtering_for_softdeleted(
soft_deleted_rewritten_statement,
{
SDTableThatShouldNotBeSoftDeleted.__table__,
},
)
is False
)
6 changes: 5 additions & 1 deletion tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def is_simple_select_doing_soft_delete_filtering(stmt: Select, tables: set[Table
# Skip checking in this query
return True

assert stmt.whereclause is not None
# if we don't have a where clause, we can't be filtering for soft-deleted
# Caveat: We need to compare with None, since and whereclause usually does not have a __bool__ method
if stmt.whereclause is None:
return False

binary_expressions = extract_binary_expressions_from_where(stmt.whereclause)

found_tables = set()
Expand Down
19 changes: 11 additions & 8 deletions tests/utils/simple_select_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
We consider SIMPLE SELECT STATEMENTS to be those that their froms are tables, and not subqueries.
"""
from __future__ import annotations

from typing import Union

from sqlalchemy.orm.util import _ORMJoin
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import CompoundSelect, Join, Select, Subquery
from sqlalchemy.sql.selectable import CompoundSelect, Join, Select, SelectBase, Subquery


def is_simple_join(j: Union[Join, _ORMJoin]) -> bool:
Expand All @@ -33,10 +35,11 @@ def is_simple_select(s: Union[Select, Subquery, CompoundSelect]) -> bool:
if isinstance(s, Subquery):
return False

if not isinstance(s.froms, list):
final_froms = s.get_final_froms()
if not isinstance(final_froms, list):
raise NotImplementedError(f"statement.froms is not a list! type -> \"{(type(s.froms))}\"!")

for from_obj in s.froms:
for from_obj in final_froms:
if isinstance(from_obj, Table):
continue
elif isinstance(from_obj, Subquery):
Expand All @@ -51,17 +54,17 @@ def is_simple_select(s: Union[Select, Subquery, CompoundSelect]) -> bool:
return True


def extract_simple_selects(statement: Union[Select, CompoundSelect]) -> list[Select]:
def extract_simple_selects(statement: Select | CompoundSelect | SelectBase) -> list[SelectBase]:
if is_simple_select(statement):
return [statement]

if isinstance(statement, CompoundSelect):
extraced_selects = []
extracted_elements = []
for select in statement.selects:
extraced_selects.extend(extract_simple_selects(select))
return extraced_selects
extracted_elements.extend(extract_simple_selects(select))
return extracted_elements

for from_obj in statement.froms:
for from_obj in statement.get_final_froms():
if isinstance(from_obj, Table):
continue
elif isinstance(from_obj, Subquery):
Expand Down

0 comments on commit 221ea6f

Please sign in to comment.