Skip to content

Commit

Permalink
Add integration and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sylwiaszunejko committed Jan 5, 2024
1 parent 2b5e408 commit ff00514
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 8 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,11 @@ jobs:
- name: Test with pytest
run: |
export EVENT_LOOP_MANAGER=${{ matrix.event_loop_manager }}
export SCYLLA_VERSION='release:5.1'
./ci/run_integration_test.sh tests/integration/standard/ tests/integration/cqlengine/
- name: Test tablets
run: |
export EVENT_LOOP_MANAGER=${{ matrix.event_loop_manager }}
export SCYLLA_VERSION='unstable/master:2024-01-03T08:06:57Z'
./ci/run_integration_test.sh tests/integration/experiments/
5 changes: 1 addition & 4 deletions ci/run_integration_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ if (( aio_max_nr != aio_max_nr_recommended_value )); then
fi
fi

SCYLLA_RELEASE='release:5.1'

python3 -m venv .test-venv
source .test-venv/bin/activate
pip install -U pip wheel setuptools
Expand All @@ -33,12 +31,11 @@ pip install https://github.com/scylladb/scylla-ccm/archive/master.zip

# download version

ccm create scylla-driver-temp -n 1 --scylla --version ${SCYLLA_RELEASE}
ccm create scylla-driver-temp -n 1 --scylla --version ${SCYLLA_VERSION}
ccm remove

# run test

export SCYLLA_VERSION=${SCYLLA_RELEASE}
export MAPPED_SCYLLA_VERSION=3.11.4
PROTOCOL_VERSION=4 pytest -rf --import-mode append $*

10 changes: 7 additions & 3 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def _id_and_mark(f):
# 1. unittest doesn't skip setUpClass when used on class and we need it sometimes
# 2. unittest doesn't have conditional xfail, and I prefer to use pytest than custom decorator
# 3. unittest doesn't have a reason argument, so you don't see the reason in pytest report
requires_collection_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None and Version(SCYLLA_VERSION.split(':')[1]) < Version('5.2'),
# TODO remove second check when we stop using unstable version in CI for tablets
requires_collection_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None and (len(SCYLLA_VERSION.split('/')) != 0 or Version(SCYLLA_VERSION.split(':')[1]) < Version('5.2')),
reason='Scylla supports collection indexes from 5.2 onwards')
requires_custom_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None,
reason='Scylla does not support SASI or any other CUSTOM INDEX class')
Expand Down Expand Up @@ -501,7 +502,7 @@ def start_cluster_wait_for_up(cluster):


def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, set_keyspace=True, ccm_options=None,
configuration_options=None, dse_options=None, use_single_interface=USE_SINGLE_INTERFACE):
configuration_options=None, dse_options=None, use_single_interface=USE_SINGLE_INTERFACE, use_tablets=False):
configuration_options = configuration_options or {}
dse_options = dse_options or {}
workloads = workloads or []
Expand Down Expand Up @@ -611,7 +612,10 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None,
# CDC is causing an issue (can't start cluster with multiple seeds)
# Selecting only features we need for tests, i.e. anything but CDC.
CCM_CLUSTER = CCMScyllaCluster(path, cluster_name, **ccm_options)
CCM_CLUSTER.set_configuration_options({'experimental_features': ['lwt', 'udf'], 'start_native_transport': True})
if use_tablets:
CCM_CLUSTER.set_configuration_options({'experimental_features': ['lwt', 'udf', 'consistent-topology-changes', 'tablets'], 'start_native_transport': True})
else:
CCM_CLUSTER.set_configuration_options({'experimental_features': ['lwt', 'udf'], 'start_native_transport': True})

# Permit IS NOT NULL restriction on non-primary key columns of a materialized view
# This allows `test_metadata_with_quoted_identifiers` to run
Expand Down
156 changes: 156 additions & 0 deletions tests/integration/experiments/test_tablets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import time
import unittest
import pytest
import os
from cassandra.cluster import Cluster
from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy

