Skip to content

Commit

Permalink
Capture transaction based records for the log based replication
Browse files Browse the repository at this point in the history
  • Loading branch information
sgandhi1311 committed Dec 9, 2024
1 parent b4c73f1 commit bdf4eb5
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 34 deletions.
93 changes: 59 additions & 34 deletions tap_mongodb/sync_strategies/oplog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from bson import timestamp
import tap_mongodb.sync_strategies.common as common

from debugpy import listen, wait_for_client
listen(8000)
wait_for_client()

LOGGER = singer.get_logger()

Expand Down Expand Up @@ -124,6 +127,41 @@ def maybe_get_session(client):
# return an object that works with a 'with' block
return SessionNotAvailable()

def process_row(schema, row, stream, update_buffer, rows_saved, version, time_extracted):
row_op = row['op']

if row_op == 'i':
write_schema(schema, row['o'], stream)
record_message = common.row_to_singer_record(stream,
row['o'],
version,
time_extracted)
singer.write_message(record_message)

rows_saved += 1

elif row_op == 'u':
update_buffer.add(row['o2']['_id'])

elif row_op == 'd':

# remove update from buffer if that document has been deleted
if row['o']['_id'] in update_buffer:
update_buffer.remove(row['o']['_id'])

# Delete ops only contain the _id of the row deleted
row['o'][SDC_DELETED_AT] = row['ts']

write_schema(schema, row['o'], stream)
record_message = common.row_to_singer_record(stream,
row['o'],
version,
time_extracted)
singer.write_message(record_message)

rows_saved += 1

return (rows_saved, update_buffer)

# pylint: disable=too-many-locals, too-many-branches, too-many-statements
def sync_collection(client, stream, state, stream_projection, max_oplog_ts=None):
Expand Down Expand Up @@ -151,8 +189,15 @@ def sync_collection(client, stream, state, stream_projection, max_oplog_ts=None)
start_time = time.time()

oplog_query = {
'ts': {'$gte': oplog_ts},
'ns': {'$eq' : '{}.{}'.format(database_name, collection_name)}
'$and': [
{'ts': {'$gte': oplog_ts}},
{
'$or': [
{'ns': '{}.{}'.format(database_name, collection_name)},
{'op': 'c', 'o.applyOps.ns': '{}.{}'.format(database_name, collection_name)}
]
}
]
}

projection = transform_projection(stream_projection)
Expand Down Expand Up @@ -197,39 +242,19 @@ def sync_collection(client, stream, state, stream_projection, max_oplog_ts=None)
stream_state['oplog_ts_inc']):
raise common.MongoAssertionException(
"Mongo is not honoring the sort ascending param")

row_namespace = row['ns']
if row_namespace == 'admin.$cmd':
# If the namespace is 'admin.$cmd', then the operation on the record was performed as part
# of a transaction and is recorded as a transactional applyOps entry.
for transaction_row in row['o']['applyOps']:
transaction_row['ts'] = row['ts']
rows_saved, update_buffer = process_row(schema, transaction_row, stream, update_buffer,
rows_saved, version, time_extracted)
else:
rows_saved, update_buffer = process_row(schema, row, stream, update_buffer,
rows_saved, version, time_extracted)

row_op = row['op']

if row_op == 'i':
write_schema(schema, row['o'], stream)
record_message = common.row_to_singer_record(stream,
row['o'],
version,
time_extracted)
singer.write_message(record_message)

rows_saved += 1

elif row_op == 'u':
update_buffer.add(row['o2']['_id'])

elif row_op == 'd':

# remove update from buffer if that document has been deleted
if row['o']['_id'] in update_buffer:
update_buffer.remove(row['o']['_id'])

# Delete ops only contain the _id of the row deleted
row['o'][SDC_DELETED_AT] = row['ts']

write_schema(schema, row['o'], stream)
record_message = common.row_to_singer_record(stream,
row['o'],
version,
time_extracted)
singer.write_message(record_message)

rows_saved += 1

