From 54b911143cd2b02a50dc240d06f9dc86482a2327 Mon Sep 17 00:00:00 2001 From: Sebastian Smiley Date: Mon, 16 Oct 2023 23:56:08 -0400 Subject: [PATCH] Fails for test_reserved_keywords when name_conforming_strategy is enabled. --- target_postgres/sinks.py | 49 ++++++++++++++++--- target_postgres/target.py | 10 ++++ target_postgres/tests/test_standard_target.py | 1 + 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/target_postgres/sinks.py b/target_postgres/sinks.py index fe1b923e..a466c766 100644 --- a/target_postgres/sinks.py +++ b/target_postgres/sinks.py @@ -1,5 +1,6 @@ """Postgres target sink class, which handles writing streams.""" +import copy import uuid from typing import Any, Dict, Iterable, List, Optional, Union @@ -23,6 +24,29 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.temp_table_name = self.generate_temp_table_name() + def conform_schema(self, schema: dict) -> dict: + """Return schema dictionary with property names conformed. + + Override from self.conform_name(key) to self.conform_name(key, "column") + + Args: + schema: JSON schema dictionary. + + Returns: + A schema dictionary with the property names conformed. + """ + conformed_schema = copy.copy(schema) + conformed_property_names = { + key: self.conform_name(key, "column") + for key in conformed_schema["properties"] + } + self._check_conformed_names_not_duplicated(conformed_property_names) + conformed_schema["properties"] = { + conformed_property_names[key]: value + for key, value in conformed_schema["properties"].items() + } + return conformed_schema + @property def append_only(self) -> bool: """Return True if the target is append only.""" @@ -48,7 +72,7 @@ def setup(self) -> None: with self.connector._connect() as connection: self.connector.prepare_table( full_table_name=self.full_table_name, - schema=self.schema, + schema=self.conform_schema(self.schema), primary_keys=self.key_properties, connection=connection, as_temp_table=False, @@ -63,12 +87,20 @@ def process_batch(self, context: dict) -> None: Args: context: Stream partition or context dictionary. """ + records: list = [] + + for record in context["records"]: + new_record: dict = {} + for k, v in record.items(): + new_record.update({self.conform_name(k, "column"): v}) + records.append(new_record) + # Use one connection so we do this all in a single transaction with self.connector._connect() as connection: # Check structure of table table: sqlalchemy.Table = self.connector.prepare_table( full_table_name=self.full_table_name, - schema=self.schema, + schema=self.conform_schema(self.schema), primary_keys=self.key_properties, as_temp_table=False, connection=connection, @@ -83,16 +115,16 @@ def process_batch(self, context: dict) -> None: # Insert into temp table self.bulk_insert_records( table=temp_table, - schema=self.schema, + schema=self.conform_schema(self.schema), primary_keys=self.key_properties, - records=context["records"], + records=records, connection=connection, ) # Merge data from Temp table to main table self.upsert( from_table=temp_table, to_table=table, - schema=self.schema, + schema=self.conform_schema(self.schema), join_keys=self.key_properties, connection=connection, ) @@ -218,7 +250,7 @@ def upsert( # Update where_condition = join_condition update_columns = {} - for column_name in self.schema["properties"].keys(): + for column_name in self.conform_schema(self.schema)["properties"].keys(): from_table_column: sqlalchemy.Column = from_table.columns[column_name] to_table_column: sqlalchemy.Column = to_table.columns[column_name] update_columns[to_table_column] = from_table_column @@ -263,7 +295,10 @@ def generate_insert_statement( def conform_name(self, name: str, object_type: Optional[str] = None) -> str: """Conforming names of tables, schemas, column names.""" - return name + if object_type in self.config["name_conforming_strategy"]: + return super().conform_name(name, object_type) + else: + return name @property def schema_name(self) -> Optional[str]: diff --git a/target_postgres/target.py b/target_postgres/target.py index 6d21947c..bf999dd6 100644 --- a/target_postgres/target.py +++ b/target_postgres/target.py @@ -306,6 +306,16 @@ def __init__( required=False, description="SSH Tunnel Configuration, this is a json object", ), + th.Property( + "name_conforming_strategy", + th.ArrayType(th.StringType), + default=[], + description=( + "If left as an empty array (the default), will not perform any name " + "conforming. Add `table` to the array to conform table names to snake " + "case. Add `column` to the array to conform column names to snake case." + ), + ), ).to_dict() default_sink_class = PostgresSink diff --git a/target_postgres/tests/test_standard_target.py b/target_postgres/tests/test_standard_target.py index 4658b1d3..e7d31be4 100644 --- a/target_postgres/tests/test_standard_target.py +++ b/target_postgres/tests/test_standard_target.py @@ -40,6 +40,7 @@ def postgres_config(): "add_record_metadata": True, "hard_delete": False, "default_target_schema": "melty", + "name_conforming_strategy": ["table", "column"], }