diff --git a/csvkit/utilities/csvstat.py b/csvkit/utilities/csvstat.py index 2a6342c0b..a21ac64a9 100644 --- a/csvkit/utilities/csvstat.py +++ b/csvkit/utilities/csvstat.py @@ -213,13 +213,7 @@ def main(self): def is_finite_decimal(self, value): return isinstance(value, Decimal) and value.is_finite() - def print_one(self, table, column_id, operation, label=True, **kwargs): - """ - Print data for a single statistic. - """ - column_name = table.column_names[column_id] - - op_name = operation + def _calculate_stat(self, table, column_id, op_name, op_data, **kwargs): getter = globals().get(f'get_{op_name}') with warnings.catch_warnings(): @@ -227,15 +221,25 @@ def print_one(self, table, column_id, operation, label=True, **kwargs): try: if getter: - stat = getter(table, column_id, **kwargs) - else: - op = OPERATIONS[op_name]['aggregation'] - stat = table.aggregate(op(column_id)) + return getter(table, column_id, **kwargs) - if self.is_finite_decimal(stat): - stat = format_decimal(stat, self.args.decimal_format, self.args.no_grouping_separator) + op = op_data['aggregation'] + v = table.aggregate(op(column_id)) + + if self.is_finite_decimal(v) and not self.args.json_output: + return format_decimal(v, self.args.decimal_format, self.args.no_grouping_separator) + + return v except Exception: - stat = None + pass + + def print_one(self, table, column_id, op_name, label=True, **kwargs): + """ + Print data for a single statistic. + """ + column_name = table.column_names[column_id] + + stat = self._calculate_stat(table, column_id, op_name, OPERATIONS[op_name], **kwargs) # Formatting if op_name == 'freq': @@ -251,29 +255,10 @@ def calculate_stats(self, table, column_id, **kwargs): """ Calculate stats for all valid operations. """ - stats = {} - - for op_name, op_data in OPERATIONS.items(): - getter = globals().get(f'get_{op_name}') - - with warnings.catch_warnings(): - warnings.simplefilter('ignore', agate.NullCalculationWarning) - - try: - if getter: - stats[op_name] = getter(table, column_id, **kwargs) - else: - op = op_data['aggregation'] - v = table.aggregate(op(column_id)) - - if self.is_finite_decimal(v) and not self.args.json_output: - v = format_decimal(v, self.args.decimal_format, self.args.no_grouping_separator) - - stats[op_name] = v - except Exception: - stats[op_name] = None - - return stats + return { + op_name: self._calculate_stat(table, column_id, op_name, op_data, **kwargs) + for op_name, op_data in OPERATIONS.items() + } def print_stats(self, table, column_ids, stats): """ diff --git a/tests/test_utilities/test_csvstat.py b/tests/test_utilities/test_csvstat.py index e27d01af6..bc01dd58f 100644 --- a/tests/test_utilities/test_csvstat.py +++ b/tests/test_utilities/test_csvstat.py @@ -118,7 +118,7 @@ def test_json(self): self.assertEqual(row[1], 'state') self.assertEqual(row[2], 'Text') self.assertNotIn('min', data[0]) - self.assertEqual(row[-2], '2') + self.assertEqual(row[-2], 2.0) def test_json_columns(self): output = self.get_output_as_io(['--json', '-c', '4', 'examples/realdata/ks_1033_data.csv']) @@ -135,7 +135,7 @@ def test_json_columns(self): self.assertEqual(row[1], 'nsn') self.assertEqual(row[2], 'Text') self.assertNotIn('min', data[0]) - self.assertEqual(row[-2], '16') + self.assertEqual(row[-2], 16.0) def test_decimal_format(self): output = self.get_output(['-c', 'TOTAL', '--mean', 'examples/realdata/FY09_EDU_Recipients_by_State.csv'])