diff --git a/magicparse/schema.py b/magicparse/schema.py index f312443..40f1002 100644 --- a/magicparse/schema.py +++ b/magicparse/schema.py @@ -80,13 +80,17 @@ class CsvSchema(Schema): def __init__(self, options: Dict[str, Any]) -> None: super().__init__(options) self.delimiter = options.get("delimiter", ",") + self.quotechar = options.get("quotechar", '"') def get_reader(self, stream: BytesIO) -> Iterable[List[str]]: stream_reader = codecs.getreader(self.encoding) stream_content = stream_reader(stream) return csv.reader( - stream_content, delimiter=self.delimiter, quoting=csv.QUOTE_NONE + stream_content, + delimiter=self.delimiter, + quoting=csv.QUOTE_MINIMAL, + quotechar=self.quotechar, ) @staticmethod diff --git a/tests/test_schema.py b/tests/test_schema.py index 7dfbe29..65e3fd7 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,3 +1,4 @@ +from decimal import Decimal from magicparse import Schema from magicparse.schema import ColumnarSchema, CsvSchema from magicparse.fields import ColumnarField, CsvField @@ -215,6 +216,45 @@ def test_errors_do_not_halt_parsing(self): ] +class TestQuotingSetting(TestCase): + def test_no_quote(self): + schema = Schema.build( + { + "file_type": "csv", + "has_header": True, + "fields": [{"key": "column_1", "type": "decimal", "column-number": 1}], + } + ) + rows, errors = schema.parse(b"column_1\n6.66") + assert rows == [{"column_1": Decimal("6.66")}] + assert not errors + + def test_single_quote(self): + schema = Schema.build( + { + "file_type": "csv", + "quotechar": "'", + "has_header": True, + "fields": [{"key": "column_1", "type": "decimal", "column-number": 1}], + } + ) + rows, errors = schema.parse(b"column_1\n'6.66'") + assert rows == [{"column_1": Decimal("6.66")}] + assert not errors + + def test_double_quote(self): + schema = Schema.build( + { + "file_type": "csv", + "has_header": True, + "fields": [{"key": "column_1", "type": "decimal", "column-number": 1}], + } + ) + rows, errors = schema.parse(b'column_1\n"6.66"') + assert rows == [{"column_1": Decimal("6.66")}] + assert not errors + + class TestRegister(TestCase): class PipedSchema(Schema): @staticmethod