From 2014d4d9ab389b51d2df14002b208bfd4130abfb Mon Sep 17 00:00:00 2001 From: pnadolny13 Date: Thu, 3 Aug 2023 15:47:11 -0400 Subject: [PATCH 01/10] initial implementation of standard load methods, sql connector implementation --- .../{{cookiecutter.library_name}}/sinks.py | 1 + samples/sample_target_sqlite/__init__.py | 1 + singer_sdk/connectors/sql.py | 41 +++++++++++++++++++ singer_sdk/helpers/capabilities.py | 33 +++++++++++++++ singer_sdk/target_base.py | 2 + 5 files changed, 78 insertions(+) diff --git a/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py b/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py index 4e84d1284..9edd13a11 100644 --- a/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py +++ b/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py @@ -35,6 +35,7 @@ class {{ cookiecutter.destination_name }}Connector(SQLConnector): allow_column_rename: bool = True # Whether RENAME COLUMN is supported. allow_column_alter: bool = False # Whether altering column types is supported. allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. + allow_overwrite: bool = False # Whether overwrite load method is supported. allow_temp_tables: bool = True # Whether temp tables are supported. def get_sqlalchemy_url(self, config: dict) -> str: diff --git a/samples/sample_target_sqlite/__init__.py b/samples/sample_target_sqlite/__init__.py index bd759e464..a494d750c 100644 --- a/samples/sample_target_sqlite/__init__.py +++ b/samples/sample_target_sqlite/__init__.py @@ -19,6 +19,7 @@ class SQLiteConnector(SQLConnector): allow_temp_tables = False allow_column_alter = False allow_merge_upsert = True + allow_overwrite: bool = False def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: """Generates a SQLAlchemy URL for SQLite.""" diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index e9a65cf80..ae5ff7d24 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -15,6 +15,7 @@ from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema from singer_sdk.exceptions import ConfigValidationError +from singer_sdk.helpers.capabilities import TargetLoadMethods if t.TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector @@ -37,6 +38,7 @@ class SQLConnector: allow_column_rename: bool = True # Whether RENAME COLUMN is supported. allow_column_alter: bool = False # Whether altering column types is supported. allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. + allow_overwrite: bool = False # Whether overwrite load method is supported. allow_temp_tables: bool = True # Whether temp tables are supported. _cached_engine: Engine | None = None @@ -732,6 +734,43 @@ def prepare_schema(self, schema_name: str) -> None: if not schema_exists: self.create_schema(schema_name) + @staticmethod + def get_truncate_table_ddl( + table_name: str, + ) -> sqlalchemy.DDL: + """Get the truncate table DDL statement. + + Override this if your database uses a different syntax for truncating tables. + + Args: + table_name: Fully qualified table name of column to alter. + + Returns: + A sqlalchemy DDL instance. + """ + return sqlalchemy.DDL( + "TRUNCATE TABLE %(table_name)s", + { + "table_name": table_name, + }, + ) + + def truncate_table(self, full_table_name: str) -> None: + """Truncate target table. + + Args: + full_table_name: Fully qualified table name of column to alter. + """ + if not self.allow_overwrite: + msg = "Truncating tables is not supported." + raise NotImplementedError(msg) + + truncate_table_ddl = self.get_truncate_table_ddl( + table_name=full_table_name, + ) + with self._connect() as conn: + conn.execute(truncate_table_ddl) + def prepare_table( self, full_table_name: str, @@ -758,6 +797,8 @@ def prepare_table( as_temp_table=as_temp_table, ) return + elif self.config.get("load_method") == TargetLoadMethods.OVERWRITE: + self.truncate_table(full_table_name=full_table_name) for property_name, property_def in schema["properties"].items(): self.prepare_column( diff --git a/singer_sdk/helpers/capabilities.py b/singer_sdk/helpers/capabilities.py index f5b5fa305..1c33802c9 100644 --- a/singer_sdk/helpers/capabilities.py +++ b/singer_sdk/helpers/capabilities.py @@ -108,6 +108,39 @@ ).to_dict() +class TargetLoadMethods(Enum): + """Target-specific capabilities.""" + + # always write all input records whether that records already exists or not + APPEND_ONLY = "append-only" + + # update existing records and insert new records + UPSERT = "upsert" + + # delete all existing records and insert all input records + OVERWRITE = "overwrite" + + +TARGET_LOAD_METHOD_CONFIG = PropertiesList( + Property( + "load_method", + StringType(), + description=( + "The method to use when loading data into the destination. " + "`append-only` will always write all input records whether that records " + "already exists or not. `upsert` will update existing records and insert " + "new records. `overwrite` will delete all existing records and insert all " + "input records." + ), + allowed_values=[ + TargetLoadMethods.APPEND_ONLY, + TargetLoadMethods.UPSERT, + TargetLoadMethods.OVERWRITE + ], + default=TargetLoadMethods.APPEND_ONLY, + ), +).to_dict() + class DeprecatedEnum(Enum): """Base class for capabilities enumeration.""" diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py index a5386199f..d902f2acd 100644 --- a/singer_sdk/target_base.py +++ b/singer_sdk/target_base.py @@ -19,6 +19,7 @@ from singer_sdk.helpers.capabilities import ( ADD_RECORD_METADATA_CONFIG, BATCH_CONFIG, + TARGET_LOAD_METHOD_CONFIG, TARGET_SCHEMA_CONFIG, CapabilitiesEnum, PluginCapabilities, @@ -597,6 +598,7 @@ def _merge_missing(source_jsonschema: dict, target_jsonschema: dict) -> None: target_jsonschema["properties"][k] = v _merge_missing(ADD_RECORD_METADATA_CONFIG, config_jsonschema) + _merge_missing(TARGET_LOAD_METHOD_CONFIG, config_jsonschema) capabilities = cls.capabilities From 8691b893dd88b55c0c320438f6382fcae960188e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Aug 2023 19:54:14 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- singer_sdk/helpers/capabilities.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/singer_sdk/helpers/capabilities.py b/singer_sdk/helpers/capabilities.py index 1c33802c9..dba94787a 100644 --- a/singer_sdk/helpers/capabilities.py +++ b/singer_sdk/helpers/capabilities.py @@ -135,12 +135,13 @@ class TargetLoadMethods(Enum): allowed_values=[ TargetLoadMethods.APPEND_ONLY, TargetLoadMethods.UPSERT, - TargetLoadMethods.OVERWRITE + TargetLoadMethods.OVERWRITE, ], default=TargetLoadMethods.APPEND_ONLY, ), ).to_dict() + class DeprecatedEnum(Enum): """Base class for capabilities enumeration.""" From 92defe7cef13a049a66af1f41ba939a03172dedf Mon Sep 17 00:00:00 2001 From: pnadolny13 Date: Thu, 3 Aug 2023 16:03:18 -0400 Subject: [PATCH 03/10] enum to values --- singer_sdk/helpers/capabilities.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/singer_sdk/helpers/capabilities.py b/singer_sdk/helpers/capabilities.py index 1c33802c9..c38c4bb2e 100644 --- a/singer_sdk/helpers/capabilities.py +++ b/singer_sdk/helpers/capabilities.py @@ -133,11 +133,11 @@ class TargetLoadMethods(Enum): "input records." ), allowed_values=[ - TargetLoadMethods.APPEND_ONLY, - TargetLoadMethods.UPSERT, - TargetLoadMethods.OVERWRITE + TargetLoadMethods.APPEND_ONLY.value, + TargetLoadMethods.UPSERT.value, + TargetLoadMethods.OVERWRITE.value, ], - default=TargetLoadMethods.APPEND_ONLY, + default=TargetLoadMethods.APPEND_ONLY.value, ), ).to_dict() From 04f4dc0c7606f8d25ffad308a25f1be69597fde3 Mon Sep 17 00:00:00 2001 From: pnadolny13 Date: Thu, 3 Aug 2023 16:24:59 -0400 Subject: [PATCH 04/10] fix enum comparisons --- singer_sdk/connectors/sql.py | 2 +- singer_sdk/helpers/capabilities.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index ae5ff7d24..0245c320f 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -797,7 +797,7 @@ def prepare_table( as_temp_table=as_temp_table, ) return - elif self.config.get("load_method") == TargetLoadMethods.OVERWRITE: + elif self.config["load_method"] == TargetLoadMethods.OVERWRITE: self.truncate_table(full_table_name=full_table_name) for property_name, property_def in schema["properties"].items(): diff --git a/singer_sdk/helpers/capabilities.py b/singer_sdk/helpers/capabilities.py index f724c92a5..690f04db5 100644 --- a/singer_sdk/helpers/capabilities.py +++ b/singer_sdk/helpers/capabilities.py @@ -108,7 +108,7 @@ ).to_dict() -class TargetLoadMethods(Enum): +class TargetLoadMethods(str, Enum): """Target-specific capabilities.""" # always write all input records whether that records already exists or not @@ -133,11 +133,11 @@ class TargetLoadMethods(Enum): "input records." ), allowed_values=[ - TargetLoadMethods.APPEND_ONLY.value, - TargetLoadMethods.UPSERT.value, - TargetLoadMethods.OVERWRITE.value, + TargetLoadMethods.APPEND_ONLY, + TargetLoadMethods.UPSERT, + TargetLoadMethods.OVERWRITE, ], - default=TargetLoadMethods.APPEND_ONLY.value, + default=TargetLoadMethods.APPEND_ONLY, ), ).to_dict() From 9510c192c2543f1e782d9aa38646fc7adeb517f9 Mon Sep 17 00:00:00 2001 From: pnadolny13 Date: Tue, 8 Aug 2023 14:17:05 -0400 Subject: [PATCH 05/10] adds test for sqlite overwrite load method --- samples/sample_target_sqlite/__init__.py | 13 ++++++- singer_sdk/connectors/sql.py | 25 +++++++------ tests/samples/test_target_sqlite.py | 45 ++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/samples/sample_target_sqlite/__init__.py b/samples/sample_target_sqlite/__init__.py index a494d750c..9ab0ae3f7 100644 --- a/samples/sample_target_sqlite/__init__.py +++ b/samples/sample_target_sqlite/__init__.py @@ -4,6 +4,8 @@ import typing as t +from sqlalchemy.sql import text + from singer_sdk import SQLConnector, SQLSink, SQLTarget from singer_sdk import typing as th @@ -19,12 +21,21 @@ class SQLiteConnector(SQLConnector): allow_temp_tables = False allow_column_alter = False allow_merge_upsert = True - allow_overwrite: bool = False + allow_overwrite: bool = True def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: """Generates a SQLAlchemy URL for SQLite.""" return f"sqlite:///{config[DB_PATH_CONFIG]}" + @staticmethod + def get_truncate_table_ddl( + table_name: str, + ) -> tuple[text, dict]: + return ( + text(f"DELETE FROM {table_name}"), # noqa: S608 + {}, + ) + class SQLiteSink(SQLSink): """The Sink class for SQLite. diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 0245c320f..285448fad 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -11,6 +11,7 @@ import sqlalchemy from sqlalchemy.engine import Engine +from sqlalchemy.sql import text from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema @@ -737,8 +738,8 @@ def prepare_schema(self, schema_name: str) -> None: @staticmethod def get_truncate_table_ddl( table_name: str, - ) -> sqlalchemy.DDL: - """Get the truncate table DDL statement. + ) -> tuple[text, dict]: + """Get the truncate table SQL statement. Override this if your database uses a different syntax for truncating tables. @@ -746,13 +747,11 @@ def get_truncate_table_ddl( table_name: Fully qualified table name of column to alter. Returns: - A sqlalchemy DDL instance. + A tuple of the SQL statement and a dictionary of bind parameters. """ - return sqlalchemy.DDL( - "TRUNCATE TABLE %(table_name)s", - { - "table_name": table_name, - }, + return ( + text(f"TRUNCATE TABLE {table_name}"), + {}, ) def truncate_table(self, full_table_name: str) -> None: @@ -760,16 +759,20 @@ def truncate_table(self, full_table_name: str) -> None: Args: full_table_name: Fully qualified table name of column to alter. + + Raises: + NotImplementedError: if truncating tables is not supported. """ if not self.allow_overwrite: msg = "Truncating tables is not supported." raise NotImplementedError(msg) - truncate_table_ddl = self.get_truncate_table_ddl( + truncate_table_ddl, kwargs = self.get_truncate_table_ddl( table_name=full_table_name, ) with self._connect() as conn: - conn.execute(truncate_table_ddl) + conn.execute(truncate_table_ddl, **kwargs) + conn.commit() def prepare_table( self, @@ -797,7 +800,7 @@ def prepare_table( as_temp_table=as_temp_table, ) return - elif self.config["load_method"] == TargetLoadMethods.OVERWRITE: + if self.config["load_method"] == TargetLoadMethods.OVERWRITE: self.truncate_table(full_table_name=full_table_name) for property_name, property_def in schema["properties"].items(): diff --git a/tests/samples/test_target_sqlite.py b/tests/samples/test_target_sqlite.py index abf1bccaf..995cb03aa 100644 --- a/tests/samples/test_target_sqlite.py +++ b/tests/samples/test_target_sqlite.py @@ -508,3 +508,48 @@ def test_hostile_to_sqlite( "hname_starts_with_number", "name_with_emoji_", } + + +def test_overwrite_load_method( + sqlite_target_test_config: dict, +): + sqlite_target_test_config["load_method"] = "overwrite" + target = SQLiteTarget(config=sqlite_target_test_config) + test_tbl = f"zzz_tmp_{str(uuid4()).split('-')[-1]}" + schema_msg = { + "type": "SCHEMA", + "stream": test_tbl, + "schema": { + "type": "object", + "properties": {"col_a": th.StringType().to_dict()}, + }, + } + + tap_output_a = "\n".join( + json.dumps(msg) + for msg in [ + schema_msg, + {"type": "RECORD", "stream": test_tbl, "record": {"col_a": "123"}}, + ] + ) + # Assert + db = sqlite3.connect(sqlite_target_test_config["path_to_db"]) + cursor = db.cursor() + + target_sync_test(target, input=StringIO(tap_output_a), finalize=True) + cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608 + records = [res[0] for res in cursor.fetchall()] + assert records == ["123"] + + tap_output_b = "\n".join( + json.dumps(msg) + for msg in [ + schema_msg, + {"type": "RECORD", "stream": test_tbl, "record": {"col_a": "456"}}, + ] + ) + target = SQLiteTarget(config=sqlite_target_test_config) + target_sync_test(target, input=StringIO(tap_output_b), finalize=True) + cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608 + records = [res[0] for res in cursor.fetchall()] + assert records == ["456"] From 046044f9de808eff215681a3c8fd0009efe9fbfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= Date: Tue, 8 Aug 2023 14:51:56 -0600 Subject: [PATCH 06/10] Address issues --- singer_sdk/connectors/sql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index fd1368565..4eec00599 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -23,6 +23,7 @@ if t.TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.elements import TextClause class SQLConnector: @@ -746,7 +747,7 @@ def prepare_schema(self, schema_name: str) -> None: @staticmethod def get_truncate_table_ddl( table_name: str, - ) -> tuple[text, dict]: + ) -> tuple[TextClause, dict]: """Get the truncate table SQL statement. Override this if your database uses a different syntax for truncating tables. @@ -778,9 +779,8 @@ def truncate_table(self, full_table_name: str) -> None: truncate_table_ddl, kwargs = self.get_truncate_table_ddl( table_name=full_table_name, ) - with self._connect() as conn: + with self._connect() as conn, conn.begin(): conn.execute(truncate_table_ddl, **kwargs) - conn.commit() def prepare_table( self, From 42961af285910b99dcc66d082d0ab5962259cc30 Mon Sep 17 00:00:00 2001 From: pnadolny13 Date: Thu, 7 Sep 2023 15:49:37 -0400 Subject: [PATCH 07/10] drop table instead of truncating --- samples/sample_target_sqlite/__init__.py | 9 ----- singer_sdk/connectors/sql.py | 48 +++++------------------- 2 files changed, 9 insertions(+), 48 deletions(-) diff --git a/samples/sample_target_sqlite/__init__.py b/samples/sample_target_sqlite/__init__.py index 9ab0ae3f7..7781f2413 100644 --- a/samples/sample_target_sqlite/__init__.py +++ b/samples/sample_target_sqlite/__init__.py @@ -27,15 +27,6 @@ def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: """Generates a SQLAlchemy URL for SQLite.""" return f"sqlite:///{config[DB_PATH_CONFIG]}" - @staticmethod - def get_truncate_table_ddl( - table_name: str, - ) -> tuple[text, dict]: - return ( - text(f"DELETE FROM {table_name}"), # noqa: S608 - {}, - ) - class SQLiteSink(SQLSink): """The Sink class for SQLite. diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 4eec00599..2fff7f33f 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -744,44 +744,6 @@ def prepare_schema(self, schema_name: str) -> None: if not schema_exists: self.create_schema(schema_name) - @staticmethod - def get_truncate_table_ddl( - table_name: str, - ) -> tuple[TextClause, dict]: - """Get the truncate table SQL statement. - - Override this if your database uses a different syntax for truncating tables. - - Args: - table_name: Fully qualified table name of column to alter. - - Returns: - A tuple of the SQL statement and a dictionary of bind parameters. - """ - return ( - text(f"TRUNCATE TABLE {table_name}"), - {}, - ) - - def truncate_table(self, full_table_name: str) -> None: - """Truncate target table. - - Args: - full_table_name: Fully qualified table name of column to alter. - - Raises: - NotImplementedError: if truncating tables is not supported. - """ - if not self.allow_overwrite: - msg = "Truncating tables is not supported." - raise NotImplementedError(msg) - - truncate_table_ddl, kwargs = self.get_truncate_table_ddl( - table_name=full_table_name, - ) - with self._connect() as conn, conn.begin(): - conn.execute(truncate_table_ddl, **kwargs) - def prepare_table( self, full_table_name: str, @@ -809,7 +771,15 @@ def prepare_table( ) return if self.config["load_method"] == TargetLoadMethods.OVERWRITE: - self.truncate_table(full_table_name=full_table_name) + self.get_table(full_table_name=full_table_name).drop(self._engine) + self.create_empty_table( + full_table_name=full_table_name, + schema=schema, + primary_keys=primary_keys, + partition_keys=partition_keys, + as_temp_table=as_temp_table, + ) + return for property_name, property_def in schema["properties"].items(): self.prepare_column( From 430cebbe94c9f6d4d27d638bb3905c5d0d019fa6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Sep 2023 19:50:04 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- samples/sample_target_sqlite/__init__.py | 2 -- singer_sdk/connectors/sql.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/samples/sample_target_sqlite/__init__.py b/samples/sample_target_sqlite/__init__.py index 7781f2413..40384facf 100644 --- a/samples/sample_target_sqlite/__init__.py +++ b/samples/sample_target_sqlite/__init__.py @@ -4,8 +4,6 @@ import typing as t -from sqlalchemy.sql import text - from singer_sdk import SQLConnector, SQLSink, SQLTarget from singer_sdk import typing as th diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 2fff7f33f..8ed4e10f1 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -14,7 +14,6 @@ import simplejson import sqlalchemy from sqlalchemy.engine import Engine -from sqlalchemy.sql import text from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema @@ -23,7 +22,6 @@ if t.TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector - from sqlalchemy.sql.elements import TextClause class SQLConnector: From 55b10b6fe31cb4180e9af78fc5885a47b32d4dd7 Mon Sep 17 00:00:00 2001 From: pnadolny13 Date: Thu, 7 Sep 2023 15:59:10 -0400 Subject: [PATCH 09/10] tinkering with how to backup/rollback instead of dropping in override mode --- singer_sdk/connectors/sql.py | 49 ++++++++++++++++++++++++++++- singer_sdk/sinks/sql.py | 14 +++++++++ tests/samples/test_target_sqlite.py | 20 ++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 8ed4e10f1..fa5567263 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -742,6 +742,47 @@ def prepare_schema(self, schema_name: str) -> None: if not schema_exists: self.create_schema(schema_name) + @staticmethod + def get_alter_table_name_ddl( + existing_table_name: str, + new_table_name: str, + ) -> tuple[TextClause, dict]: + """Get the alter table name SQL statement. + + Override this if your database uses a different syntax for alter table names. + + Args: + existing_table_name: Fully qualified table name of column to alter. + new_table_name: The new fully qualified table name. + + Returns: + A tuple of the SQL statement and a dictionary of bind parameters. + """ + return ( + text(f"ALTER TABLE {existing_table_name} RENAME TO {new_table_name}"), + {}, + ) + + def alter_table_name(self, existing_table_name: str, new_full_table_name: str) -> None: + """Alter target table name. + + Args: + full_table_name: Fully qualified table name of column to alter. + + Raises: + NotImplementedError: if truncating tables is not supported. + """ + if not self.allow_overwrite: + msg = "Altering table names for overwrite load method is not supported." + raise NotImplementedError(msg) + + alter_table_name_ddl, kwargs = self.get_alter_table_name_ddl( + existing_table_name, + new_full_table_name, + ) + with self._connect() as conn, conn.begin(): + conn.execute(alter_table_name_ddl, **kwargs) + def prepare_table( self, full_table_name: str, @@ -749,6 +790,7 @@ def prepare_table( primary_keys: list[str], partition_keys: list[str] | None = None, as_temp_table: bool = False, # noqa: FBT002, FBT001 + sync_started_at: int | None = None, ) -> None: """Adapt target table to provided schema if possible. @@ -769,7 +811,8 @@ def prepare_table( ) return if self.config["load_method"] == TargetLoadMethods.OVERWRITE: - self.get_table(full_table_name=full_table_name).drop(self._engine) + self.alter_table_name(full_table_name, f"{full_table_name}_{sync_started_at}") + self.create_empty_table( full_table_name=full_table_name, schema=schema, @@ -786,6 +829,10 @@ def prepare_table( self.to_sql_type(property_def), ) + def drop_backup_table(self, full_table_name: str) -> None: + table = self.get_table(full_table_name=full_table_name) + table.drop(self._engine) + def prepare_column( self, full_table_name: str, diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index 238e83dec..aec73fc3c 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -16,6 +16,7 @@ from singer_sdk.exceptions import ConformedNameClashException from singer_sdk.helpers._conformers import replace_leading_digit from singer_sdk.sinks.batch import BatchSink +from singer_sdk.helpers.capabilities import TargetLoadMethods if t.TYPE_CHECKING: from sqlalchemy.sql import Executable @@ -242,6 +243,7 @@ def setup(self) -> None: schema=self.conform_schema(self.schema), primary_keys=self.key_properties, as_temp_table=False, + sync_started_at=self.sync_started_at ) @property @@ -264,6 +266,7 @@ def process_batch(self, context: dict) -> None: """ # If duplicates are merged, these can be tracked via # :meth:`~singer_sdk.Sink.tally_duplicate_merged()`. + # Append only or overwrite use inserts self.bulk_insert_records( full_table_name=self.full_table_name, schema=self.schema, @@ -417,5 +420,16 @@ def activate_version(self, new_version: int) -> None: with self.connector._connect() as conn, conn.begin(): conn.execute(query) + def clean_up(self) -> None: + """Clean up any resources held by the sink.""" + if ( + self.config["load_method"] == TargetLoadMethods.OVERWRITE + and self.connector.table_exists(full_table_name=f"{self.full_table_name}_{self.sync_started_at}") + ): + # if success then drop the backup table, if fail revert to old table + backup_table = f"{self.full_table_name}_{self.sync_started_at}" + self.logger.info(f"Dropping backup table {backup_table}") + self.connector.drop_backup_table(backup_table) + super().clean_up() __all__ = ["SQLSink", "SQLConnector"] diff --git a/tests/samples/test_target_sqlite.py b/tests/samples/test_target_sqlite.py index 995cb03aa..e637de255 100644 --- a/tests/samples/test_target_sqlite.py +++ b/tests/samples/test_target_sqlite.py @@ -541,6 +541,7 @@ def test_overwrite_load_method( records = [res[0] for res in cursor.fetchall()] assert records == ["123"] + # Assert overwrite works tap_output_b = "\n".join( json.dumps(msg) for msg in [ @@ -553,3 +554,22 @@ def test_overwrite_load_method( cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608 records = [res[0] for res in cursor.fetchall()] assert records == ["456"] + + # Assert overwrite rollback works on error + tap_output_c = "\n".join( + json.dumps(msg) + for msg in [ + schema_msg, + # Type mismatch to cause an Exception + {"type": "RECORD", "stream": test_tbl, "record": {"col_a": 789}}, + ] + ) + target2 = SQLiteTarget(config=sqlite_target_test_config) + try: + target_sync_test(target2, input=StringIO(tap_output_c), finalize=True) + except Exception: + cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608 + records = [res[0] for res in cursor.fetchall()] + assert records == ["456"] + else: + assert False, "Should have raised an exception" \ No newline at end of file From 6c2f17fc48786db571d1e29f553a3ffeea22f5e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Sep 2023 20:01:28 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- singer_sdk/connectors/sql.py | 8 ++++++-- singer_sdk/sinks/sql.py | 12 +++++++----- tests/samples/test_target_sqlite.py | 3 ++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index fa5567263..54f8ed300 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -763,7 +763,9 @@ def get_alter_table_name_ddl( {}, ) - def alter_table_name(self, existing_table_name: str, new_full_table_name: str) -> None: + def alter_table_name( + self, existing_table_name: str, new_full_table_name: str + ) -> None: """Alter target table name. Args: @@ -811,7 +813,9 @@ def prepare_table( ) return if self.config["load_method"] == TargetLoadMethods.OVERWRITE: - self.alter_table_name(full_table_name, f"{full_table_name}_{sync_started_at}") + self.alter_table_name( + full_table_name, f"{full_table_name}_{sync_started_at}" + ) self.create_empty_table( full_table_name=full_table_name, diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index aec73fc3c..d677b2dec 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -15,8 +15,8 @@ from singer_sdk.connectors import SQLConnector from singer_sdk.exceptions import ConformedNameClashException from singer_sdk.helpers._conformers import replace_leading_digit -from singer_sdk.sinks.batch import BatchSink from singer_sdk.helpers.capabilities import TargetLoadMethods +from singer_sdk.sinks.batch import BatchSink if t.TYPE_CHECKING: from sqlalchemy.sql import Executable @@ -243,7 +243,7 @@ def setup(self) -> None: schema=self.conform_schema(self.schema), primary_keys=self.key_properties, as_temp_table=False, - sync_started_at=self.sync_started_at + sync_started_at=self.sync_started_at, ) @property @@ -422,9 +422,10 @@ def activate_version(self, new_version: int) -> None: def clean_up(self) -> None: """Clean up any resources held by the sink.""" - if ( - self.config["load_method"] == TargetLoadMethods.OVERWRITE - and self.connector.table_exists(full_table_name=f"{self.full_table_name}_{self.sync_started_at}") + if self.config[ + "load_method" + ] == TargetLoadMethods.OVERWRITE and self.connector.table_exists( + full_table_name=f"{self.full_table_name}_{self.sync_started_at}" ): # if success then drop the backup table, if fail revert to old table backup_table = f"{self.full_table_name}_{self.sync_started_at}" @@ -432,4 +433,5 @@ def clean_up(self) -> None: self.connector.drop_backup_table(backup_table) super().clean_up() + __all__ = ["SQLSink", "SQLConnector"] diff --git a/tests/samples/test_target_sqlite.py b/tests/samples/test_target_sqlite.py index e637de255..1dc974f7d 100644 --- a/tests/samples/test_target_sqlite.py +++ b/tests/samples/test_target_sqlite.py @@ -572,4 +572,5 @@ def test_overwrite_load_method( records = [res[0] for res in cursor.fetchall()] assert records == ["456"] else: - assert False, "Should have raised an exception" \ No newline at end of file + msg = "Should have raised an exception" + raise AssertionError(msg)