diff --git a/sklearn2pmml/__init__.py b/sklearn2pmml/__init__.py index 07eb98f..82857a2 100644 --- a/sklearn2pmml/__init__.py +++ b/sklearn2pmml/__init__.py @@ -1,4 +1,5 @@ from pandas import CategoricalDtype +from pathlib import Path try: from sklearn_pandas import DataFrameMapper except ImportError: @@ -242,7 +243,7 @@ def sklearn2pmml(estimator, pmml_path, with_repr = False, pmml_schema = None, ja Parameters: ---------- - estimator: BaseEstimator + estimator: BaseEstimator or path-like The estimator or pipeline object. pmml_path: string @@ -285,36 +286,47 @@ def sklearn2pmml(estimator, pmml_path, with_repr = False, pmml_schema = None, ja print("dill: {0}".format(dill.__version__)) print("joblib: {0}".format(joblib.__version__)) print("{0}: {1}".format(java_version[0], java_version[1])) - if not _is_supported(estimator): - raise TypeError("The estimator object is not an instance of {0}".format(BaseEstimator.__name__)) - # if isinstance(estimator, Pipeline): - if hasattr(estimator, "_final_estimator"): - final_estimator = estimator._final_estimator - else: - final_estimator = estimator + dumps = [] try: - java_args = ["-cp", os.pathsep.join(_classpath(user_classpath)), "com.sklearn2pmml.Main"] - if with_repr: - estimator.repr_ = repr(estimator) - # if isinstance(final_estimator, H2OEstimator): - if hasattr(final_estimator, "download_mojo"): - if dump_flavour != "dill": - warnings.warn("Changing dump flavour to dill") - dump_flavour = "dill" - # Avoid MOJO (re-)download if the indicator attribute is set - if not (hasattr(final_estimator, "_mojo_path") or hasattr(final_estimator, "_mojo_bytes")): - mojo_path = final_estimator.download_mojo() - dumps.append(mojo_path) - final_estimator._mojo_path = mojo_path - if dump_flavour == "dill": - pkl_path = _dill_dump(estimator, "estimator") - elif dump_flavour == "joblib": - pkl_path = _joblib_dump(estimator, "estimator") + if isinstance(estimator, (str, Path)): + if with_repr: + warnings.warn("Ignoring 'with_repr' flag") + + pkl_path = str(estimator) else: - raise ValueError("Dump flavour {0} not in {1}".format(dump_flavour, ["dill", "joblib"])) + if not _is_supported(estimator): + raise TypeError("The estimator object is not an instance of {0}".format(BaseEstimator.__name__)) + + if with_repr: + estimator.repr_ = repr(estimator) + + # if isinstance(estimator, Pipeline): + if hasattr(estimator, "_final_estimator"): + final_estimator = estimator._final_estimator + else: + final_estimator = estimator + # if isinstance(final_estimator, H2OEstimator): + if hasattr(final_estimator, "download_mojo"): + if dump_flavour != "dill": + warnings.warn("Changing dump flavour to dill") + dump_flavour = "dill" + # Avoid MOJO (re-)download if the indicator attribute is set + if not (hasattr(final_estimator, "_mojo_path") or hasattr(final_estimator, "_mojo_bytes")): + mojo_path = final_estimator.download_mojo() + final_estimator._mojo_path = mojo_path + dumps.append(mojo_path) + + if dump_flavour == "dill": + pkl_path = _dill_dump(estimator, "estimator") + elif dump_flavour == "joblib": + pkl_path = _joblib_dump(estimator, "estimator") + else: + raise ValueError("Dump flavour {0} not in {1}".format(dump_flavour, ["dill", "joblib"])) + dumps.append(pkl_path) + + java_args = ["-cp", os.pathsep.join(_classpath(user_classpath)), "com.sklearn2pmml.Main"] java_args.extend(["--pkl-input", pkl_path]) - dumps.append(pkl_path) java_args.extend(["--pmml-output", pmml_path]) if pmml_schema: java_args.extend(["--pmml-schema", pmml_schema]) diff --git a/sklearn2pmml/cli.py b/sklearn2pmml/cli.py index 7a205c1..efe6e55 100644 --- a/sklearn2pmml/cli.py +++ b/sklearn2pmml/cli.py @@ -1,8 +1,6 @@ from argparse import ArgumentParser from sklearn2pmml import __version__, sklearn2pmml -import joblib - def main(): version = "SkLearn2PMML {}".format(__version__) @@ -14,6 +12,4 @@ def main(): args = parser.parse_args() - estimator = joblib.load(args.input) - - sklearn2pmml(estimator, pmml_path = args.output, pmml_schema = args.schema) + sklearn2pmml(args.input, pmml_path = args.output, pmml_schema = args.schema)