diff --git a/magicparse/__init__.py b/magicparse/__init__.py index 72d91d5..0d42af9 100644 --- a/magicparse/__init__.py +++ b/magicparse/__init__.py @@ -1,6 +1,6 @@ from io import BytesIO -from .schema import Schema, builtins as builtins_schemas +from .schema import ParsedRow, Schema, builtins as builtins_schemas from .post_processors import PostProcessor, builtins as builtins_post_processors from .pre_processors import PreProcessor, builtins as builtins_pre_processors from .builders import ( @@ -9,16 +9,18 @@ ) from .transform import Transform from .type_converters import TypeConverter, builtins as builtins_type_converters -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, Iterable, List, Tuple, Union from .validators import Validator, builtins as builtins_validators __all__ = [ "TypeConverter", "parse", + "stream_parse", "PostProcessor", "PreProcessor", "Schema", + "ParsedRow", "Validator", ] @@ -30,6 +32,13 @@ def parse( return schema_definition.parse(data) +def stream_parse( + data: Union[bytes, BytesIO], schema_options: Dict[str, Any] +) -> Iterable[ParsedRow]: + schema_definition = Schema.build(schema_options) + return schema_definition.stream_parse(data) + + Registrable = Union[Schema, Transform] diff --git a/magicparse/schema.py b/magicparse/schema.py index 914d45c..3941743 100644 --- a/magicparse/schema.py +++ b/magicparse/schema.py @@ -1,11 +1,19 @@ import codecs from abc import ABC, abstractmethod import csv +from dataclasses import dataclass from .fields import Field, ComputedField from io import BytesIO from typing import Any, Dict, List, Tuple, Union, Iterable +@dataclass(frozen=True, slots=True) +class ParsedRow: + row_number: int + values: dict + errors: list[dict] + + class Schema(ABC): fields: List[Field] encoding: str @@ -48,17 +56,15 @@ def parse(self, data: Union[bytes, BytesIO]) -> Tuple[List[dict], List[dict]]: items = [] errors = [] - for item, row_errors in self.stream_parse(data): - if row_errors: - errors.extend(row_errors) + for parsed_row in self.stream_parse(data): + if parsed_row.errors: + errors.extend(parsed_row.errors) else: - items.append(item) + items.append(parsed_row.values) return items, errors - def stream_parse( - self, data: Union[bytes, BytesIO] - ) -> Iterable[Tuple[dict, list[dict]]]: + def stream_parse(self, data: Union[bytes, BytesIO]) -> Iterable[ParsedRow]: if isinstance(data, bytes): stream = BytesIO(data) else: @@ -98,7 +104,7 @@ def stream_parse( item[computed_field.key] = value - yield item, errors + yield ParsedRow(row_number, item, errors) class CsvSchema(Schema): diff --git a/tests/test_schema.py b/tests/test_schema.py index 40e7ef3..8ad7135 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,7 +1,7 @@ from decimal import Decimal from magicparse import Schema -from magicparse.schema import ColumnarSchema, CsvSchema +from magicparse.schema import ColumnarSchema, CsvSchema, ParsedRow from magicparse.fields import ColumnarField, CsvField import pytest from unittest import TestCase @@ -337,10 +337,11 @@ def test_stream_parse_errors_do_not_halt_parsing(self): ) rows = list(schema.stream_parse(b"1\na\n2")) assert rows == [ - ({"age": 1}, []), - ( - {}, - [ + ParsedRow(row_number=1, values={"age": 1}, errors=[]), + ParsedRow( + row_number=2, + values={}, + errors=[ { "row-number": 2, "column-number": 1, @@ -349,5 +350,29 @@ def test_stream_parse_errors_do_not_halt_parsing(self): } ], ), - ({"age": 2}, []), + ParsedRow(row_number=3, values={"age": 2}, errors=[]), ] + + def test_stream_parse_with_header_first_row_number_is_2(self): + schema = Schema.build( + { + "has_header": True, + "file_type": "csv", + "fields": [{"key": "age", "type": "int", "column-number": 1}], + } + ) + rows = list(schema.stream_parse(b"My age\n1")) + assert len(rows) == 1 + assert rows[0].row_number == 2 + + def test_stream_parse_without_header_first_row_number_is_1(self): + schema = Schema.build( + { + "has_header": False, + "file_type": "csv", + "fields": [{"key": "age", "type": "int", "column-number": 1}], + } + ) + rows = list(schema.stream_parse(b"1")) + assert len(rows) == 1 + assert rows[0].row_number == 1