diff --git a/target_csv/serialization.py b/target_csv/serialization.py index adc8acd..87782b6 100644 --- a/target_csv/serialization.py +++ b/target_csv/serialization.py @@ -1,8 +1,26 @@ import csv # noqa: D100 from pathlib import Path -from typing import Any, List +from typing import Any, List, Callable +import os +def create_folder_if_not_exists(func: Any) -> Callable[..., int]: + """Decorator to create folder if it does not exist.""" + + def wrapper(*args: Any, **kwargs: Any) -> int: + try: + filepath = kwargs["filepath"] + except KeyError: + filepath = args[0] + folder = os.path.dirname(filepath) + if not os.path.exists(folder) and folder != "": + os.makedirs(folder) + return func(*args, **kwargs) + + return wrapper + + +@create_folder_if_not_exists def write_csv(filepath: Path, records: List[dict], schema: dict, **kwargs: Any) -> int: """Write a CSV file.""" if "properties" not in schema: diff --git a/tests/test_csv.py b/tests/test_csv.py index a371196..11d46c0 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -58,11 +58,30 @@ def output_filepath(output_dir) -> Path: return result +@pytest.fixture +def test_file_paths(output_dir) -> List[Path]: + paths = [] + for dir in range(4): + path = Path(output_dir / f"test-dir-{dir}/csv-test-output-{dir}.csv") + if path.exists(): + path.unlink() + + paths.append(path) + + return paths + + def test_csv_write(output_filepath) -> None: for schema, records in SAMPLE_DATASETS: write_csv(filepath=output_filepath, records=records, schema=schema) +def test_csv_write_if_not_exists(test_file_paths) -> None: + for path in test_file_paths: + for schema, records in SAMPLE_DATASETS: + write_csv(filepath=path, records=records, schema=schema) + + def test_csv_roundtrip(output_filepath) -> None: for schema, records in SAMPLE_DATASETS: write_csv(filepath=output_filepath, records=records, schema=schema)