Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: backup/rollback instead of dropping in override mode #1941

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class {{ cookiecutter.destination_name }}Connector(SQLConnector):
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_overwrite: bool = False # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.

def get_sqlalchemy_url(self, config: dict) -> str:
Expand Down
1 change: 1 addition & 0 deletions samples/sample_target_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SQLiteConnector(SQLConnector):
allow_temp_tables = False
allow_column_alter = False
allow_merge_upsert = True
allow_overwrite: bool = True

def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str:
"""Generates a SQLAlchemy URL for SQLite."""
Expand Down
63 changes: 63 additions & 0 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema
from singer_sdk.exceptions import ConfigValidationError
from singer_sdk.helpers.capabilities import TargetLoadMethods

if t.TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
Expand All @@ -40,6 +41,7 @@ class SQLConnector:
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_overwrite: bool = False # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.
_cached_engine: Engine | None = None

Expand Down Expand Up @@ -740,13 +742,57 @@ def prepare_schema(self, schema_name: str) -> None:
if not schema_exists:
self.create_schema(schema_name)

@staticmethod
def get_alter_table_name_ddl(
existing_table_name: str,
new_table_name: str,
) -> tuple[TextClause, dict]:
"""Get the alter table name SQL statement.

Override this if your database uses a different syntax for alter table names.

Args:
existing_table_name: Fully qualified table name of column to alter.
new_table_name: The new fully qualified table name.

Returns:
A tuple of the SQL statement and a dictionary of bind parameters.
"""
return (
text(f"ALTER TABLE {existing_table_name} RENAME TO {new_table_name}"),
{},
)

def alter_table_name(
self, existing_table_name: str, new_full_table_name: str
) -> None:
"""Alter target table name.

Args:
full_table_name: Fully qualified table name of column to alter.

Raises:
NotImplementedError: if truncating tables is not supported.
"""
if not self.allow_overwrite:
msg = "Altering table names for overwrite load method is not supported."
raise NotImplementedError(msg)

alter_table_name_ddl, kwargs = self.get_alter_table_name_ddl(
existing_table_name,
new_full_table_name,
)
with self._connect() as conn, conn.begin():
conn.execute(alter_table_name_ddl, **kwargs)

