Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Standard configurable load methods #1893

Merged
merged 13 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class {{ cookiecutter.destination_name }}Connector(SQLConnector):
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_overwrite: bool = False # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.

def get_sqlalchemy_url(self, config: dict) -> str:
Expand Down
12 changes: 12 additions & 0 deletions samples/sample_target_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import typing as t

from sqlalchemy.sql import text

from singer_sdk import SQLConnector, SQLSink, SQLTarget
from singer_sdk import typing as th

Expand All @@ -19,11 +21,21 @@ class SQLiteConnector(SQLConnector):
allow_temp_tables = False
allow_column_alter = False
allow_merge_upsert = True
allow_overwrite: bool = True

def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str:
"""Generates a SQLAlchemy URL for SQLite."""
return f"sqlite:///{config[DB_PATH_CONFIG]}"

@staticmethod
def get_truncate_table_ddl(
table_name: str,
) -> tuple[text, dict]:
return (
text(f"DELETE FROM {table_name}"), # noqa: S608
{},
)


class SQLiteSink(SQLSink):
"""The Sink class for SQLite.
Expand Down
44 changes: 44 additions & 0 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
import simplejson
import sqlalchemy
from sqlalchemy.engine import Engine
from sqlalchemy.sql import text

from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema
from singer_sdk.exceptions import ConfigValidationError
from singer_sdk.helpers.capabilities import TargetLoadMethods

if t.TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import TextClause


class SQLConnector:
Expand All @@ -40,6 +43,7 @@
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_overwrite: bool = False # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.
_cached_engine: Engine | None = None

Expand Down Expand Up @@ -740,6 +744,44 @@
if not schema_exists:
self.create_schema(schema_name)

@staticmethod
def get_truncate_table_ddl(
table_name: str,
) -> tuple[TextClause, dict]:
"""Get the truncate table SQL statement.

Override this if your database uses a different syntax for truncating tables.

Args:
table_name: Fully qualified table name of column to alter.

Returns:
A tuple of the SQL statement and a dictionary of bind parameters.
"""
return (

Check warning on line 761 in singer_sdk/connectors/sql.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/connectors/sql.py#L761

Added line #L761 was not covered by tests
text(f"TRUNCATE TABLE {table_name}"),
{},
)

def truncate_table(self, full_table_name: str) -> None:
"""Truncate target table.

Args:
full_table_name: Fully qualified table name of column to alter.

Raises:
NotImplementedError: if truncating tables is not supported.
"""
if not self.allow_overwrite:
msg = "Truncating tables is not supported."

Check warning on line 776 in singer_sdk/connectors/sql.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/connectors/sql.py#L776

Added line #L776 was not covered by tests
raise NotImplementedError(msg)

truncate_table_ddl, kwargs = self.get_truncate_table_ddl(
table_name=full_table_name,
)
with self._connect() as conn, conn.begin():
conn.execute(truncate_table_ddl, **kwargs)

def prepare_table(
self,
full_table_name: str,
Expand All @@ -766,6 +808,8 @@
as_temp_table=as_temp_table,
)
return
if self.config["load_method"] == TargetLoadMethods.OVERWRITE:
self.truncate_table(full_table_name=full_table_name)
edgarrmondragon marked this conversation as resolved.
Show resolved Hide resolved

for property_name, property_def in schema["properties"].items():
self.prepare_column(
Expand Down
34 changes: 34 additions & 0 deletions singer_sdk/helpers/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,40 @@
).to_dict()


class TargetLoadMethods(str, Enum):
"""Target-specific capabilities."""

# always write all input records whether that records already exists or not
APPEND_ONLY = "append-only"

# update existing records and insert new records
UPSERT = "upsert"

# delete all existing records and insert all input records
OVERWRITE = "overwrite"


TARGET_LOAD_METHOD_CONFIG = PropertiesList(
Property(
"load_method",
StringType(),
description=(
"The method to use when loading data into the destination. "
"`append-only` will always write all input records whether that records "
"already exists or not. `upsert` will update existing records and insert "
"new records. `overwrite` will delete all existing records and insert all "
"input records."
),
allowed_values=[
TargetLoadMethods.APPEND_ONLY,
TargetLoadMethods.UPSERT,
TargetLoadMethods.OVERWRITE,
],
default=TargetLoadMethods.APPEND_ONLY,
),
).to_dict()


class DeprecatedEnum(Enum):
"""Base class for capabilities enumeration."""

Expand Down
2 changes: 2 additions & 0 deletions singer_sdk/target_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from singer_sdk.helpers.capabilities import (
ADD_RECORD_METADATA_CONFIG,
BATCH_CONFIG,
TARGET_LOAD_METHOD_CONFIG,
TARGET_SCHEMA_CONFIG,
CapabilitiesEnum,
PluginCapabilities,
Expand Down Expand Up @@ -597,6 +598,7 @@ def _merge_missing(source_jsonschema: dict, target_jsonschema: dict) -> None:
target_jsonschema["properties"][k] = v

_merge_missing(ADD_RECORD_METADATA_CONFIG, config_jsonschema)
_merge_missing(TARGET_LOAD_METHOD_CONFIG, config_jsonschema)

capabilities = cls.capabilities

Expand Down
45 changes: 45 additions & 0 deletions tests/samples/test_target_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,48 @@ def test_hostile_to_sqlite(
"hname_starts_with_number",
"name_with_emoji_",
}


def test_overwrite_load_method(
sqlite_target_test_config: dict,
):
sqlite_target_test_config["load_method"] = "overwrite"
target = SQLiteTarget(config=sqlite_target_test_config)
test_tbl = f"zzz_tmp_{str(uuid4()).split('-')[-1]}"
schema_msg = {
"type": "SCHEMA",
"stream": test_tbl,
"schema": {
"type": "object",
"properties": {"col_a": th.StringType().to_dict()},
},
}

tap_output_a = "\n".join(
json.dumps(msg)
for msg in [
schema_msg,
{"type": "RECORD", "stream": test_tbl, "record": {"col_a": "123"}},
]
)
# Assert
db = sqlite3.connect(sqlite_target_test_config["path_to_db"])
cursor = db.cursor()

target_sync_test(target, input=StringIO(tap_output_a), finalize=True)
cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608
records = [res[0] for res in cursor.fetchall()]
assert records == ["123"]

tap_output_b = "\n".join(
json.dumps(msg)
for msg in [
schema_msg,
{"type": "RECORD", "stream": test_tbl, "record": {"col_a": "456"}},
]
)
target = SQLiteTarget(config=sqlite_target_test_config)
target_sync_test(target, input=StringIO(tap_output_b), finalize=True)
cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608
records = [res[0] for res in cursor.fetchall()]
assert records == ["456"]