Skip to content

Commit

Permalink
Added support for the 'FPGrowth' model type. Fixes #50
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 28, 2021
1 parent 6fd5dd7 commit 6ec0d18
Show file tree
Hide file tree
Showing 14 changed files with 3,139 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Java library and command-line application for converting Apache Spark ML pipelin
* [`classification.NaiveBayesModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/NaiveBayesModel.html)
* [`classification.RandomForestClassificationModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/RandomForestClassificationModel.html)
* [`clustering.KMeansModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/clustering/KMeansModel.html)
* [`fpm.FPGrowthModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/fpm/FPGrowthModel.html)
* [`regression.DecisionTreeRegressionModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/regression/DecisionTreeRegressionModel.html)
* [`regression.GBTRegressionModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/regression/GBTRegressionModel.html)
* [`regression.GeneralizedLinearRegressionModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/regression/GeneralizedLinearRegressionModel.html)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2021 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sparkml;

import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.MiningFunction;

abstract
public class AssociationRulesModelConverter<T extends Model<T> & HasPredictionCol> extends ModelConverter<T> {

public AssociationRulesModelConverter(T model){
super(model);
}

@Override
public MiningFunction getMiningFunction(){
return MiningFunction.ASSOCIATION_RULES;
}
}
35 changes: 35 additions & 0 deletions src/main/java/org/jpmml/sparkml/ItemSetFeature.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2021 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sparkml;

import org.dmg.pmml.Field;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;

public class ItemSetFeature extends Feature {

public ItemSetFeature(SparkMLEncoder encoder, Field<?> field){
super(encoder, field.getName(), field.getDataType());
}

@Override
public ContinuousFeature toContinuousFeature(){
throw new UnsupportedOperationException();
}
}
24 changes: 24 additions & 0 deletions src/main/java/org/jpmml/sparkml/SparkMLEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.association.Item;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.model.visitors.AbstractVisitor;

public class SparkMLEncoder extends ModelEncoder {

Expand All @@ -56,6 +62,24 @@ public SparkMLEncoder(StructType schema, ConverterFactory converterFactory){
setConverterFactory(converterFactory);
}

@Override
public PMML encodePMML(Model model){
PMML pmml = super.encodePMML(model);

Visitor visitor = new AbstractVisitor(){

@Override
public VisitorAction visit(Item item){
item.setField(null);

return super.visit(item);
}
};
visitor.applyTo(pmml);

return pmml;
}

public boolean hasFeatures(String column){
return this.columnFeatures.containsKey(column);
}
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/org/jpmml/sparkml/ZipUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ public void uncompress(ZipFile zipFile, File dir) throws IOException {
File file = new File(dir, entry.getName());

File parentDir = file.getParentFile();
if(!parentDir.mkdirs()){
throw new IOException();
if(!parentDir.exists()){

if(!parentDir.mkdirs()){
throw new IOException(parentDir.getAbsolutePath());
}
}

try(OutputStream os = new FileOutputStream(file)){
Expand Down
172 changes: 172 additions & 0 deletions src/main/java/org/jpmml/sparkml/model/FPGrowthModelConverter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright (c) 2021 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.apache.spark.ml.fpm.FPGrowthModel;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.association.AssociationModel;
import org.dmg.pmml.association.AssociationRule;
import org.dmg.pmml.association.Item;
import org.dmg.pmml.association.ItemRef;
import org.dmg.pmml.association.Itemset;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.sparkml.AssociationRulesModelConverter;
import org.jpmml.sparkml.ItemSetFeature;
import org.jpmml.sparkml.SparkMLEncoder;
import scala.collection.JavaConversions;
import scala.collection.Seq;

public class FPGrowthModelConverter extends AssociationRulesModelConverter<FPGrowthModel> {

public FPGrowthModelConverter(FPGrowthModel model){
super(model);
}

@Override
public List<Feature> getFeatures(SparkMLEncoder encoder){
FPGrowthModel model = getTransformer();

String itemsCol = model.getItemsCol();

// Convert from plural to singular
if(itemsCol.endsWith("s")){
itemsCol = itemsCol.substring(0, itemsCol.length() - 1);
}

DataField dataField = encoder.createDataField(FieldName.create(itemsCol), OpType.CATEGORICAL, DataType.STRING);

Feature feature = new ItemSetFeature(encoder, dataField);

return Collections.singletonList(feature);
}

@Override
public AssociationModel encodeModel(Schema schema){
FPGrowthModel model = getTransformer();

List<? extends Feature> features = schema.getFeatures();

SchemaUtil.checkSize(1, features);

Feature feature = features.get(0);

Map<String, Item> items = new LinkedHashMap<>();
Map<List<String>, Itemset> itemsets = new LinkedHashMap<>();

List<AssociationRule> associationRules = new ArrayList<>();

List<Row> associationRuleRows = (model.associationRules()).collectAsList();
for(Row associationRuleRow : associationRuleRows){
List<String> antecedent = JavaConversions.seqAsJavaList((Seq)associationRuleRow.apply(0));
List<String> consequent = JavaConversions.seqAsJavaList((Seq)associationRuleRow.apply(1));

Double confidence = (Double)associationRuleRow.apply(2);

// XXX
Double lift = 0d;
Double support = 0d;

Itemset antecedentItemset = ensureItemset(feature, antecedent, itemsets, items);
Itemset consequentItemset = ensureItemset(feature, consequent, itemsets, items);

AssociationRule associationRule = new AssociationRule()
.setAntecedent(antecedentItemset.getId())
.setConsequent(consequentItemset.getId());

associationRule = associationRule
.setConfidence(confidence)
.setLift(lift)
.setSupport(support);

associationRules.add(associationRule);
}

// XXX
int numberOfTransactions = 0;

MiningSchema miningSchema = new MiningSchema();

AssociationModel associationModel = new AssociationModel(MiningFunction.ASSOCIATION_RULES, numberOfTransactions, model.getMinSupport(), model.getMinConfidence(), items.size(), itemsets.size(), associationRules.size(), miningSchema)
.setScorable(Boolean.FALSE);

(associationModel.getItems()).addAll(items.values());
(associationModel.getItemsets()).addAll(itemsets.values());
(associationModel.getAssociationRules()).addAll(associationRules);

return associationModel;
}

static
private Itemset ensureItemset(Feature feature, List<String> values, Map<List<String>, Itemset> itemsets, Map<String, Item> items){
Itemset itemset = itemsets.get(values);

if(itemset == null){
itemset = new Itemset(String.valueOf(itemsets.size() + 1));

for(String value : values){
Item item = items.get(value);

if(item == null){
item = new Item(String.valueOf(items.size() + 1), value)
// XXX: See SparkMLEncoder#encodePMML(Model)
.setField(feature.getName());

items.put(value, item);
}

itemset.addItemRefs(new ItemRef(item.getId()));
}

List<ItemRef> itemRefs = itemset.getItemRefs();
if(itemRefs.size() > 1){
Comparator<ItemRef> comparator = new Comparator<ItemRef>(){

@Override
public int compare(ItemRef left, ItemRef right){
int leftId = Integer.parseInt(left.getItemRef());
int rightId = Integer.parseInt(right.getItemRef());

return Integer.compare(leftId, rightId);
}
};

Collections.sort(itemRefs, comparator);
}

itemsets.put(values, itemset);
}

return itemset;
}
}
1 change: 1 addition & 0 deletions src/main/resources/META-INF/sparkml2pmml.properties
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel = org
org.apache.spark.ml.classification.NaiveBayesModel = org.jpmml.sparkml.model.NaiveBayesModelConverter
org.apache.spark.ml.classification.RandomForestClassificationModel = org.jpmml.sparkml.model.RandomForestClassificationModelConverter
org.apache.spark.ml.clustering.KMeansModel = org.jpmml.sparkml.model.KMeansModelConverter
org.apache.spark.ml.fpm.FPGrowthModel = org.jpmml.sparkml.model.FPGrowthModelConverter
org.apache.spark.ml.regression.DecisionTreeRegressionModel = org.jpmml.sparkml.model.DecisionTreeRegressionModelConverter
org.apache.spark.ml.regression.GBTRegressionModel = org.jpmml.sparkml.model.GBTRegressionModelConverter
org.apache.spark.ml.regression.GeneralizedLinearRegressionModel = org.jpmml.sparkml.model.GeneralizedLinearRegressionModelConverter
Expand Down
1 change: 1 addition & 0 deletions src/test/java/org/jpmml/sparkml/Algorithms.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
interface Algorithms {

String DECISION_TREE = "DecisionTree";
String FP_GROWTH = "FPGrowth";
String GBT = "GBT";
String GLM = "GLM";
String K_MEANS = "KMeans";
Expand Down
48 changes: 48 additions & 0 deletions src/test/java/org/jpmml/sparkml/AssociationRulesTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2021 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.sparkml;

import java.util.function.Predicate;

import com.google.common.base.Equivalence;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.association.AssociationModel;
import org.jpmml.evaluator.ResultField;
import org.junit.Test;
import org.spark_project.guava.collect.Iterables;

import static org.junit.Assert.assertTrue;

public class AssociationRulesTest extends SparkMLTest implements Algorithms, Datasets {

@Test
public void evaluateFPGrowthShopping() throws Exception {
Predicate<ResultField> predicate = (resultField -> true);
Equivalence<Object> equivalence = getEquivalence();

try(SparkMLTestBatch batch = (SparkMLTestBatch)createBatch(FP_GROWTH, SHOPPING, predicate, equivalence)){
PMML pmml = batch.getPMML();

Model model = Iterables.getOnlyElement(pmml.getModels());

assertTrue(model instanceof AssociationModel);
}
}
}
1 change: 1 addition & 0 deletions src/test/java/org/jpmml/sparkml/Datasets.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ interface Datasets {
String HOUSING = "Housing";
String IRIS = "Iris";
String SENTIMENT = "Sentiment";
String SHOPPING = "Shopping";
String VISIT = "Visit";

FieldName AUDIT_PROBABILITY_TRUE = FieldNameUtil.create("probability", 1);
Expand Down
Loading

0 comments on commit 6ec0d18

Please sign in to comment.