def prepare_table(
self,
full_table_name: str,
schema: dict,
primary_keys: list[str],
partition_keys: list[str] | None = None,
as_temp_table: bool = False, # noqa: FBT002, FBT001
sync_started_at: int | None = None,
) -> None:
"""Adapt target table to provided schema if possible.

Expand All @@ -766,6 +812,19 @@ def prepare_table(
as_temp_table=as_temp_table,
)
return
if self.config["load_method"] == TargetLoadMethods.OVERWRITE:
self.alter_table_name(
full_table_name, f"{full_table_name}_{sync_started_at}"
)

self.create_empty_table(
full_table_name=full_table_name,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
)
return

for property_name, property_def in schema["properties"].items():
self.prepare_column(
Expand All @@ -774,6 +833,10 @@ def prepare_table(
self.to_sql_type(property_def),
)

def drop_backup_table(self, full_table_name: str) -> None:
table = self.get_table(full_table_name=full_table_name)
table.drop(self._engine)

def prepare_column(
self,
full_table_name: str,
Expand Down
34 changes: 34 additions & 0 deletions singer_sdk/helpers/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,40 @@
).to_dict()


class TargetLoadMethods(str, Enum):
"""Target-specific capabilities."""

# always write all input records whether that records already exists or not
APPEND_ONLY = "append-only"

# update existing records and insert new records
UPSERT = "upsert"

# delete all existing records and insert all input records
OVERWRITE = "overwrite"


TARGET_LOAD_METHOD_CONFIG = PropertiesList(
Property(
"load_method",
StringType(),
description=(
"The method to use when loading data into the destination. "
"`append-only` will always write all input records whether that records "
"already exists or not. `upsert` will update existing records and insert "
"new records. `overwrite` will delete all existing records and insert all "
"input records."
),
allowed_values=[
TargetLoadMethods.APPEND_ONLY,
TargetLoadMethods.UPSERT,
TargetLoadMethods.OVERWRITE,
],
default=TargetLoadMethods.APPEND_ONLY,
),
).to_dict()


class DeprecatedEnum(Enum):
"""Base class for capabilities enumeration."""

Expand Down
16 changes: 16 additions & 0 deletions singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from singer_sdk.connectors import SQLConnector
from singer_sdk.exceptions import ConformedNameClashException
from singer_sdk.helpers._conformers import replace_leading_digit
from singer_sdk.helpers.capabilities import TargetLoadMethods
from singer_sdk.sinks.batch import BatchSink

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -242,6 +243,7 @@ def setup(self) -> None:
schema=self.conform_schema(self.schema),
primary_keys=self.key_properties,
as_temp_table=False,
sync_started_at=self.sync_started_at,
)

@property
Expand All @@ -264,6 +266,7 @@ def process_batch(self, context: dict) -> None:
"""
# If duplicates are merged, these can be tracked via
# :meth:`~singer_sdk.Sink.tally_duplicate_merged()`.
# Append only or overwrite use inserts
self.bulk_insert_records(
full_table_name=self.full_table_name,
schema=self.schema,
Expand Down Expand Up @@ -417,5 +420,18 @@ def activate_version(self, new_version: int) -> None:
with self.connector._connect() as conn, conn.begin():
conn.execute(query)

def clean_up(self) -> None:
"""Clean up any resources held by the sink."""
if self.config[
"load_method"
] == TargetLoadMethods.OVERWRITE and self.connector.table_exists(
full_table_name=f"{self.full_table_name}_{self.sync_started_at}"
):
# if success then drop the backup table, if fail revert to old table
backup_table = f"{self.full_table_name}_{self.sync_started_at}"
self.logger.info(f"Dropping backup table {backup_table}")
self.connector.drop_backup_table(backup_table)
super().clean_up()


__all__ = ["SQLSink", "SQLConnector"]
2 changes: 2 additions & 0 deletions singer_sdk/target_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from singer_sdk.helpers.capabilities import (
ADD_RECORD_METADATA_CONFIG,
BATCH_CONFIG,
TARGET_LOAD_METHOD_CONFIG,
TARGET_SCHEMA_CONFIG,
CapabilitiesEnum,
PluginCapabilities,
Expand Down Expand Up @@ -597,6 +598,7 @@ def _merge_missing(source_jsonschema: dict, target_jsonschema: dict) -> None:
target_jsonschema["properties"][k] = v

_merge_missing(ADD_RECORD_METADATA_CONFIG, config_jsonschema)
_merge_missing(TARGET_LOAD_METHOD_CONFIG, config_jsonschema)

capabilities = cls.capabilities

Expand Down
66 changes: 66 additions & 0 deletions tests/samples/test_target_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,69 @@ def test_hostile_to_sqlite(
"hname_starts_with_number",
"name_with_emoji_",
}


def test_overwrite_load_method(
sqlite_target_test_config: dict,
):
sqlite_target_test_config["load_method"] = "overwrite"
target = SQLiteTarget(config=sqlite_target_test_config)
test_tbl = f"zzz_tmp_{str(uuid4()).split('-')[-1]}"
schema_msg = {
"type": "SCHEMA",
"stream": test_tbl,
"schema": {
"type": "object",
"properties": {"col_a": th.StringType().to_dict()},
},
}

tap_output_a = "\n".join(
json.dumps(msg)
for msg in [
schema_msg,
{"type": "RECORD", "stream": test_tbl, "record": {"col_a": "123"}},
]
)
# Assert
db = sqlite3.connect(sqlite_target_test_config["path_to_db"])
cursor = db.cursor()

target_sync_test(target, input=StringIO(tap_output_a), finalize=True)
cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608
records = [res[0] for res in cursor.fetchall()]
assert records == ["123"]

# Assert overwrite works
tap_output_b = "\n".join(
json.dumps(msg)
for msg in [
schema_msg,
{"type": "RECORD", "stream": test_tbl, "record": {"col_a": "456"}},
]
)
target = SQLiteTarget(config=sqlite_target_test_config)
target_sync_test(target, input=StringIO(tap_output_b), finalize=True)
cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608
records = [res[0] for res in cursor.fetchall()]
assert records == ["456"]

# Assert overwrite rollback works on error
tap_output_c = "\n".join(
json.dumps(msg)
for msg in [
schema_msg,
# Type mismatch to cause an Exception
{"type": "RECORD", "stream": test_tbl, "record": {"col_a": 789}},
]
)
target2 = SQLiteTarget(config=sqlite_target_test_config)
try:
target_sync_test(target2, input=StringIO(tap_output_c), finalize=True)
except Exception:
cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608
records = [res[0] for res in cursor.fetchall()]
assert records == ["456"]
else:
msg = "Should have raised an exception"
raise AssertionError(msg)
Loading