Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add onnx #10

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ So far we support:
- pandas dataframes:
- csv
- parquet
- onnx.ModelProto instances

## Dependencies

Expand Down
9 changes: 8 additions & 1 deletion dummio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

from dummio import json as json
from dummio import text as text
from dummio.constants import ModuleProtocol as ModuleProtocol

try:
from dummio import yaml as yaml
except ImportError:
# yaml is an optional dependency
# this would require an optional yaml dependency such as pyyaml
pass

try:
from dummio import onnx as onnx
except ImportError:
# this would require the optional dependency onnx
pass
20 changes: 19 additions & 1 deletion dummio/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
"""Constants for dummio."""

import typing
from pathlib import Path
from typing import Any, TypeAlias
from typing import Any, Protocol, TypeAlias

PathType: TypeAlias = str | Path
AnyDict: TypeAlias = dict[Any, Any]

# pyright expect a type var to be used at least twice within a single method. It's having
# trouble respecting how it's use *accross* methods of a class.
T = typing.TypeVar("T") # pyright: ignore

DEFAULT_ENCODING = "utf-8"
DEFAULT_WRITE_MODE = "w"


@typing.runtime_checkable
class ModuleProtocol(Protocol):
"""Protocol for dummio's IO modules."""

def save(self, data: T, *, filepath: PathType) -> None: # pyright: ignore[reportInvalidTypeVarUse]
"""Declares the signature of an IO module save method."""
...

def load(self, filepath: PathType) -> T: # pyright: ignore[reportInvalidTypeVarUse]
"""Declares the signature of an IO module load method."""
...
28 changes: 28 additions & 0 deletions dummio/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""IO methods for sklearn models using ONNX serialization."""

import onnx

from dummio.constants import PathType


def save(data: onnx.ModelProto, *, filepath: PathType) -> None:
"""Saves a sklearn model to a file using ONNX serialization.

Args:
data: Data to save. This needs to be an sklearn model.
filepath: Path to save the data.
"""
byte_str = data.SerializeToString()
with open(filepath, "wb") as file:
file.write(byte_str)


def load(filepath: PathType) -> onnx.ModelProto:
"""Loads a sklearn model from a file using ONNX serialization.

Args:
filepath: Path to read the data.
"""
with open(filepath, "rb") as file:
byte_str = file.read()
return onnx.load_model_from_string(byte_str)
7 changes: 0 additions & 7 deletions dummio/pandas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
try:
import pandas

del pandas
except ImportError:
raise ImportError("Please install pandas to use dummio.pandas")

from dummio.pandas import df_csv as df_csv
from dummio.pandas import df_parquet as df_parquet
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dummio"
version = "1.0.0"
version = "1.1.0"
description = "Easiest-possible IO for basic file types."
authors = [{ name = "Zach Kurtz", email = "[email protected]" }]
readme = "README.md"
Expand All @@ -12,11 +12,14 @@ dev = [
"pyright >=1.1.378",
"ruff >=0.7.4",
"pytest >=8.3.2",
"scikit-learn >=1.0.2",
"skl2onnx >=1.10.1",
]
extras = [
"fastparquet>=2024.11.0",
"pandas>=1.5.0",
"pyyaml>=6.0.2",
"onnx>=1.10.1",
]

[project.urls]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_assert_module_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Assert that every IO module implements the ModuleProtocol."""

import importlib

from dummio import ModuleProtocol

IO_MODULES = [
"dummio.json",
"dummio.onnx",
"dummio.text",
"dummio.yaml",
"dummio.pandas.df_csv",
"dummio.pandas.df_parquet",
]


def test_assert_module_protocol() -> None:
for module_name in IO_MODULES:
module = importlib.import_module(module_name)
assert isinstance(module, ModuleProtocol)
28 changes: 28 additions & 0 deletions tests/test_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

from dummio import onnx as onnx_io


def test_onnx_io_cycle(tmp_path) -> None:
# Empty model:
model = onnx.ModelProto()
filepath = tmp_path / "model.onnx"
onnx_io.save(model, filepath=filepath)
loaded_model = onnx_io.load(filepath=filepath)
assert model.SerializeToString() == loaded_model.SerializeToString()

# Minimal sklearn model:
iris = load_iris()
clr = LogisticRegression(solver="saga", max_iter=10000)
clr.fit(iris.data, iris.target) # pyright: ignore[reportAttributeAccessIssue]
initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
assert isinstance(onx, onnx.ModelProto) # pyright wasn't sure
filepath = tmp_path / "model.onnx"
onnx_io.save(onx, filepath=filepath)
loaded_onx = onnx_io.load(filepath=filepath)
assert onx.SerializeToString() == loaded_onx.SerializeToString()
2 changes: 0 additions & 2 deletions tests/test_pandas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""Test IO methods for tabulare data types."""

from pathlib import Path
from types import ModuleType

Expand Down
Loading