Skip to content

Commit

Permalink
add onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
zkurtz committed Nov 28, 2024
1 parent 484c063 commit 6120690
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ So far we support:
- pandas dataframes:
- csv
- parquet
- onnx.ModelProto instances

## Dependencies

Expand Down
9 changes: 8 additions & 1 deletion dummio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

from dummio import json as json
from dummio import text as text
from dummio.constants import ModuleProtocol as ModuleProtocol

try:
from dummio import yaml as yaml
except ImportError:
# yaml is an optional dependency
# this would require an optional yaml dependency such as pyyaml
pass

try:
from dummio import onnx as onnx
except ImportError:
# this would require the optional dependency onnx
pass
20 changes: 19 additions & 1 deletion dummio/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
"""Constants for dummio."""

import typing
from pathlib import Path
from typing import Any, TypeAlias
from typing import Any, Protocol, TypeAlias

PathType: TypeAlias = str | Path
AnyDict: TypeAlias = dict[Any, Any]

# pyright expect a type var to be used at least twice within a single method. It's having
# trouble respecting how it's use *accross* methods of a class.
T = typing.TypeVar("T") # pyright: ignore

DEFAULT_ENCODING = "utf-8"
DEFAULT_WRITE_MODE = "w"


@typing.runtime_checkable
class ModuleProtocol(Protocol):
"""Protocol for dummio's IO modules."""

def save(self, data: T, *, filepath: PathType) -> None: # pyright: ignore[reportInvalidTypeVarUse]
"""Declares the signature of an IO module save method."""
...

def load(self, filepath: PathType) -> T: # pyright: ignore[reportInvalidTypeVarUse]
"""Declares the signature of an IO module load method."""
...
28 changes: 28 additions & 0 deletions dummio/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""IO methods for sklearn models using ONNX serialization."""

import onnx

from dummio.constants import PathType


def save(data: onnx.ModelProto, *, filepath: PathType) -> None:
"""Saves a sklearn model to a file using ONNX serialization.
Args:
data: Data to save. This needs to be an sklearn model.
filepath: Path to save the data.
"""
byte_str = data.SerializeToString()
with open(filepath, "wb") as file:
file.write(byte_str)


def load(filepath: PathType) -> onnx.ModelProto:
"""Loads a sklearn model from a file using ONNX serialization.
Args:
filepath: Path to read the data.
"""
with open(filepath, "rb") as file:
byte_str = file.read()
return onnx.load_model_from_string(byte_str)
7 changes: 0 additions & 7 deletions dummio/pandas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
try:
import pandas

del pandas
except ImportError:
raise ImportError("Please install pandas to use dummio.pandas")

from dummio.pandas import df_csv as df_csv
from dummio.pandas import df_parquet as df_parquet
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dummio"
version = "1.0.0"
version = "1.1.0"
description = "Easiest-possible IO for basic file types."
authors = [{ name = "Zach Kurtz", email = "[email protected]" }]
readme = "README.md"
Expand All @@ -12,11 +12,14 @@ dev = [
"pyright >=1.1.378",
"ruff >=0.7.4",
"pytest >=8.3.2",
"scikit-learn >=1.0.2",
"skl2onnx >=1.10.1",
]
extras = [
"fastparquet>=2024.11.0",
"pandas>=1.5.0",
"pyyaml>=6.0.2",
"onnx>=1.10.1",
]

[project.urls]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_assert_module_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Assert that every IO module implements the ModuleProtocol."""

import importlib

from dummio import ModuleProtocol

IO_MODULES = [
"dummio.json",
"dummio.onnx",
"dummio.text",
"dummio.yaml",
"dummio.pandas.df_csv",
"dummio.pandas.df_parquet",
]


def test_assert_module_protocol() -> None:
for module_name in IO_MODULES:
module = importlib.import_module(module_name)
assert isinstance(module, ModuleProtocol)
28 changes: 28 additions & 0 deletions tests/test_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

from dummio import onnx as onnx_io


def test_onnx_io_cycle(tmp_path) -> None:
# Empty model:
model = onnx.ModelProto()
filepath = tmp_path / "model.onnx"
onnx_io.save(model, filepath=filepath)
loaded_model = onnx_io.load(filepath=filepath)
assert model.SerializeToString() == loaded_model.SerializeToString()

# Minimal sklearn model:
iris = load_iris()
clr = LogisticRegression(solver="saga", max_iter=10000)
clr.fit(iris.data, iris.target) # pyright: ignore[reportAttributeAccessIssue]
initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
assert isinstance(onx, onnx.ModelProto) # pyright wasn't sure
filepath = tmp_path / "model.onnx"
onnx_io.save(onx, filepath=filepath)
loaded_onx = onnx_io.load(filepath=filepath)
assert onx.SerializeToString() == loaded_onx.SerializeToString()
2 changes: 0 additions & 2 deletions tests/test_pandas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""Test IO methods for tabulare data types."""

from pathlib import Path
from types import ModuleType

Expand Down
Loading

0 comments on commit 6120690

Please sign in to comment.