Skip to content

Commit

Permalink
Merge pull request #585 from roy0424/feature/format-description-event
Browse files Browse the repository at this point in the history
  • Loading branch information
sean-k1 authored Dec 4, 2023
2 parents 78f2114 + 59202a2 commit 579277e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 9 deletions.
13 changes: 6 additions & 7 deletions pymysqlreplication/binlogstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def __init__(
else:
self.pymysql_wrapper = pymysql.connect
self.mysql_version = (0, 0, 0)
self.dbms = None

def close(self):
if self.__connected_stream:
Expand Down Expand Up @@ -748,14 +749,12 @@ def _allowed_event_list(
def __get_dbms(self):
if not self.__connected_ctl:
self.__connect_to_ctl()

cur = self._ctl_connection.cursor()
cur.execute("SELECT VERSION();")

version_info = cur.fetchone().get("VERSION()", "")

if "MariaDB" in version_info:
if self.dbms:
return self.dbms
if "MariaDB" in self._ctl_connection.get_server_info():
self.dbms = "mariadb"
return "mariadb"
self.dbms = "mysql"
return "mysql"

def __log_valid_parameters(self):
Expand Down
19 changes: 18 additions & 1 deletion pymysqlreplication/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
self._processed = True
self.complete = True
self._verify_event()
self.dbms = self._ctl_connection._get_dbms()

def _read_table_id(self):
# Table ID is 6 byte
Expand Down Expand Up @@ -368,10 +369,26 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs)
self.mysql_version_str = self.packet.read(50).rstrip(b"\0").decode()
numbers = self.mysql_version_str.split("-")[0]
self.mysql_version = tuple(map(int, numbers.split(".")))
self.created = struct.unpack("<I", self.packet.read(4))[0]
self.common_header_len = struct.unpack("<B", self.packet.read(1))[0]
offset = (
4 + 2 + 50 + 1
) # created + binlog_version + mysql_version_str + common_header_len
checksum_algorithm = 1
checksum = 4
n = event_size - offset - self.common_header_len - checksum_algorithm - checksum
self.post_header_len = struct.unpack(f"<{n}B", self.packet.read(n))
self.server_version_split = struct.unpack("<3B", self.packet.read(3))
self.number_of_event_types = struct.unpack("<B", self.packet.read(1))[0]

def _dump(self):
print(f"Binlog version: {self.binlog_version}")
print(f"MySQL version: {self.mysql_version_str}")
print(f"mysql version: {self.mysql_version_str}")
print(f"Created: {self.created}")
print(f"Common header length: {self.common_header_len}")
print(f"Post header length: {self.post_header_len}")
print(f"Server version split: {self.server_version_split}")
print(f"Number of event types: {self.number_of_event_types}")


class StopEvent(BinLogEvent):
Expand Down
2 changes: 1 addition & 1 deletion pymysqlreplication/row_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs)
self.column_count = self.packet.read_length_coded_binary()

self.columns = []
self.dbms = self._ctl_connection._get_dbms()
self.dbms = self.dbms or self._ctl_connection._get_dbms()
# Read columns meta data
column_types = bytearray(self.packet.read(self.column_count))
self.packet.read_length_coded_binary()
Expand Down
50 changes: 50 additions & 0 deletions pymysqlreplication/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,32 @@ def test_json_update(self):
),
self.assertEqual(event.rows[0]["after_values"]["setting"], {b"btn": True}),

def test_format_description_event(self):
self.stream.close()
self.stream = BinLogStreamReader(
self.database,
server_id=1024,
blocking=False,
only_events=[FormatDescriptionEvent],
)

event = self.stream.fetchone()
self.assertIsInstance(event, FormatDescriptionEvent)
self.assertIsInstance(event.binlog_version, tuple)
self.assertIsInstance(event.mysql_version_str, str)
self.assertTrue(
event.mysql_version_str.startswith("5.")
or event.mysql_version_str.startswith("8.")
) # Example check
self.assertIsInstance(event.common_header_len, int)
self.assertIsInstance(event.post_header_len, tuple)
self.assertIsInstance(event.mysql_version, tuple)
self.assertEqual(len(event.mysql_version), 3)
self.assertEqual(event.dbms, "mysql")
self.assertIsInstance(event.server_version_split, tuple)
self.assertEqual(len(event.server_version_split), 3)
self.assertIsInstance(event.number_of_event_types, int)


class TestMultipleRowBinLogStreamReader(base.PyMySQLReplicationTestCase):
def setUp(self):
Expand Down Expand Up @@ -1485,6 +1511,30 @@ def test_gtid_list_event(self):
self.assertEqual(event.event_type, 163)
self.assertEqual(event.gtid_list[0].gtid, "0-1-15")

def test_format_description_event(self):
self.stream.close()
self.stream = BinLogStreamReader(
self.database,
server_id=1024,
blocking=False,
only_events=[FormatDescriptionEvent],
is_mariadb=True,
)

event = self.stream.fetchone()
self.assertIsInstance(event, FormatDescriptionEvent)
self.assertIsInstance(event.binlog_version, tuple)
self.assertIsInstance(event.mysql_version_str, str)
self.assertTrue(event.mysql_version_str.startswith("10."))
self.assertIsInstance(event.common_header_len, int)
self.assertIsInstance(event.post_header_len, tuple)
self.assertIsInstance(event.mysql_version, tuple)
self.assertEqual(len(event.mysql_version), 3)
self.assertEqual(event.dbms, "mariadb")
self.assertIsInstance(event.server_version_split, tuple)
self.assertEqual(len(event.server_version_split), 3)
self.assertIsInstance(event.number_of_event_types, int)


class TestRowsQueryLogEvents(base.PyMySQLReplicationTestCase):
def setUp(self):
Expand Down

0 comments on commit 579277e

Please sign in to comment.