state = update_bookmarks(state,
tap_stream_id,
Expand Down
236 changes: 236 additions & 0 deletions tests/test_mongodb_oplog_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import tap_tester.connections as connections
import tap_tester.menagerie as menagerie
import tap_tester.runner as runner
import os
import datetime
import unittest
import datetime
import pymongo
import string
import random
import time
import re
import pprint
import pdb
import bson
from bson import ObjectId
import singer
from functools import reduce
from singer import utils, metadata
from mongodb_common import drop_all_collections, get_test_connection
import decimal


RECORD_COUNT = {}

# def get_test_connection():
# username = os.getenv('TAP_MONGODB_USER')
# password = os.getenv('TAP_MONGODB_PASSWORD')
# host= os.getenv('TAP_MONGODB_HOST')
# auth_source = os.getenv('TAP_MONGODB_DBNAME')
# port = os.getenv('TAP_MONGODB_PORT')
# ssl = False
# conn = pymongo.MongoClient(host=host, username=username, password=password, port=port, authSource=auth_source, ssl=ssl)
# return conn

def random_string_generator(size=6, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for x in range(size))

def generate_simple_coll_docs(num_docs):
docs = []
for int_value in range(num_docs):
docs.append({"int_field": int_value, "string_field": random_string_generator()})
return docs

class MongoDBOplog(unittest.TestCase):
table_name = 'collection_with_transaction_1'
def setUp(self):

if not all([x for x in [os.getenv('TAP_MONGODB_HOST'),
os.getenv('TAP_MONGODB_USER'),
os.getenv('TAP_MONGODB_PASSWORD'),
os.getenv('TAP_MONGODB_PORT'),
os.getenv('TAP_MONGODB_DBNAME')]]):
#pylint: disable=line-too-long
raise Exception("set TAP_MONGODB_HOST, TAP_MONGODB_USER, TAP_MONGODB_PASSWORD, TAP_MONGODB_PORT, TAP_MONGODB_DBNAME")

with get_test_connection() as client:
############# Drop all dbs/collections #############
drop_all_collections(client)

############# Add simple collections #############
client["simple_db"][self.table_name].insert_many(generate_simple_coll_docs(50))



def expected_check_streams(self):
return {
'simple_db-collection_with_transaction_1',
}

def expected_pks(self):
return {
self.table_name: {'_id'}
}

def expected_row_counts(self):
return {
self.table_name: 50
}

def expected_sync_streams(self):
return {
self.table_name
}

def name(self):
return "tap_tester_mongodb_oplog_with_transaction"

def tap_name(self):
return "tap-mongodb"

def get_type(self):
return "platform.mongodb"

def get_credentials(self):
return {'password': os.getenv('TAP_MONGODB_PASSWORD')}

def get_properties(self):
return {'host' : os.getenv('TAP_MONGODB_HOST'),
'port' : os.getenv('TAP_MONGODB_PORT'),
'user' : os.getenv('TAP_MONGODB_USER'),
'database' : os.getenv('TAP_MONGODB_DBNAME')
}


def test_run(self):

conn_id = connections.ensure_connection(self)

# -------------------------------
# ----------- Discovery ----------
# -------------------------------

# run in discovery mode
check_job_name = runner.run_check_mode(self, conn_id)

# verify check exit codes
exit_status = menagerie.get_exit_status(conn_id, check_job_name)
menagerie.verify_check_exit_status(self, exit_status, check_job_name)

# verify the tap discovered the right streams
found_catalogs = menagerie.get_catalogs(conn_id)

# assert we find the correct streams
self.assertEqual(self.expected_check_streams(),
{c['tap_stream_id'] for c in found_catalogs})



for tap_stream_id in self.expected_check_streams():
found_stream = [c for c in found_catalogs if c['tap_stream_id'] == tap_stream_id][0]

# assert that the pks are correct
self.assertEqual(self.expected_pks()[found_stream['stream_name']],
set(found_stream.get('metadata', {}).get('table-key-properties')))

# assert that the row counts are correct
self.assertEqual(self.expected_row_counts()[found_stream['stream_name']],
found_stream.get('metadata', {}).get('row-count'))

# -----------------------------------
# ----------- Initial Full Table ---------
# -----------------------------------
# Select simple_coll_1 and simple_coll_2 streams and add replication method metadata
for stream_catalog in found_catalogs:
annotated_schema = menagerie.get_annotated_schema(conn_id, stream_catalog['stream_id'])
additional_md = [{ "breadcrumb" : [], "metadata" : {'replication-method' : 'LOG_BASED'}}]
selected_metadata = connections.select_catalog_and_fields_via_metadata(conn_id,
stream_catalog,
annotated_schema,
additional_md)

# Run sync
sync_job_name = runner.run_sync_mode(self, conn_id)

exit_status = menagerie.get_exit_status(conn_id, sync_job_name)
menagerie.verify_sync_exit_status(self, exit_status, sync_job_name)


# verify the persisted schema was correct
records_by_stream = runner.get_records_from_target_output()

# assert that each of the streams that we synced are the ones that we expect to see
record_count_by_stream = runner.examine_target_output_file(self,
conn_id,
self.expected_sync_streams(),
self.expected_pks())

# Verify that the full table was syncd
for tap_stream_id in self.expected_sync_streams():
self.assertGreaterEqual(record_count_by_stream[tap_stream_id],self.expected_row_counts()[tap_stream_id])

# Verify that we have 'initial_full_table_complete' bookmark
state = menagerie.get_state(conn_id)
first_versions = {}

for tap_stream_id in self.expected_check_streams():
# assert that the state has an initial_full_table_complete == True
self.assertTrue(state['bookmarks'][tap_stream_id]['initial_full_table_complete'])
# assert that there is a version bookmark in state
first_versions[tap_stream_id] = state['bookmarks'][tap_stream_id]['version']
self.assertIsNotNone(first_versions[tap_stream_id])
# Verify that we have a oplog_ts_time and oplog_ts_inc bookmark
self.assertIsNotNone(state['bookmarks'][tap_stream_id]['oplog_ts_time'])
self.assertIsNotNone(state['bookmarks'][tap_stream_id]['oplog_ts_inc'])

# Create records for oplog sync
with get_test_connection() as client:
db = client['simple_db'][self.table_name]
with client.start_session() as session:
with session.start_transaction():
db.insert_many([{"int_field": x, "string_field": str(x)} for x in range(51, 61)], session=session)

# Insert 10 docs in one transaction, update 5 of them
with client.start_session() as session:
with session.start_transaction():
db.insert_many([{"int_field": x, "string_field": str(x)} for x in range(61, 71)], session=session)
for x in range(61, 66):
db.update_one({"string_field": str(x)}, {"$inc": {"int_field": 1}}, session=session)

# Update 5 docs in one transaction from the the first transaction
with client.start_session() as session:
with session.start_transaction():
for x in range(51, 56):
db.update_one({"string_field": str(x)}, {"$inc": {"int_field": 1}}, session=session)


# -----------------------------------
# ----------- Subsequent Oplog Sync ---------
# -----------------------------------

# Run sync

sync_job_name = runner.run_sync_mode(self, conn_id)

exit_status = menagerie.get_exit_status(conn_id, sync_job_name)
menagerie.verify_sync_exit_status(self, exit_status, sync_job_name)


# verify the persisted schema was correct
messages_by_stream = runner.get_records_from_target_output()
records_by_stream = {}
for stream_name in self.expected_sync_streams():
records_by_stream[stream_name] = [x for x in messages_by_stream[stream_name]['messages'] if x.get('action') == 'upsert']


# assert that each of the streams that we synced are the ones that we expect to see
record_count_by_stream = runner.examine_target_output_file(self,
conn_id,
self.expected_sync_streams(),
self.expected_pks())

# Verify that we got at least 6 records due to changes
# (could be more due to overlap in gte oplog clause)
self.assertEqual(30,
record_count_by_stream[self.table_name])

0 comments on commit bdf4eb5

Please sign in to comment.