Skip to content

Commit

Permalink
COPY TO STDOUT shouldn't put None where a function is expected
Browse files Browse the repository at this point in the history
using the command was failing like the following
```
cassandra@cqlsh> COPY system_schema.tables TO STDOUT ;
'NoneType' object is not callable
cassandra@cqlsh>
```

the logic was not working as expected, and `self.printmsg` was
assigned with `None`

* introduced a unit test to cover this case

Ref: scylladb/scylla-enterprise#3940
  • Loading branch information
fruch committed Feb 21, 2024
1 parent b8d86b7 commit 387f769
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
6 changes: 5 additions & 1 deletion pylib/cqlshlib/copyutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def printmsg(msg, eol='\n'):
sys.stdout.flush()


def noop(*arg, **kwargs):
pass


class OneWayPipe(object):
"""
A one way pipe protected by two process level locks, one for reading and one for writing.
Expand Down Expand Up @@ -259,7 +263,7 @@ def __init__(self, shell, ks, table, columns, fname, opts, protocol_version, con
DEBUG = True

# do not display messages when exporting to STDOUT unless --debug is set
self.printmsg = printmsg if self.fname is not None or direction == 'from' or DEBUG else None
self.printmsg = printmsg if self.fname is not None or direction == 'from' or DEBUG else noop
self.options = self.parse_options(opts, direction)

self.num_processes = self.options.copy['numprocesses']
Expand Down
11 changes: 9 additions & 2 deletions pylib/cqlshlib/test/test_copyutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from cqlshlib.copyutil import ExportTask


Default = object()


class CopyTaskTest(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -60,10 +63,11 @@ def mock_shell(self):

class TestExportTask(CopyTaskTest):

def _test_get_ranges_murmur3_base(self, opts, expected_ranges):
def _test_get_ranges_murmur3_base(self, opts, expected_ranges, fname=Default):
"""
Set up a mock shell with a simple token map to test the ExportTask get_ranges function.
"""
fname = self.fname if fname is Default else fname
shell = self.mock_shell()
shell.conn.metadata.partitioner = 'Murmur3Partitioner'
# token range for a cluster of 4 nodes with replication factor 3
Expand All @@ -77,7 +81,7 @@ def _test_get_ranges_murmur3_base(self, opts, expected_ranges):
overridden_opts = dict(self.opts)
for k, v in opts.items():
overridden_opts[k] = v
export_task = ExportTask(shell, self.ks, self.table, self.columns, self.fname, overridden_opts, self.protocol_version, self.config_file)
export_task = ExportTask(shell, self.ks, self.table, self.columns, fname, overridden_opts, self.protocol_version, self.config_file)
assert export_task.get_ranges() == expected_ranges
export_task.close()

Expand Down Expand Up @@ -114,3 +118,6 @@ def test_get_ranges_murmur3(self):
(None, MIN_LONG + 1): {'hosts': ('10.0.0.2', '10.0.0.3', '10.0.0.4'), 'attempts': 0, 'rows': 0, 'workerno': -1}
}
self._test_get_ranges_murmur3_base({'endtoken': MIN_LONG + 1}, expected_ranges)

def test_exporting_to_std(self):
self._test_get_ranges_murmur3_base({'begintoken': MIN_LONG - 1}, {}, fname=None)

0 comments on commit 387f769

Please sign in to comment.