diff --git a/src/main/java/org/jpmml/sparkml/ModelConverter.java b/src/main/java/org/jpmml/sparkml/ModelConverter.java index 3f1e35c3..c23f9767 100644 --- a/src/main/java/org/jpmml/sparkml/ModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/ModelConverter.java @@ -156,6 +156,8 @@ public Schema encodeSchema(SparkMLEncoder encoder){ Schema result = new Schema(label, features); + SchemaUtil.checkSchema(result); + return result; } diff --git a/src/main/java/org/jpmml/sparkml/SchemaUtil.java b/src/main/java/org/jpmml/sparkml/SchemaUtil.java index b89f53a3..8b00871e 100644 --- a/src/main/java/org/jpmml/sparkml/SchemaUtil.java +++ b/src/main/java/org/jpmml/sparkml/SchemaUtil.java @@ -19,15 +19,34 @@ package org.jpmml.sparkml; import java.util.List; +import java.util.Objects; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.Feature; +import org.jpmml.converter.Label; +import org.jpmml.converter.Schema; public class SchemaUtil { private SchemaUtil(){ } + static + public void checkSchema(Schema schema){ + Label label = schema.getLabel(); + List features = schema.getFeatures(); + + if(label != null){ + + for(Feature feature : features){ + + if(Objects.equals(label.getName(), feature.getName())){ + throw new IllegalArgumentException("Label column '" + label.getName() + "' is contained in the list of feature columns"); + } + } + } + } + static public void checkSize(int size, CategoricalLabel categoricalLabel){ diff --git a/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java b/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java index 99bef8fe..fb8b9cbe 100644 --- a/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java +++ b/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java @@ -59,21 +59,6 @@ public boolean hasFeatures(String column){ return this.columnFeatures.containsKey(column); } - public List getSchemaFeatures(){ - StructType schema = getSchema(); - - List result = new ArrayList<>(); - - StructField[] fields = schema.fields(); - for(StructField field : fields){ - Feature feature = getOnlyFeature(field.name()); - - result.add(feature); - } - - return result; - } - public Feature getOnlyFeature(String column){ List features = getFeatures(column); @@ -149,7 +134,7 @@ public void putFeatures(String column, List features){ Feature feature = features.get(i); if(!(feature.getName()).equals(existingFeature.getName())){ - throw new IllegalArgumentException("Expected '" + existingFeature.getName() + "' feature, got '" + feature.getName() + "' feature"); + throw new IllegalArgumentException("Expected feature column '" + existingFeature.getName() + "', got feature column '" + feature.getName() + "'"); } } }