From bdf4eb5e3b8ef79ce4c524c33015234277bbe5e5 Mon Sep 17 00:00:00 2001 From: Sourabh Gandhi Date: Mon, 9 Dec 2024 04:45:39 +0000 Subject: [PATCH] Capture transaction based records for the log based replication --- tap_mongodb/sync_strategies/oplog.py | 93 ++++++---- tests/test_mongodb_oplog_transaction.py | 236 ++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 34 deletions(-) create mode 100644 tests/test_mongodb_oplog_transaction.py diff --git a/tap_mongodb/sync_strategies/oplog.py b/tap_mongodb/sync_strategies/oplog.py index f39125a..9cc6a75 100644 --- a/tap_mongodb/sync_strategies/oplog.py +++ b/tap_mongodb/sync_strategies/oplog.py @@ -8,6 +8,9 @@ from bson import timestamp import tap_mongodb.sync_strategies.common as common +from debugpy import listen, wait_for_client +listen(8000) +wait_for_client() LOGGER = singer.get_logger() @@ -124,6 +127,41 @@ def maybe_get_session(client): # return an object that works with a 'with' block return SessionNotAvailable() +def process_row(schema, row, stream, update_buffer, rows_saved, version, time_extracted): + row_op = row['op'] + + if row_op == 'i': + write_schema(schema, row['o'], stream) + record_message = common.row_to_singer_record(stream, + row['o'], + version, + time_extracted) + singer.write_message(record_message) + + rows_saved += 1 + + elif row_op == 'u': + update_buffer.add(row['o2']['_id']) + + elif row_op == 'd': + + # remove update from buffer if that document has been deleted + if row['o']['_id'] in update_buffer: + update_buffer.remove(row['o']['_id']) + + # Delete ops only contain the _id of the row deleted + row['o'][SDC_DELETED_AT] = row['ts'] + + write_schema(schema, row['o'], stream) + record_message = common.row_to_singer_record(stream, + row['o'], + version, + time_extracted) + singer.write_message(record_message) + + rows_saved += 1 + + return (rows_saved, update_buffer) # pylint: disable=too-many-locals, too-many-branches, too-many-statements def sync_collection(client, stream, state, stream_projection, max_oplog_ts=None): @@ -151,8 +189,15 @@ def sync_collection(client, stream, state, stream_projection, max_oplog_ts=None) start_time = time.time() oplog_query = { - 'ts': {'$gte': oplog_ts}, - 'ns': {'$eq' : '{}.{}'.format(database_name, collection_name)} + '$and': [ + {'ts': {'$gte': oplog_ts}}, + { + '$or': [ + {'ns': '{}.{}'.format(database_name, collection_name)}, + {'op': 'c', 'o.applyOps.ns': '{}.{}'.format(database_name, collection_name)} + ] + } + ] } projection = transform_projection(stream_projection) @@ -197,39 +242,19 @@ def sync_collection(client, stream, state, stream_projection, max_oplog_ts=None) stream_state['oplog_ts_inc']): raise common.MongoAssertionException( "Mongo is not honoring the sort ascending param") + + row_namespace = row['ns'] + if row_namespace == 'admin.$cmd': + # If the namespace is 'admin.$cmd', then the operation on the record was performed as part + # of a transaction and is recorded as a transactional applyOps entry. + for transaction_row in row['o']['applyOps']: + transaction_row['ts'] = row['ts'] + rows_saved, update_buffer = process_row(schema, transaction_row, stream, update_buffer, + rows_saved, version, time_extracted) + else: + rows_saved, update_buffer = process_row(schema, row, stream, update_buffer, + rows_saved, version, time_extracted) - row_op = row['op'] - - if row_op == 'i': - write_schema(schema, row['o'], stream) - record_message = common.row_to_singer_record(stream, - row['o'], - version, - time_extracted) - singer.write_message(record_message) - - rows_saved += 1 - - elif row_op == 'u': - update_buffer.add(row['o2']['_id']) - - elif row_op == 'd': - - # remove update from buffer if that document has been deleted - if row['o']['_id'] in update_buffer: - update_buffer.remove(row['o']['_id']) - - # Delete ops only contain the _id of the row deleted - row['o'][SDC_DELETED_AT] = row['ts'] - - write_schema(schema, row['o'], stream) - record_message = common.row_to_singer_record(stream, - row['o'], - version, - time_extracted) - singer.write_message(record_message) - - rows_saved += 1 state = update_bookmarks(state, tap_stream_id, diff --git a/tests/test_mongodb_oplog_transaction.py b/tests/test_mongodb_oplog_transaction.py new file mode 100644 index 0000000..6f3e59d --- /dev/null +++ b/tests/test_mongodb_oplog_transaction.py @@ -0,0 +1,236 @@ +import tap_tester.connections as connections +import tap_tester.menagerie as menagerie +import tap_tester.runner as runner +import os +import datetime +import unittest +import datetime +import pymongo +import string +import random +import time +import re +import pprint +import pdb +import bson +from bson import ObjectId +import singer +from functools import reduce +from singer import utils, metadata +from mongodb_common import drop_all_collections, get_test_connection +import decimal + + +RECORD_COUNT = {} + +# def get_test_connection(): +# username = os.getenv('TAP_MONGODB_USER') +# password = os.getenv('TAP_MONGODB_PASSWORD') +# host= os.getenv('TAP_MONGODB_HOST') +# auth_source = os.getenv('TAP_MONGODB_DBNAME') +# port = os.getenv('TAP_MONGODB_PORT') +# ssl = False +# conn = pymongo.MongoClient(host=host, username=username, password=password, port=port, authSource=auth_source, ssl=ssl) +# return conn + +def random_string_generator(size=6, chars=string.ascii_uppercase + string.digits): + return ''.join(random.choice(chars) for x in range(size)) + +def generate_simple_coll_docs(num_docs): + docs = [] + for int_value in range(num_docs): + docs.append({"int_field": int_value, "string_field": random_string_generator()}) + return docs + +class MongoDBOplog(unittest.TestCase): + table_name = 'collection_with_transaction_1' + def setUp(self): + + if not all([x for x in [os.getenv('TAP_MONGODB_HOST'), + os.getenv('TAP_MONGODB_USER'), + os.getenv('TAP_MONGODB_PASSWORD'), + os.getenv('TAP_MONGODB_PORT'), + os.getenv('TAP_MONGODB_DBNAME')]]): + #pylint: disable=line-too-long + raise Exception("set TAP_MONGODB_HOST, TAP_MONGODB_USER, TAP_MONGODB_PASSWORD, TAP_MONGODB_PORT, TAP_MONGODB_DBNAME") + + with get_test_connection() as client: + ############# Drop all dbs/collections ############# + drop_all_collections(client) + + ############# Add simple collections ############# + client["simple_db"][self.table_name].insert_many(generate_simple_coll_docs(50)) + + + + def expected_check_streams(self): + return { + 'simple_db-collection_with_transaction_1', + } + + def expected_pks(self): + return { + self.table_name: {'_id'} + } + + def expected_row_counts(self): + return { + self.table_name: 50 + } + + def expected_sync_streams(self): + return { + self.table_name + } + + def name(self): + return "tap_tester_mongodb_oplog_with_transaction" + + def tap_name(self): + return "tap-mongodb" + + def get_type(self): + return "platform.mongodb" + + def get_credentials(self): + return {'password': os.getenv('TAP_MONGODB_PASSWORD')} + + def get_properties(self): + return {'host' : os.getenv('TAP_MONGODB_HOST'), + 'port' : os.getenv('TAP_MONGODB_PORT'), + 'user' : os.getenv('TAP_MONGODB_USER'), + 'database' : os.getenv('TAP_MONGODB_DBNAME') + } + + + def test_run(self): + + conn_id = connections.ensure_connection(self) + + # ------------------------------- + # ----------- Discovery ---------- + # ------------------------------- + + # run in discovery mode + check_job_name = runner.run_check_mode(self, conn_id) + + # verify check exit codes + exit_status = menagerie.get_exit_status(conn_id, check_job_name) + menagerie.verify_check_exit_status(self, exit_status, check_job_name) + + # verify the tap discovered the right streams + found_catalogs = menagerie.get_catalogs(conn_id) + + # assert we find the correct streams + self.assertEqual(self.expected_check_streams(), + {c['tap_stream_id'] for c in found_catalogs}) + + + + for tap_stream_id in self.expected_check_streams(): + found_stream = [c for c in found_catalogs if c['tap_stream_id'] == tap_stream_id][0] + + # assert that the pks are correct + self.assertEqual(self.expected_pks()[found_stream['stream_name']], + set(found_stream.get('metadata', {}).get('table-key-properties'))) + + # assert that the row counts are correct + self.assertEqual(self.expected_row_counts()[found_stream['stream_name']], + found_stream.get('metadata', {}).get('row-count')) + + # ----------------------------------- + # ----------- Initial Full Table --------- + # ----------------------------------- + # Select simple_coll_1 and simple_coll_2 streams and add replication method metadata + for stream_catalog in found_catalogs: + annotated_schema = menagerie.get_annotated_schema(conn_id, stream_catalog['stream_id']) + additional_md = [{ "breadcrumb" : [], "metadata" : {'replication-method' : 'LOG_BASED'}}] + selected_metadata = connections.select_catalog_and_fields_via_metadata(conn_id, + stream_catalog, + annotated_schema, + additional_md) + + # Run sync + sync_job_name = runner.run_sync_mode(self, conn_id) + + exit_status = menagerie.get_exit_status(conn_id, sync_job_name) + menagerie.verify_sync_exit_status(self, exit_status, sync_job_name) + + + # verify the persisted schema was correct + records_by_stream = runner.get_records_from_target_output() + + # assert that each of the streams that we synced are the ones that we expect to see + record_count_by_stream = runner.examine_target_output_file(self, + conn_id, + self.expected_sync_streams(), + self.expected_pks()) + + # Verify that the full table was syncd + for tap_stream_id in self.expected_sync_streams(): + self.assertGreaterEqual(record_count_by_stream[tap_stream_id],self.expected_row_counts()[tap_stream_id]) + + # Verify that we have 'initial_full_table_complete' bookmark + state = menagerie.get_state(conn_id) + first_versions = {} + + for tap_stream_id in self.expected_check_streams(): + # assert that the state has an initial_full_table_complete == True + self.assertTrue(state['bookmarks'][tap_stream_id]['initial_full_table_complete']) + # assert that there is a version bookmark in state + first_versions[tap_stream_id] = state['bookmarks'][tap_stream_id]['version'] + self.assertIsNotNone(first_versions[tap_stream_id]) + # Verify that we have a oplog_ts_time and oplog_ts_inc bookmark + self.assertIsNotNone(state['bookmarks'][tap_stream_id]['oplog_ts_time']) + self.assertIsNotNone(state['bookmarks'][tap_stream_id]['oplog_ts_inc']) + + # Create records for oplog sync + with get_test_connection() as client: + db = client['simple_db'][self.table_name] + with client.start_session() as session: + with session.start_transaction(): + db.insert_many([{"int_field": x, "string_field": str(x)} for x in range(51, 61)], session=session) + + # Insert 10 docs in one transaction, update 5 of them + with client.start_session() as session: + with session.start_transaction(): + db.insert_many([{"int_field": x, "string_field": str(x)} for x in range(61, 71)], session=session) + for x in range(61, 66): + db.update_one({"string_field": str(x)}, {"$inc": {"int_field": 1}}, session=session) + + # Update 5 docs in one transaction from the the first transaction + with client.start_session() as session: + with session.start_transaction(): + for x in range(51, 56): + db.update_one({"string_field": str(x)}, {"$inc": {"int_field": 1}}, session=session) + + + # ----------------------------------- + # ----------- Subsequent Oplog Sync --------- + # ----------------------------------- + + # Run sync + + sync_job_name = runner.run_sync_mode(self, conn_id) + + exit_status = menagerie.get_exit_status(conn_id, sync_job_name) + menagerie.verify_sync_exit_status(self, exit_status, sync_job_name) + + + # verify the persisted schema was correct + messages_by_stream = runner.get_records_from_target_output() + records_by_stream = {} + for stream_name in self.expected_sync_streams(): + records_by_stream[stream_name] = [x for x in messages_by_stream[stream_name]['messages'] if x.get('action') == 'upsert'] + + + # assert that each of the streams that we synced are the ones that we expect to see + record_count_by_stream = runner.examine_target_output_file(self, + conn_id, + self.expected_sync_streams(), + self.expected_pks()) + + # Verify that we got at least 6 records due to changes + # (could be more due to overlap in gte oplog clause) + self.assertEqual(30, + record_count_by_stream[self.table_name])