Skip to content

Commit

Permalink
Added support for row-level exception handling. Fixes #26
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Dec 14, 2024
1 parent e3a5f30 commit f30da79
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 10 deletions.
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,20 @@ print("Output fields: " + str([outputField.getName() for outputField in outputFi
Evaluating a single data record:

```python
from jpmml_evaluator import JavaError

arguments = {
"Sepal_Length" : 5.1,
"Sepal_Width" : 3.5,
"Petal_Length" : 1.4,
"Petal_Width" : 0.2
}

results = evaluator.evaluate(arguments)
print(results)
try:
results = evaluator.evaluate(arguments)
print(results)
except JavaError as je:
pass
```

Evaluating a collection of data records:
Expand All @@ -132,8 +137,23 @@ import pandas

arguments_df = pandas.read_csv("Iris.csv", sep = ",")

results_df = evaluator.evaluateAll(arguments_df)
results_df = evaluator.evaluateAll(arguments_df, error_col = "errors")
print(results_df)

# The error column is added to the results DataFrame only if there was something erroneous happening
errors = df_results["errors"] if "errors" in results_df.columns else None
if errors is not None:
pass
```

Alternatively, getting the results DataFrame and errors Series as separate objects:

```python
results_df, errors = evaluator.evaluateAll(arguments_df, error_col = None)
print(results_df)

if errors is not None:
pass
```

# License #
Expand Down
16 changes: 13 additions & 3 deletions jpmml_evaluator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle

from abc import abstractmethod, abstractclassmethod, ABC
from pandas import DataFrame
from pandas import DataFrame, Series
from pathlib import Path

import numpy
Expand Down Expand Up @@ -163,7 +163,7 @@ def evaluate(self, arguments, nan_as_missing = True):
results = self.backend.loads(results)
return results

def evaluateAll(self, arguments_df, nan_as_missing = True):
def evaluateAll(self, arguments_df, nan_as_missing = True, error_col = "errors"):
arguments_df = _canonicalizeAll(arguments_df, nan_as_missing = nan_as_missing)
columns = arguments_df.columns.tolist()
data = []
Expand All @@ -182,12 +182,22 @@ def evaluateAll(self, arguments_df, nan_as_missing = True):
results_dict = self.backend.loads(results)
columns = results_dict["columns"]
data = results_dict["data"]
errors = results_dict["errors"]
results_df = DataFrame()
for idx, column in enumerate(columns):
results_df[column] = data[idx]
if len(arguments_df) == len(results_df):
results_df.index = arguments_df.index.copy()
return results_df
if errors is not None:
errors = Series(errors, name = error_col, dtype = str)
if len(arguments_df) == len(errors):
errors.index = arguments_df.index.copy()
if error_col:
if errors is not None:
results_df[error_col] = errors
return results_df
else:
return (results_df, errors)

def predict(self, X):
return self.evaluateAll(X)
Expand Down
Binary file modified jpmml_evaluator/resources/jpmml-evaluator-python-1.4-SNAPSHOT.jar
Binary file not shown.
30 changes: 28 additions & 2 deletions jpmml_evaluator/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def workflow(self, backend, lax):
print(arguments_df.head(5))

results_df = evaluator.evaluateAll(arguments_df)
print(results_df.head(5))
#print(results_df.head(5))

self.assertEqual((150, 5), results_df.shape)
self.assertEqual(arguments_df.index.tolist(), results_df.index.tolist())
Expand All @@ -170,7 +170,7 @@ def workflow(self, backend, lax):

evaluator.suppressResultFields([reportOutputField])

results_df = evaluator.evaluateAll(arguments_df)
results_df, errors = evaluator.evaluateAll(arguments_df, error_col = None)

self.assertEqual((150, 4), results_df.shape)
self.assertEqual(arguments_df.index.tolist(), results_df.index.tolist())
Expand All @@ -187,4 +187,30 @@ def workflow(self, backend, lax):
probabilityOutputFieldNames = [probabilityOutputField.getName() for probabilityOutputField in probabilityOutputFields]
self.assertTrue(numpy.allclose(expected_results_df[probabilityOutputFieldNames], results_df[probabilityOutputFieldNames], rtol = 1e-13, atol = 1e-13))

self.assertIsNone(errors)

arguments_df.iloc[13, :] = "error"

results_df = evaluator.evaluateAll(arguments_df)

self.assertEqual((150, 5), results_df.shape)
self.assertEqual(arguments_df.index.tolist(), results_df.index.tolist())

self.assertEqual(1, results_df["errors"].count())
self.assertEqual(None, results_df["Species"][13])
self.assertEqual("org.jpmml.evaluator.ValueCheckException: Field \"Petal.Length\" cannot accept invalid value \"error\"", results_df["errors"][13])

results_df, errors = evaluator.evaluateAll(arguments_df, error_col = None)

self.assertEqual((150, 4), results_df.shape)
self.assertEqual(arguments_df.index.tolist(), results_df.index.tolist())

self.assertEqual(None, results_df["Species"][13])

self.assertEqual((150,), errors.shape)
self.assertEqual(arguments_df.index.tolist(), errors.index.tolist())

self.assertEqual(1, errors.count())
self.assertEqual("org.jpmml.evaluator.ValueCheckException: Field \"Petal.Length\" cannot accept invalid value \"error\"", errors[13])

evaluator.suppressResultFields(None)
33 changes: 31 additions & 2 deletions src/main/java/org/jpmml/evaluator/python/PythonUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import net.razorvine.pickle.Pickler;
import net.razorvine.pickle.Unpickler;
Expand Down Expand Up @@ -128,9 +129,13 @@ public Object put(String key, Object value){

resultsWriter.next();

Map<String, ?> results = evaluator.evaluate(arguments);
try {
Map<String, ?> results = evaluator.evaluate(arguments);

resultsWriter.putAll(results);
resultsWriter.putAll(results);
} catch(Exception e){
resultsWriter.put(e);
}
}

resultsTable.canonicalize();
Expand Down Expand Up @@ -240,13 +245,37 @@ private Table parseDict(Map<String, ?> dict){
data.add(values);
}

List<Exception> exceptions = table.getExceptions();
List<String> errors = null;

if(containsNonNull(exceptions)){
errors = exceptions.stream()
.map(exception -> (exception != null ? exception.toString() : null))
.collect(Collectors.toList());
}

Map<String, List<?>> result = new HashMap<>();
result.put("columns", columns);
result.put("data", data);
result.put("errors", errors);

return result;
}

static
private <E> boolean containsNonNull(List<E> values){

for(int i = 0; i < values.size(); i++){
E value = values.get(i);

if(value != null){
return true;
}
}

return false;
}

static
private Object unpickle(byte[] bytes) throws IOException {
Unpickler unpickler = new Unpickler();
Expand Down

0 comments on commit f30da79

Please sign in to comment.