-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from zkurtz/io-protocol
add onnx
- Loading branch information
Showing
10 changed files
with
280 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ So far we support: | |
- pandas dataframes: | ||
- csv | ||
- parquet | ||
- onnx.ModelProto instances | ||
|
||
## Dependencies | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,28 @@ | ||
"""Constants for dummio.""" | ||
|
||
import typing | ||
from pathlib import Path | ||
from typing import Any, TypeAlias | ||
from typing import Any, Protocol, TypeAlias | ||
|
||
PathType: TypeAlias = str | Path | ||
AnyDict: TypeAlias = dict[Any, Any] | ||
|
||
# pyright expect a type var to be used at least twice within a single method. It's having | ||
# trouble respecting how it's use *accross* methods of a class. | ||
T = typing.TypeVar("T") # pyright: ignore | ||
|
||
DEFAULT_ENCODING = "utf-8" | ||
DEFAULT_WRITE_MODE = "w" | ||
|
||
|
||
@typing.runtime_checkable | ||
class ModuleProtocol(Protocol): | ||
"""Protocol for dummio's IO modules.""" | ||
|
||
def save(self, data: T, *, filepath: PathType) -> None: # pyright: ignore[reportInvalidTypeVarUse] | ||
"""Declares the signature of an IO module save method.""" | ||
... | ||
|
||
def load(self, filepath: PathType) -> T: # pyright: ignore[reportInvalidTypeVarUse] | ||
"""Declares the signature of an IO module load method.""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""IO methods for sklearn models using ONNX serialization.""" | ||
|
||
import onnx | ||
|
||
from dummio.constants import PathType | ||
|
||
|
||
def save(data: onnx.ModelProto, *, filepath: PathType) -> None: | ||
"""Saves a sklearn model to a file using ONNX serialization. | ||
Args: | ||
data: Data to save. This needs to be an sklearn model. | ||
filepath: Path to save the data. | ||
""" | ||
byte_str = data.SerializeToString() | ||
with open(filepath, "wb") as file: | ||
file.write(byte_str) | ||
|
||
|
||
def load(filepath: PathType) -> onnx.ModelProto: | ||
"""Loads a sklearn model from a file using ONNX serialization. | ||
Args: | ||
filepath: Path to read the data. | ||
""" | ||
with open(filepath, "rb") as file: | ||
byte_str = file.read() | ||
return onnx.load_model_from_string(byte_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,2 @@ | ||
try: | ||
import pandas | ||
|
||
del pandas | ||
except ImportError: | ||
raise ImportError("Please install pandas to use dummio.pandas") | ||
|
||
from dummio.pandas import df_csv as df_csv | ||
from dummio.pandas import df_parquet as df_parquet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[project] | ||
name = "dummio" | ||
version = "1.0.0" | ||
version = "1.1.0" | ||
description = "Easiest-possible IO for basic file types." | ||
authors = [{ name = "Zach Kurtz", email = "[email protected]" }] | ||
readme = "README.md" | ||
|
@@ -12,11 +12,14 @@ dev = [ | |
"pyright >=1.1.378", | ||
"ruff >=0.7.4", | ||
"pytest >=8.3.2", | ||
"scikit-learn >=1.0.2", | ||
"skl2onnx >=1.10.1", | ||
] | ||
extras = [ | ||
"fastparquet>=2024.11.0", | ||
"pandas>=1.5.0", | ||
"pyyaml>=6.0.2", | ||
"onnx>=1.10.1", | ||
] | ||
|
||
[project.urls] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
"""Assert that every IO module implements the ModuleProtocol.""" | ||
|
||
import importlib | ||
|
||
from dummio import ModuleProtocol | ||
|
||
IO_MODULES = [ | ||
"dummio.json", | ||
"dummio.onnx", | ||
"dummio.text", | ||
"dummio.yaml", | ||
"dummio.pandas.df_csv", | ||
"dummio.pandas.df_parquet", | ||
] | ||
|
||
|
||
def test_assert_module_protocol() -> None: | ||
for module_name in IO_MODULES: | ||
module = importlib.import_module(module_name) | ||
assert isinstance(module, ModuleProtocol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import onnx | ||
from skl2onnx import convert_sklearn | ||
from skl2onnx.common.data_types import FloatTensorType | ||
from sklearn.datasets import load_iris | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
from dummio import onnx as onnx_io | ||
|
||
|
||
def test_onnx_io_cycle(tmp_path) -> None: | ||
# Empty model: | ||
model = onnx.ModelProto() | ||
filepath = tmp_path / "model.onnx" | ||
onnx_io.save(model, filepath=filepath) | ||
loaded_model = onnx_io.load(filepath=filepath) | ||
assert model.SerializeToString() == loaded_model.SerializeToString() | ||
|
||
# Minimal sklearn model: | ||
iris = load_iris() | ||
clr = LogisticRegression(solver="saga", max_iter=10000) | ||
clr.fit(iris.data, iris.target) # pyright: ignore[reportAttributeAccessIssue] | ||
initial_type = [("float_input", FloatTensorType([None, 4]))] | ||
onx = convert_sklearn(clr, initial_types=initial_type) | ||
assert isinstance(onx, onnx.ModelProto) # pyright wasn't sure | ||
filepath = tmp_path / "model.onnx" | ||
onnx_io.save(onx, filepath=filepath) | ||
loaded_onx = onnx_io.load(filepath=filepath) | ||
assert onx.SerializeToString() == loaded_onx.SerializeToString() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
"""Test IO methods for tabulare data types.""" | ||
|
||
from pathlib import Path | ||
from types import ModuleType | ||
|
||
|
Oops, something went wrong.