from tests.integration import PROTOCOL_VERSION, use_cluster
from tests.unit.test_host_connection_pool import LOGGER

def setup_module():
use_cluster('tablets', [3], start=True, use_tablets=True)

class TestTabletsIntegration(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION,
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
reconnection_policy=ConstantReconnectionPolicy(1))
cls.session = cls.cluster.connect()
cls.create_ks_and_cf(cls)
cls.create_data(cls.session)

@classmethod
def teardown_class(cls):
cls.cluster.shutdown()

def verify_same_host_in_tracing(self, results):
traces = results.get_query_trace()
events = traces.events
host_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description)
host_set.add(event.source)

self.assertEqual(len(host_set), 1)
self.assertIn('locally', "\n".join([event.description for event in events]))

trace_id = results.response_future.get_query_trace_ids()[0]
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
events = [event for event in traces]
host_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity)
host_set.add(event.source)

self.assertEqual(len(host_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))

def verify_same_shard_in_tracing(self, results):
traces = results.get_query_trace()
events = traces.events
shard_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description)
shard_set.add(event.thread_name)

self.assertEqual(len(shard_set), 1)
self.assertIn('locally', "\n".join([event.description for event in events]))

trace_id = results.response_future.get_query_trace_ids()[0]
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
events = [event for event in traces]
shard_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s", event.thread, event.activity)
shard_set.add(event.thread)

self.assertEqual(len(shard_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))

def create_ks_and_cf(self):
self.session.execute(
"""
DROP KEYSPACE IF EXISTS test1
"""
)
self.session.execute(
"""
CREATE KEYSPACE test1
WITH replication = {
'class': 'NetworkTopologyStrategy',
'replication_factor': 1,
'initial_tablets': 8
}
""")

self.session.execute(
"""
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
""")

@staticmethod
def create_data(session):
prepared = session.prepare(
"""
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
""")

for i in range(50):
bound = prepared.bind((i, i%5, i%2))
session.execute(bound)

def query_data_shard_select(self, session, verify_in_tracing=True):
prepared = session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""")

bound = prepared.bind([(2)])
results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
if verify_in_tracing:
self.verify_same_shard_in_tracing(results)

def query_data_host_select(self, session, verify_in_tracing=True):
prepared = session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""")

bound = prepared.bind([(2)])
results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
if verify_in_tracing:
self.verify_same_host_in_tracing(results)

def query_data_shard_insert(self, session, verify_in_tracing=True):
prepared = session.prepare(
"""
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
""")

bound = prepared.bind([(51), (1), (2)])
results = session.execute(bound, trace=True)
if verify_in_tracing:
self.verify_same_shard_in_tracing(results)

def query_data_host_insert(self, session, verify_in_tracing=True):
prepared = session.prepare(
"""
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
""")

bound = prepared.bind([(52), (1), (2)])
results = session.execute(bound, trace=True)
if verify_in_tracing:
self.verify_same_host_in_tracing(results)

def test_tablets(self):
self.query_data_host_select(self.session)
self.query_data_host_insert(self.session)

def test_tablets_shard_awareness(self):
self.query_data_shard_select(self.session)
self.query_data_shard_insert(self.session)
3 changes: 2 additions & 1 deletion tests/unit/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from threading import Thread

from cassandra import ConsistencyLevel
from cassandra.cluster import Cluster
from cassandra.cluster import Cluster, ControlConnection
from cassandra.metadata import Metadata
from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy,
TokenAwarePolicy, SimpleConvictionPolicy,
Expand Down Expand Up @@ -601,6 +601,7 @@ def get_replicas(keyspace, packed_key):
class FakeCluster:
def __init__(self):
self.metadata = Mock(spec=Metadata)
self.control_connection = Mock(spec=ControlConnection)

