Skip to content

Commit

Permalink
support pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
zkurtz committed Nov 29, 2024
1 parent c661796 commit d4a6c48
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ So far we support:
- csv
- parquet
- onnx.ModelProto instances
- pydantic models (relying on the built-in json serialization methods)

## Dependencies

Expand Down
6 changes: 6 additions & 0 deletions dummio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@
except ImportError:
# this would require the optional dependency onnx
pass

try:
from dummio import pydantic as pydantic
except ImportError:
# this would require the optional dependency pydantic
pass
28 changes: 28 additions & 0 deletions dummio/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""IO for pydantic models."""

from pathlib import Path
from typing import Type

import pydantic

from dummio.constants import PathType


def save(
data: pydantic.BaseModel,
*,
filepath: PathType,
) -> None:
"""Save a pydantic model instance to a json text file."""
data_json_str = data.model_dump_json()
Path(filepath).write_text(data_json_str)


def load(
filepath: PathType,
*,
model: Type[pydantic.BaseModel],
) -> pydantic.BaseModel:
"""Load a pydantic model instance from a json text file."""
data_json_str = Path(filepath).read_text()
return model.model_validate_json(data_json_str)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dummio"
version = "1.1.0"
version = "1.2.0"
description = "Easiest-possible IO for basic file types."
authors = [{ name = "Zach Kurtz", email = "[email protected]" }]
readme = "README.md"
Expand All @@ -14,6 +14,7 @@ dev = [
"pytest >=8.3.2",
"scikit-learn >=1.0.2",
"skl2onnx >=1.10.1",
"pydantic>=2.10.2",
]
extras = [
"fastparquet>=2024.11.0",
Expand Down
1 change: 1 addition & 0 deletions tests/test_assert_module_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
IO_MODULES = [
"dummio.json",
"dummio.onnx",
"dummio.pydantic",
"dummio.text",
"dummio.yaml",
"dummio.pandas.df_csv",
Expand Down
30 changes: 30 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from datetime import datetime, timezone
from pathlib import Path
from uuid import UUID, uuid4

from pydantic import BaseModel

import dummio.pydantic


# make a pydantic model equivalent to the dataclass above
class Data(BaseModel):
id: UUID
documentation: str
config: dict
rmse: float
trained_at: datetime


def test_pydantic_io(tmp_path: Path) -> None:
data = Data(
id=uuid4(),
documentation="This is a test data instance.",
config={"n_estimators": 100, "learning_rate": 0.01},
rmse=0.1,
trained_at=datetime.now(timezone.utc),
)
filepath = tmp_path / "data.json"
dummio.pydantic.save(data, filepath=filepath)
loaded_data = dummio.pydantic.load(filepath=filepath, model=Data)
assert data == loaded_data
100 changes: 100 additions & 0 deletions uv.lock

Large diffs are not rendered by default.

0 comments on commit d4a6c48

Please sign in to comment.