Skip to content

Commit

Permalink
Checking the sanity of supervised learning model schemas. Fixes #47
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Feb 20, 2019
1 parent 4c2a4ae commit 81ffd6f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/main/java/org/jpmml/sparkml/ModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ public Schema encodeSchema(SparkMLEncoder encoder){

Schema result = new Schema(label, features);

SchemaUtil.checkSchema(result);

return result;
}

Expand Down
19 changes: 19 additions & 0 deletions src/main/java/org/jpmml/sparkml/SchemaUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,34 @@
package org.jpmml.sparkml;

import java.util.List;
import java.util.Objects;

import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;

public class SchemaUtil {

private SchemaUtil(){
}

static
public void checkSchema(Schema schema){
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

if(label != null){

for(Feature feature : features){

if(Objects.equals(label.getName(), feature.getName())){
throw new IllegalArgumentException("Label column '" + label.getName() + "' is contained in the list of feature columns");
}
}
}
}

static
public void checkSize(int size, CategoricalLabel categoricalLabel){

Expand Down
17 changes: 1 addition & 16 deletions src/main/java/org/jpmml/sparkml/SparkMLEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,6 @@ public boolean hasFeatures(String column){
return this.columnFeatures.containsKey(column);
}

public List<Feature> getSchemaFeatures(){
StructType schema = getSchema();

List<Feature> result = new ArrayList<>();

StructField[] fields = schema.fields();
for(StructField field : fields){
Feature feature = getOnlyFeature(field.name());

result.add(feature);
}

return result;
}

public Feature getOnlyFeature(String column){
List<Feature> features = getFeatures(column);

Expand Down Expand Up @@ -149,7 +134,7 @@ public void putFeatures(String column, List<Feature> features){
Feature feature = features.get(i);

if(!(feature.getName()).equals(existingFeature.getName())){
throw new IllegalArgumentException("Expected '" + existingFeature.getName() + "' feature, got '" + feature.getName() + "' feature");
throw new IllegalArgumentException("Expected feature column '" + existingFeature.getName() + "', got feature column '" + feature.getName() + "'");
}
}
}
Expand Down

0 comments on commit 81ffd6f

Please sign in to comment.