def test_get_distance(self):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_response_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ResponseFutureTests(unittest.TestCase):
def make_basic_session(self):
s = Mock(spec=Session)
s.row_factory = lambda col_names, rows: [(col_names, rows)]
s.cluster.control_connection._tablets_routing_v1 = False
return s

def make_pool(self):
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_tablets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import unittest

from cassandra.tablets import Tablets, Tablet

class TabletsTest(unittest.TestCase):
def compare_ranges(self, tablets, ranges):
self.assertEqual(len(tablets), len(ranges))

for idx, tablet in enumerate(tablets):
self.assertEqual(tablet.first_token, ranges[idx][0], "First token is not correct in tablet: {}".format(tablet))
self.assertEqual(tablet.last_token, ranges[idx][1], "Last token is not correct in tablet: {}".format(tablet))

def test_add_tablet_to_empty_tablets(self):
tablets = Tablets({("test_ks", "test_tb"): []})

tablets.add_tablet("test_ks", "test_tb", Tablet(-6917529027641081857, -4611686018427387905, None))

tablets_list = tablets._tablets.get(("test_ks", "test_tb"))

self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905)])

def test_add_tablet_at_the_beggining(self):
tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None)]})

tablets.add_tablet("test_ks", "test_tb", Tablet(-8611686018427387905, -7917529027641081857, None))

tablets_list = tablets._tablets.get(("test_ks", "test_tb"))

self.compare_ranges(tablets_list, [(-8611686018427387905, -7917529027641081857),
(-6917529027641081857, -4611686018427387905)])

def test_add_tablet_at_the_end(self):
tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None)]})

tablets.add_tablet("test_ks", "test_tb", Tablet(-1, 2305843009213693951, None))

tablets_list = tablets._tablets.get(("test_ks", "test_tb"))

self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905),
(-1, 2305843009213693951)])

def test_add_tablet_in_the_middle(self):
tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None),
Tablet(-1, 2305843009213693951, None)]},)

tablets.add_tablet("test_ks", "test_tb", Tablet(-4611686018427387905, -2305843009213693953, None))

tablets_list = tablets._tablets.get(("test_ks", "test_tb"))

self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905),
(-4611686018427387905, -2305843009213693953),
(-1, 2305843009213693951)])

def test_add_tablet_intersecting(self):
tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None),
Tablet(-4611686018427387905, -2305843009213693953, None),
Tablet(-2305843009213693953, -1, None),
Tablet(-1, 2305843009213693951, None)]})

tablets.add_tablet("test_ks", "test_tb", Tablet(-3611686018427387905, -6, None))

tablets_list = tablets._tablets.get(("test_ks", "test_tb"))

self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905),
(-3611686018427387905, -6),
(-1, 2305843009213693951)])

def test_add_tablet_intersecting_with_first(self):
tablets = Tablets({("test_ks", "test_tb"): [Tablet(-8611686018427387905, -7917529027641081857, None),
Tablet(-6917529027641081857, -4611686018427387905, None)]})

tablets.add_tablet("test_ks", "test_tb", Tablet(-8011686018427387905, -7987529027641081857, None))

tablets_list = tablets._tablets.get(("test_ks", "test_tb"))

self.compare_ranges(tablets_list, [(-8011686018427387905, -7987529027641081857),
(-6917529027641081857, -4611686018427387905)])

def test_add_tablet_intersecting_with_last(self):
tablets = Tablets({("test_ks", "test_tb"): [Tablet(-8611686018427387905, -7917529027641081857, None),
Tablet(-6917529027641081857, -4611686018427387905, None)]})

tablets.add_tablet("test_ks", "test_tb", Tablet(-5011686018427387905, -2987529027641081857, None))

tablets_list = tablets._tablets.get(("test_ks", "test_tb"))

self.compare_ranges(tablets_list, [(-8611686018427387905, -7917529027641081857),
(-5011686018427387905, -2987529027641081857)])

0 comments on commit ff00514

Please sign in to comment.