Skip to content

Commit

Permalink
Fix the SQL constraint violation logic (distributed-system-analysis#3528
Browse files Browse the repository at this point in the history
)

* Fix the SQL constraint violation logic

PBENCH-1251

This is mostly a fix for the ops review problem discovered Thursday, that our
carefully unit-tested logic to decode SQLAlchemy `IntegrityError` instances to
identify the specific violated constraint (specifically, duplicate vs null)
doesn't work in a production environment.

The root problem is that the text strings reported differ between the unit
test environment (`sqlite3` engine) and the production environment
(`postgresql` engine).

In order to consolidate on the central constraint error decoder, (which was
why this wasn't done earlier), we need all SQL-related exception constructors
to be consistent, but I didn't like dropping all context-specific arguments.
Instead I've adopted a `**kwargs` model where the decoder passes through
arbitrary parameters, which are captured in the exception to provide context.

This also adopts a more "pythonic exception" model where `e.args` captures a
statically formatted message string and allows the base `Exception.__str__` to
provide formatting. I've done this for all the SQLAlchemy model objects, but
nowhere else.

And to avoid complication in many places where we use `str(exception)`,
I'm not including the `kwargs` or raw `cause` in the `args` tuple. I played
with that idea, but since `Exception.__str__` formats the entire tuple, the
results are ugly and also more difficult to meaningfully compare in tests. We
use `str(exception)` a lot both internally and in tests.
  • Loading branch information
dbutenhof authored Aug 28, 2023
1 parent d731803 commit 3d9261e
Show file tree
Hide file tree
Showing 15 changed files with 346 additions and 405 deletions.
41 changes: 27 additions & 14 deletions lib/pbench/server/database/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Callable
from typing import Callable, Optional

from sqlalchemy import DateTime
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -54,24 +54,37 @@ def process_result_value(self, value, dialect):
return value


def decode_integrity_error(
exception: IntegrityError, on_null: Callable, on_duplicate: Callable
def decode_sql_error(
exception: Exception,
on_null: Callable[[Exception], Exception],
on_duplicate: Callable[[Exception], Exception],
fallback: Optional[Callable[[Exception], Exception]] = None,
**kwargs
) -> Exception:

"""Decode a SQLAlchemy IntegrityError to look for a recognizable UNIQUE
or NOT NULL constraint violation.
"""Analyze an exception for a SQL constraint violation
Return the original exception if it doesn't match.
Analyzes SQLAlchemy IntegrityException instances for NOT NULL and UNIQUE
KEY constraints, constructing and returning an appropriate exception
instance. If the exception doesn't match a recognized SQL constraint,
construct and return a fallback exception instance if specified or the
original exception.
Args:
exception : An IntegrityError to decode
exception: An exception to decode
on_null: Exception class to build if null contraint
on_duplicate: Exception class to build if duplicate constraint
fallback: Exception class to build otherwise
kwargs: additional arguments passed to exception constructors
Returns:
a more specific exception, or the original if decoding fails
a more specific exception, or the original if no matches are found and
no fallback template is provided.
"""
cause = exception.orig.args[-1]
if "UNIQUE constraint" in cause:
return on_duplicate(cause)
elif "NOT NULL constraint" in cause:
return on_null(cause)
return exception
if isinstance(exception, IntegrityError):
cause = exception.orig.args[-1].lower()
if "unique constraint" in cause:
return on_duplicate(exception, **kwargs)
elif "not null constraint" in cause:
return on_null(exception, **kwargs)
return exception if not fallback else fallback(exception, **kwargs)
47 changes: 25 additions & 22 deletions lib/pbench/server/database/models/api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pbench.server import JSONOBJECT
from pbench.server.database.database import Database
from pbench.server.database.models import decode_integrity_error, TZDateTime
from pbench.server.database.models import decode_sql_error, TZDateTime
from pbench.server.database.models.users import User

# Module private constants
Expand All @@ -17,31 +17,34 @@
class APIKeyError(Exception):
"""A base class for errors reported by the APIKey class."""

def __init__(self, message):
self.message = message
pass

def __str__(self):
return repr(self.message)

class APIKeySqlError(APIKeyError):
"""Report a generic SQL error"""

def __init__(self, cause: Exception, **kwargs):
super().__init__(f"API key SQL error: '{cause}' {kwargs}")
self.cause = cause
self.kwargs = kwargs


class DuplicateApiKey(APIKeyError):
"""Attempt to commit a duplicate unique value."""

def __init__(self, cause: str):
def __init__(self, cause: Exception, **kwargs):
super().__init__(f"API key duplicate key error: '{cause}' {kwargs}")
self.cause = cause

def __str__(self) -> str:
return f"Duplicate api_key: {self.cause}"
self.kwargs = kwargs


class NullKey(APIKeyError):
"""Attempt to commit an APIkey row with an empty required column."""

def __init__(self, cause: str):
def __init__(self, cause: Exception, **kwargs):
super().__init__(f"API key null key error: '{cause}' {kwargs}")
self.cause = cause

def __str__(self) -> str:
return f"Missing required value: {self.cause}"
self.kwargs = kwargs


class APIKey(Database.Base):
Expand All @@ -68,14 +71,14 @@ def add(self):
Database.db_session.commit()
except Exception as e:
Database.db_session.rollback()
self.logger.error("Can't add {} to DB: {}", str(self), str(e))
decode_exc = decode_integrity_error(
e, on_duplicate=DuplicateApiKey, on_null=NullKey
)
if decode_exc is e:
raise APIKeyError(str(e)) from e
else:
raise decode_exc from e
raise decode_sql_error(
e,
on_duplicate=DuplicateApiKey,
on_null=NullKey,
fallback=APIKeySqlError,
operation="add",
key=self,
) from e

@staticmethod
def query(**kwargs) -> Optional["APIKey"]:
Expand All @@ -99,7 +102,7 @@ def delete(self):
Database.db_session.commit()
except Exception as e:
Database.db_session.rollback()
raise APIKeyError(f"Error deleting api_key from db : {e}") from e
raise APIKeySqlError(e, operation="delete", key=self) from e

def as_json(self) -> JSONOBJECT:
"""Return a JSON object for this APIkey object.
Expand Down
77 changes: 24 additions & 53 deletions lib/pbench/server/database/models/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Optional

from sqlalchemy import Column, Enum, Integer, String
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql.sqltypes import JSON

from pbench.server import JSONOBJECT, OperationCode
from pbench.server.database.database import Database
from pbench.server.database.models import TZDateTime
from pbench.server.database.models import decode_sql_error, TZDateTime
from pbench.server.database.models.datasets import Dataset
from pbench.server.database.models.users import User

Expand All @@ -19,6 +19,8 @@ class AuditError(Exception):
It is never raised directly, but may be used in "except" clauses.
"""

pass


class AuditSqlError(AuditError):
"""SQLAlchemy errors reported through Audit operations.
Expand All @@ -27,35 +29,28 @@ class AuditSqlError(AuditError):
key; the cause will specify the original SQLAlchemy exception.
"""

def __init__(self, operation: str, params: JSONOBJECT, cause: str):
self.operation = operation
self.params = params
def __init__(self, cause: Exception, **kwargs):
super().__init__(f"SQL error on {kwargs}: '{cause}'")
self.cause = cause

def __str__(self) -> str:
return f"Error {self.operation} {self.params!r}: {self.cause}"
self.kwargs = kwargs


class AuditDuplicate(AuditError):
"""Attempt to commit a duplicate unique value."""

def __init__(self, audit: "Audit", cause: str):
self.audit = audit
def __init__(self, cause: Exception, **kwargs):
super().__init__(f"Duplicate key in {kwargs}: '{cause}'")
self.cause = cause

def __str__(self) -> str:
return f"Duplicate audit setting in {self.audit.as_json()}: {self.cause}"
self.kwargs = kwargs


class AuditNullKey(AuditError):
"""Attempt to commit an Audit row with an empty required column."""

def __init__(self, audit: "Audit", cause: str):
self.audit = audit
def __init__(self, cause: Exception, **kwargs):
super().__init__(f"Missing required key in {kwargs}: '{cause}'")
self.cause = cause

def __str__(self) -> str:
return f"Missing required key in {self.audit.as_json()}: {self.cause}"
self.kwargs = kwargs


class AuditType(enum.Enum):
Expand Down Expand Up @@ -333,50 +328,26 @@ def query(

audit = query.order_by(Audit.timestamp).all()
except SQLAlchemyError as e:
args = {
k: v
for k, v in (
("start", start),
("end", end),
("dataset", dataset),
*kwargs.items(),
)
if v is not None
}
raise AuditSqlError("finding", args, str(e)) from e
raise AuditSqlError(
e, operation="query", start=start, end=end, dataset=dataset, **kwargs
) from e
return audit

def _decode(self, exception: IntegrityError) -> Exception:
"""Decode a SQLAlchemy IntegrityError to look for a recognizable UNIQUE
or NOT NULL constraint violation.
Return the original exception if it doesn't match.
Args:
exception : An IntegrityError to decode
Returns:
a more specific exception, or the original if decoding fails
"""
# Postgres engine returns (code, message) but sqlite3 engine only
# returns (message); so always take the last element.
cause = exception.orig.args[-1]
if "UNIQUE constraint" in cause:
return AuditDuplicate(self, cause)
elif "NOT NULL constraint" in cause:
return AuditNullKey(self, cause)
return exception

def add(self):
"""Add the Audit object to the database."""
try:
Database.db_session.add(self)
Database.db_session.commit()
except Exception as e:
Database.db_session.rollback()
if isinstance(e, IntegrityError):
raise self._decode(e) from e
raise AuditSqlError("adding", self.as_json(), str(e)) from e
raise decode_sql_error(
e,
on_duplicate=AuditDuplicate,
on_null=AuditNullKey,
fallback=AuditSqlError,
operation="add",
audit=self,
) from e

def as_json(self) -> JSONOBJECT:
"""Return a JSON object for this Audit object.
Expand Down
Loading

0 comments on commit 3d9261e

Please sign in to comment.