Skip to content

Commit

Permalink
Merged version 1.1.23
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Feb 20, 2019
2 parents c1e3f54 + d4c5685 commit e061e7c
Show file tree
Hide file tree
Showing 21 changed files with 206 additions and 139 deletions.
6 changes: 3 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-converter</artifactId>
<version>1.3.4</version>
<version>1.3.5</version>
</dependency>

<dependency>
Expand All @@ -90,13 +90,13 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.4</version>
<version>1.4.7</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-test</artifactId>
<version>1.4.4</version>
<version>1.4.7</version>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/jpmml/sparkml/MatrixUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ public class MatrixUtil {
private MatrixUtil(){
}

public void checkColumns(int columns, Matrix matrix){

if(matrix.numCols() != columns){
throw new IllegalArgumentException("Expected " + columns + " column(s), got " + matrix.numCols() + " column(s)");
}
}

static
public void checkRows(int rows, Matrix matrix){

if(matrix.numRows() != rows){
throw new IllegalArgumentException("Expected " + rows + " row(s), got " + matrix.numRows() + " row(s)");
}
}

static
public List<Double> getRow(Matrix matrix, int row){
List<Double> result = new ArrayList<>();
Expand Down
55 changes: 19 additions & 36 deletions src/main/java/org/jpmml/sparkml/ModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.mining.Segmentation.MultipleModelMethod;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
Expand All @@ -46,6 +43,7 @@
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;

abstract
public class ModelConverter<T extends Model<T> & HasFeaturesCol & HasPredictionCol> extends TransformerConverter<T> {
Expand Down Expand Up @@ -136,12 +134,11 @@ public Schema encodeSchema(SparkMLEncoder encoder){
if(model instanceof ClassificationModel){
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>)model;

int numClasses = classificationModel.numClasses();

CategoricalLabel categoricalLabel = (CategoricalLabel)label;

int numClasses = classificationModel.numClasses();
if(numClasses != categoricalLabel.size()){
throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
}
SchemaUtil.checkSize(numClasses, categoricalLabel);
}

String featuresCol = model.getFeaturesCol();
Expand All @@ -152,13 +149,15 @@ public Schema encodeSchema(SparkMLEncoder encoder){
PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>)model;

int numFeatures = predictionModel.numFeatures();
if(numFeatures != -1 && features.size() != numFeatures){
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
if(numFeatures != -1){
SchemaUtil.checkSize(numFeatures, features);
}
}

Schema result = new Schema(label, features);

SchemaUtil.checkSchema(result);

return result;
}

Expand All @@ -175,39 +174,23 @@ public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder){

List<OutputField> sparkOutputFields = registerOutputFields(label, encoder);
if(sparkOutputFields != null && sparkOutputFields.size() > 0){
org.dmg.pmml.Model lastModel = getLastModel(model);

Output output = ModelUtil.ensureOutput(lastModel);

List<OutputField> outputFields = output.getOutputFields();

outputFields.addAll(0, sparkOutputFields);
}
Output output;

return model;
}

protected org.dmg.pmml.Model getLastModel(org.dmg.pmml.Model model){
if(model instanceof MiningModel){
MiningModel miningModel = (MiningModel)model;

if(model instanceof MiningModel){
MiningModel miningModel = (MiningModel)model;
org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(miningModel);

Segmentation segmentation = miningModel.getSegmentation();
output = ModelUtil.ensureOutput(finalModel);
} else

MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
switch(multipleModelMethod){
case MODEL_CHAIN:
List<Segment> segments = segmentation.getSegments();
{
output = ModelUtil.ensureOutput(model);
}

if(segments.size() > 0){
Segment lastSegment = segments.get(segments.size() - 1);
List<OutputField> outputFields = output.getOutputFields();

return lastSegment.getModel();
}
break;
default:
break;
}
outputFields.addAll(0, sparkOutputFields);
}

return model;
Expand Down
65 changes: 65 additions & 0 deletions src/main/java/org/jpmml/sparkml/SchemaUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2019 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sparkml;

import java.util.List;
import java.util.Objects;

import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;

public class SchemaUtil {

private SchemaUtil(){
}

static
public void checkSchema(Schema schema){
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

if(label != null){

for(Feature feature : features){

if(Objects.equals(label.getName(), feature.getName())){
throw new IllegalArgumentException("Label column '" + label.getName() + "' is contained in the list of feature columns");
}
}
}
}

static
public void checkSize(int size, CategoricalLabel categoricalLabel){

if(categoricalLabel.size() != size){
throw new IllegalArgumentException("Expected " + size + " target categories, got " + categoricalLabel.size() + " target categories");
}
}

static
public void checkSize(int size, List<? extends Feature> features){

if(features.size() != size){
throw new IllegalArgumentException("Expected " + size + " feature(s), got " + features.size() + " feature(s)");
}
}
}
19 changes: 2 additions & 17 deletions src/main/java/org/jpmml/sparkml/SparkMLEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,6 @@ public boolean hasFeatures(String column){
return this.columnFeatures.containsKey(column);
}

public List<Feature> getSchemaFeatures(){
StructType schema = getSchema();

List<Feature> result = new ArrayList<>();

StructField[] fields = schema.fields();
for(StructField field : fields){
Feature feature = getOnlyFeature(field.name());

result.add(feature);
}

return result;
}

public Feature getOnlyFeature(String column){
List<Feature> features = getFeatures(column);

Expand Down Expand Up @@ -141,15 +126,15 @@ public void putFeatures(String column, List<Feature> features){
if(existingFeatures != null && existingFeatures.size() > 0){

if(features.size() != existingFeatures.size()){
throw new IllegalArgumentException("Expected " + existingFeatures.size() + " features, got " + features.size() + " features");
throw new IllegalArgumentException("Expected " + existingFeatures.size() + " feature(s), got " + features.size() + " feature(s)");
}

for(int i = 0; i < existingFeatures.size(); i++){
Feature existingFeature = existingFeatures.get(i);
Feature feature = features.get(i);

if(!(feature.getName()).equals(existingFeature.getName())){
throw new IllegalArgumentException();
throw new IllegalArgumentException("Expected feature column '" + existingFeature.getName() + "', got feature column '" + feature.getName() + "'");
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/jpmml/sparkml/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ public class VectorUtil {
private VectorUtil(){
}

static
public void checkSize(int size, Vector... vectors){

for(Vector vector : vectors){

if(vector.size() != size){
throw new IllegalArgumentException("Expected " + size + " element(s), got " + vector.size() + " element(s)");
}
}
}

static
public List<Double> toList(Vector vector){
DenseVector denseVector = vector.toDense();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public List<Feature> encodeFeatures(SparkMLEncoder encoder){
String term = vocabulary[i];

if(TermUtil.hasPunctuation(term)){
throw new IllegalArgumentException(term);
throw new IllegalArgumentException("Punctuated vocabulary terms (" + term + ") are not supported");
}

result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.TermFeature;
import org.jpmml.sparkml.VectorUtil;
import org.jpmml.sparkml.WeightedTermFeature;

public class IDFModelConverter extends FeatureConverter<IDFModel> {
Expand All @@ -41,12 +42,11 @@ public IDFModelConverter(IDFModel transformer){
public List<Feature> encodeFeatures(SparkMLEncoder encoder){
IDFModel transformer = getTransformer();

Vector idf = transformer.idf();

List<Feature> features = encoder.getFeatures(transformer.getInputCol());

Vector idf = transformer.idf();
if(idf.size() != features.size()){
throw new IllegalArgumentException();
}
VectorUtil.checkSize(features.size(), idf);

List<Feature> result = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.VectorUtil;

public class MaxAbsScalerModelConverter extends FeatureConverter<MaxAbsScalerModel> {

Expand All @@ -44,12 +45,11 @@ public MaxAbsScalerModelConverter(MaxAbsScalerModel transformer){
public List<Feature> encodeFeatures(SparkMLEncoder encoder){
MaxAbsScalerModel transformer = getTransformer();

Vector maxAbs = transformer.maxAbs();

List<Feature> features = encoder.getFeatures(transformer.getInputCol());

Vector maxAbs = transformer.maxAbs();
if(maxAbs.size() != features.size()){
throw new IllegalArgumentException();
}
VectorUtil.checkSize(features.size(), maxAbs);

List<Feature> result = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.VectorUtil;

public class MinMaxScalerModelConverter extends FeatureConverter<MinMaxScalerModel> {

Expand All @@ -47,17 +48,12 @@ public List<Feature> encodeFeatures(SparkMLEncoder encoder){
double rescaleFactor = (transformer.getMax() - transformer.getMin());
double rescaleConstant = transformer.getMin();

List<Feature> features = encoder.getFeatures(transformer.getInputCol());

Vector originalMax = transformer.originalMax();
if(originalMax.size() != features.size()){
throw new IllegalArgumentException();
}

Vector originalMin = transformer.originalMin();
if(originalMin.size() != features.size()){
throw new IllegalArgumentException();
}

List<Feature> features = encoder.getFeatures(transformer.getInputCol());

VectorUtil.checkSize(features.size(), originalMax, originalMin);

List<Feature> result = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.MatrixUtil;
import org.jpmml.sparkml.SparkMLEncoder;

public class PCAModelConverter extends FeatureConverter<PCAModel> {
Expand All @@ -45,12 +46,11 @@ public PCAModelConverter(PCAModel transformer){
public List<Feature> encodeFeatures(SparkMLEncoder encoder){
PCAModel transformer = getTransformer();

DenseMatrix pc = transformer.pc();

List<Feature> features = encoder.getFeatures(transformer.getInputCol());

DenseMatrix pc = transformer.pc();
if(pc.numRows() != features.size()){
throw new IllegalArgumentException();
}
MatrixUtil.checkRows(features.size(), pc);

List<Feature> result = new ArrayList<>();

Expand Down
Loading

0 comments on commit e061e7c

Please sign in to comment.