Skip to content

Commit

Permalink
another fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RobyBen committed Nov 23, 2024
1 parent b9de248 commit d2c6ed5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tap_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, Any]]:
replication_slot_name = self.config.get("replication_slot_name", "tappostgres")

logical_replication_cursor.start_replication(
slot_name=replication_slot_name,, #use slot name
slot_name=replication_slot_name, #use slot name
decode=True,
start_lsn=start_lsn,
status_interval=status_interval,
Expand Down
9 changes: 5 additions & 4 deletions tap_postgres/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def __init__(
):
slot_name = self.config["replication_slot_name"]
assert slot_name.isalnum() or "_" in slot_name, (
"Replication slot name must contain only letters, numbers and underscores"
"Replication slot name must contain letters, numbers and underscores"
)
assert len(slot_name) <= 63, (
max_slot_name_len = 63
assert len(slot_name) <= max_slot_name_len, (
"Replication slot name must be less than 63 characters"
)
assert not slot_name.startswith("pg_"), (
Expand Down Expand Up @@ -462,7 +463,7 @@ def connector(self) -> PostgresConnector:
url = self.ssh_tunnel_connect(ssh_config=ssh_config, url=url)

return PostgresConnector(
config=dict(self.config), #Pass the entire configuration, including replication_slot_name
config=dict(self.config),
sqlalchemy_url=url.render_as_string(hide_password=False),
)

Expand Down Expand Up @@ -662,4 +663,4 @@ def discover_streams(self) -> Sequence[Stream]:
"port": 5432,
"dbname": "example_db_2",
"replication_slot_name": "slot_2"
}
}
17 changes: 14 additions & 3 deletions tests/test_slot_name.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest

from tap_postgres.tap import TapPostgres


class TestReplicationSlot(unittest.TestCase):
def setUp(self):
self.default_config = {
Expand All @@ -13,7 +15,9 @@ def test_default_slot_name(self):
# Test backward compatibility when slot name is not provided.
config = self.default_config
tap = TapPostgres(config)
self.assertEqual(tap.config.get("replication_slot_name", "tappostgres"), "tappostgres")
self.assertEqual(
tap.config.get("replication_slot_name", "tappostgres"),
"tappostgres")

def test_custom_slot_name(self):
# Test if the custom slot name is used.
Expand All @@ -29,12 +33,19 @@ def test_multiple_slots(self):
tap_1 = TapPostgres(config_1)
tap_2 = TapPostgres(config_2)

self.assertNotEqual(tap_1.config["replication_slot_name"], tap_2.config["replication_slot_name"])
self.assertNotEqual(
tap_1.config["replication_slot_name"],
tap_2.config["replication_slot_name"],
)
self.assertEqual(tap_1.config["replication_slot_name"], "slot_1")
self.assertEqual(tap_2.config["replication_slot_name"], "slot_2")

def test_invalid_slot_name(self):
# Test validation for invalid slot names (if any validation rules exist).
invalid_config = {**self.default_config, "replication_slot_name": "invalid slot name!"}
invalid_config = {
**self.default_config,
"replication_slot_name": "invalid slot name!",
}

with self.assertRaises(ValueError):
TapPostgres(invalid_config)

0 comments on commit d2c6ed5

Please sign in to comment.