diff --git a/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py b/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py index 4e84d1284..9edd13a11 100644 --- a/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py +++ b/cookiecutter/target-template/{{cookiecutter.target_id}}/{{cookiecutter.library_name}}/sinks.py @@ -35,6 +35,7 @@ class {{ cookiecutter.destination_name }}Connector(SQLConnector): allow_column_rename: bool = True # Whether RENAME COLUMN is supported. allow_column_alter: bool = False # Whether altering column types is supported. allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. + allow_overwrite: bool = False # Whether overwrite load method is supported. allow_temp_tables: bool = True # Whether temp tables are supported. def get_sqlalchemy_url(self, config: dict) -> str: diff --git a/samples/sample_target_sqlite/__init__.py b/samples/sample_target_sqlite/__init__.py index bd759e464..40384facf 100644 --- a/samples/sample_target_sqlite/__init__.py +++ b/samples/sample_target_sqlite/__init__.py @@ -19,6 +19,7 @@ class SQLiteConnector(SQLConnector): allow_temp_tables = False allow_column_alter = False allow_merge_upsert = True + allow_overwrite: bool = True def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: """Generates a SQLAlchemy URL for SQLite.""" diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index fb49b3587..2693bfe2e 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -18,6 +18,7 @@ from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema from singer_sdk.exceptions import ConfigValidationError +from singer_sdk.helpers.capabilities import TargetLoadMethods if t.TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector @@ -40,6 +41,7 @@ class SQLConnector: allow_column_rename: bool = True # Whether RENAME COLUMN is supported. allow_column_alter: bool = False # Whether altering column types is supported. allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. + allow_overwrite: bool = False # Whether overwrite load method is supported. allow_temp_tables: bool = True # Whether temp tables are supported. _cached_engine: Engine | None = None @@ -775,6 +777,16 @@ def prepare_table( as_temp_table=as_temp_table, ) return + if self.config["load_method"] == TargetLoadMethods.OVERWRITE: + self.get_table(full_table_name=full_table_name).drop(self._engine) + self.create_empty_table( + full_table_name=full_table_name, + schema=schema, + primary_keys=primary_keys, + partition_keys=partition_keys, + as_temp_table=as_temp_table, + ) + return for property_name, property_def in schema["properties"].items(): self.prepare_column( diff --git a/singer_sdk/helpers/capabilities.py b/singer_sdk/helpers/capabilities.py index f5b5fa305..690f04db5 100644 --- a/singer_sdk/helpers/capabilities.py +++ b/singer_sdk/helpers/capabilities.py @@ -108,6 +108,40 @@ ).to_dict() +class TargetLoadMethods(str, Enum): + """Target-specific capabilities.""" + + # always write all input records whether that records already exists or not + APPEND_ONLY = "append-only" + + # update existing records and insert new records + UPSERT = "upsert" + + # delete all existing records and insert all input records + OVERWRITE = "overwrite" + + +TARGET_LOAD_METHOD_CONFIG = PropertiesList( + Property( + "load_method", + StringType(), + description=( + "The method to use when loading data into the destination. " + "`append-only` will always write all input records whether that records " + "already exists or not. `upsert` will update existing records and insert " + "new records. `overwrite` will delete all existing records and insert all " + "input records." + ), + allowed_values=[ + TargetLoadMethods.APPEND_ONLY, + TargetLoadMethods.UPSERT, + TargetLoadMethods.OVERWRITE, + ], + default=TargetLoadMethods.APPEND_ONLY, + ), +).to_dict() + + class DeprecatedEnum(Enum): """Base class for capabilities enumeration.""" diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py index a5386199f..d902f2acd 100644 --- a/singer_sdk/target_base.py +++ b/singer_sdk/target_base.py @@ -19,6 +19,7 @@ from singer_sdk.helpers.capabilities import ( ADD_RECORD_METADATA_CONFIG, BATCH_CONFIG, + TARGET_LOAD_METHOD_CONFIG, TARGET_SCHEMA_CONFIG, CapabilitiesEnum, PluginCapabilities, @@ -597,6 +598,7 @@ def _merge_missing(source_jsonschema: dict, target_jsonschema: dict) -> None: target_jsonschema["properties"][k] = v _merge_missing(ADD_RECORD_METADATA_CONFIG, config_jsonschema) + _merge_missing(TARGET_LOAD_METHOD_CONFIG, config_jsonschema) capabilities = cls.capabilities diff --git a/tests/samples/test_target_sqlite.py b/tests/samples/test_target_sqlite.py index abf1bccaf..995cb03aa 100644 --- a/tests/samples/test_target_sqlite.py +++ b/tests/samples/test_target_sqlite.py @@ -508,3 +508,48 @@ def test_hostile_to_sqlite( "hname_starts_with_number", "name_with_emoji_", } + + +def test_overwrite_load_method( + sqlite_target_test_config: dict, +): + sqlite_target_test_config["load_method"] = "overwrite" + target = SQLiteTarget(config=sqlite_target_test_config) + test_tbl = f"zzz_tmp_{str(uuid4()).split('-')[-1]}" + schema_msg = { + "type": "SCHEMA", + "stream": test_tbl, + "schema": { + "type": "object", + "properties": {"col_a": th.StringType().to_dict()}, + }, + } + + tap_output_a = "\n".join( + json.dumps(msg) + for msg in [ + schema_msg, + {"type": "RECORD", "stream": test_tbl, "record": {"col_a": "123"}}, + ] + ) + # Assert + db = sqlite3.connect(sqlite_target_test_config["path_to_db"]) + cursor = db.cursor() + + target_sync_test(target, input=StringIO(tap_output_a), finalize=True) + cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608 + records = [res[0] for res in cursor.fetchall()] + assert records == ["123"] + + tap_output_b = "\n".join( + json.dumps(msg) + for msg in [ + schema_msg, + {"type": "RECORD", "stream": test_tbl, "record": {"col_a": "456"}}, + ] + ) + target = SQLiteTarget(config=sqlite_target_test_config) + target_sync_test(target, input=StringIO(tap_output_b), finalize=True) + cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608 + records = [res[0] for res in cursor.fetchall()] + assert records == ["456"]