diff --git a/jpmml_evaluator/__init__.py b/jpmml_evaluator/__init__.py index 8391848..2baa1b4 100644 --- a/jpmml_evaluator/__init__.py +++ b/jpmml_evaluator/__init__.py @@ -1,8 +1,10 @@ +import os import pickle from abc import abstractmethod, abstractclassmethod, ABC from importlib.resources import files from pandas import DataFrame +from pathlib import Path import numpy @@ -228,6 +230,8 @@ def setLocatable(self, locatable = False): return self def loadFile(self, path): + if isinstance(path, Path): + path = str(path) javaFile = self.backend.newObject("java.io.File", path) self.javaModelEvaluatorBuilder.load(javaFile) return self @@ -270,13 +274,14 @@ def make_backend(alias): aliases = ["jpype", "pyjnius", "py4j"] raise ValueError("Java backend alias {0} not in {1}".format(alias, aliases)) -def make_evaluator(path, backend = "jpype", lax = False, locatable = False, reporting = False, transpile = False): +def make_evaluator(obj, backend = "jpype", lax = False, locatable = False, reporting = False, transpile = False): """ Builds an Evaluator based on a PMML file. Parameters: ---------- - path: string - The path to the PMML file in local filesystem. + obj: string or bytes + The object to load. Either a path to a PMML file in local filesystem, + or a PMML string or byte array. backend: JavaBackend or string The Java backend or its alias @@ -303,12 +308,26 @@ def make_evaluator(path, backend = "jpype", lax = False, locatable = False, repo raise TypeError() evaluatorBuilder = LoadingModelEvaluatorBuilder(backend, lax) \ - .setLocatable(locatable) \ - .loadFile(path) + .setLocatable(locatable) + + if isinstance(obj, Path): + evaluatorBuilder = evaluatorBuilder.loadFile(obj) + elif isinstance(obj, str): + if len(obj) < 1024 and os.path.isfile(obj): + evaluatorBuilder = evaluatorBuilder.loadFile(obj) + else: + evaluatorBuilder = evaluatorBuilder.loadString(obj) + elif isinstance(obj, bytes): + evaluatorBuilder = evaluatorBuilder.loadBytes(obj) + else: + raise TypeError() + if reporting: - evaluatorBuilder.setReportingValueFactoryFactory() + evaluatorBuilder = evaluatorBuilder.setReportingValueFactoryFactory() + if transpile: - evaluatorBuilder.transpile(transpile if isinstance(transpile, str) else None) + evaluatorBuilder = evaluatorBuilder.transpile(transpile if isinstance(transpile, str) else None) + return evaluatorBuilder.build() def _package_data_jars(package_data_dir): diff --git a/jpmml_evaluator/tests/__init__.py b/jpmml_evaluator/tests/__init__.py index d4fa29c..fd4cf4f 100644 --- a/jpmml_evaluator/tests/__init__.py +++ b/jpmml_evaluator/tests/__init__.py @@ -1,11 +1,12 @@ import os +from pathlib import Path from unittest import TestCase import numpy import pandas -from jpmml_evaluator import make_backend, make_evaluator, Evaluator, JavaError, LoadingModelEvaluatorBuilder +from jpmml_evaluator import make_backend, make_evaluator, Evaluator, JavaError def _resource(name): return os.path.join(os.path.dirname(__file__), "resources", name) @@ -27,9 +28,15 @@ def workflow(self, backend): self.assertIsInstance(resource, str) - evaluator = LoadingModelEvaluatorBuilder(backend) \ - .loadFile(resource) \ - .build() + evaluator = make_evaluator(resource, backend = backend) + + self.assertIsInstance(evaluator, Evaluator) + + resource_path = Path(resource) + + self.assertIsInstance(resource_path, Path) + + evaluator = make_evaluator(resource_path, backend = backend) self.assertIsInstance(evaluator, Evaluator) @@ -38,9 +45,7 @@ def workflow(self, backend): self.assertIsInstance(resource_bytes, bytes) - evaluator = LoadingModelEvaluatorBuilder(backend) \ - .loadBytes(resource_bytes) \ - .build() + evaluator = make_evaluator(resource_bytes, backend = backend) self.assertIsInstance(evaluator, Evaluator) @@ -49,9 +54,7 @@ def workflow(self, backend): self.assertIsInstance(resource_string, str) - evaluator = LoadingModelEvaluatorBuilder(backend) \ - .loadString(resource_string) \ - .build() + evaluator = make_evaluator(resource_string, backend = backend) self.assertIsInstance(evaluator, Evaluator)