Skip to content

Commit

Permalink
Merge pull request #100 from sanity/staticBuilders
Browse files Browse the repository at this point in the history
Static builders
  • Loading branch information
athawk81 committed Mar 7, 2015
2 parents 0492423 + f14c6c9 commit da04812
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
4) *ANY* change to the master branch (ie. when a feature branch is merged) must
be accompanied by a bump in version number, regardless of how minor the change.
-->
<version>0.6.0</version>
<version>0.6.1</version>

<repositories>
<repository>
Expand Down
17 changes: 14 additions & 3 deletions src/main/java/quickml/supervised/Utils.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package quickml.supervised;

import com.google.common.collect.Lists;
import org.joda.time.DateTime;
import quickml.data.Instance;
import quickml.data.PredictionMap;
import quickml.supervised.crossValidation.PredictionMapResult;
import quickml.supervised.crossValidation.PredictionMapResults;
import quickml.data.ClassifierInstance;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.*;

/**
* Created by alexanderhawk on 7/31/14.
Expand Down Expand Up @@ -63,6 +63,17 @@ public static PredictionMapResults calcResultpredictionsWithoutAttrs(Classifier
return new PredictionMapResults(results);
}

public static void sortTrainingInstancesByTime(List<ClassifierInstance> trainingData, final DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {
Collections.sort(trainingData, new Comparator<ClassifierInstance>() {
@Override
public int compare(ClassifierInstance o1, ClassifierInstance o2) {
DateTime dateTime1 = dateTimeExtractor.extractDateTime(o1);
DateTime dateTime2 = dateTimeExtractor.extractDateTime(o2);
return dateTime1.compareTo(dateTime2);
}
});
}


}

101 changes: 101 additions & 0 deletions src/main/java/quickml/supervised/classifier/StaticBuilders.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package quickml.supervised.classifier;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.javatuples.Pair;
import org.joda.time.DateTime;
import org.joda.time.Period;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.ClassifierInstance;
import quickml.data.OnespotDateTimeExtractor;
import quickml.supervised.Utils;
import quickml.supervised.classifier.decisionTree.TreeBuilder;
import quickml.supervised.classifier.decisionTree.scorers.GiniImpurityScorer;
import quickml.supervised.classifier.decisionTree.scorers.InformationGainScorer;
import quickml.supervised.classifier.downsampling.DownsamplingClassifier;
import quickml.supervised.classifier.downsampling.DownsamplingClassifierBuilder;
import quickml.supervised.classifier.randomForest.RandomForest;
import quickml.supervised.classifier.randomForest.RandomForestBuilder;
import quickml.supervised.crossValidation.ClassifierLossChecker;
import quickml.supervised.crossValidation.data.OutOfTimeData;
import quickml.supervised.crossValidation.lossfunctions.ClassifierLogCVLossFunction;
import quickml.supervised.crossValidation.lossfunctions.ClassifierLossFunction;
import quickml.supervised.crossValidation.lossfunctions.LossFunction;
import quickml.supervised.crossValidation.lossfunctions.WeightedAUCCrossValLossFunction;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;
import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender;
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer;
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizerBuilder;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.MonotonicConvergenceRecommender;

import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

import static java.util.Arrays.asList;
import static quickml.supervised.classifier.decisionTree.TreeBuilder.*;
import static quickml.supervised.classifier.randomForest.RandomForestBuilder.NUM_TREES;

/**
* Created by alexanderhawk on 3/5/15.
*/
public class StaticBuilders {
private static final Logger logger = LoggerFactory.getLogger(StaticBuilders.class);

public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Map<String, FieldValueRecommender> config) {
/**
* @param rebuildsPerValidation is the number of times the model will be rebuilt with a new training set while estimating the loss of a model
* with a prarticular set of hyperparameters
* @param fractionOfDataForValidation is the fraction of the training data that out of time validation is performed on during parameter optimization.
* Note, the final model returned by the method uses all data.
*/

int timeSliceHours = getTimeSliceHours(trainingData, rebuildsPerValidation, dateTimeExtractor);
double crossValidationFraction = 0.2;
PredictiveModelOptimizer optimizer= new PredictiveModelOptimizerBuilder<Classifier, ClassifierInstance>()
.modelBuilder(new RandomForestBuilder<>())
.dataCycler(new OutOfTimeData<>(trainingData, crossValidationFraction, timeSliceHours,dateTimeExtractor))
.lossChecker(new ClassifierLossChecker<>(lossFunction))
.valuesToTest(config)
.iterations(3).build();
Map<String, Object> bestParams = optimizer.determineOptimalConfig();

RandomForestBuilder<ClassifierInstance> randomForestBuilder = new RandomForestBuilder<>(new TreeBuilder<>().ignoreAttributeAtNodeProbability(0.7)).numTrees(24);
DownsamplingClassifierBuilder<ClassifierInstance> downsamplingClassifierBuilder = new DownsamplingClassifierBuilder<>(randomForestBuilder,0.1);
downsamplingClassifierBuilder.updateBuilderConfig(bestParams);

DownsamplingClassifier downsamplingClassifier = downsamplingClassifierBuilder.buildPredictiveModel(trainingData);
return new Pair<Map<String, Object>, DownsamplingClassifier>(bestParams, downsamplingClassifier);
}
public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) {
Map<String, FieldValueRecommender> config = createConfig();
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, config);
}

private static int getTimeSliceHours(List<ClassifierInstance> trainingData, int rebuildsPerValidation, DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {

Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor);
DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1));
int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1;
DateTime earliestValidationTime = dateTimeExtractor.extractDateTime(trainingData.get(indexOfEarliestValidationInstance));
Period period = new Period(earliestValidationTime, latestDateTime);
int validationPeriodHours = period.getHours();
return validationPeriodHours/rebuildsPerValidation;
}


private static Map<String, FieldValueRecommender> createConfig() {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
config.put(MAX_DEPTH, new FixedOrderRecommender(4, 8, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9));
config.put(MIN_CAT_ATTR_OCC, new FixedOrderRecommender(7, 14));
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 15));
config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .2));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75));
return config;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import quickml.collections.MapUtils;
import quickml.data.ClassifierInstance;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.classifier.decisionTree.scorers.GiniImpurityScorer;
import quickml.supervised.classifier.decisionTree.scorers.MSEScorer;
import quickml.supervised.classifier.decisionTree.tree.*;

Expand Down Expand Up @@ -41,7 +42,7 @@ public final class TreeBuilder<T extends ClassifierInstance> implements Predicti
private Scorer scorer;
private int maxDepth = 5;
private double ignoreAttributeAtNodeProbability = 0.0;
private double minimumScore = 0.00000000000001;
private double minimumScore = 0.000001;
private int minCategoricalAttributeValueOccurances = 0;
private int minLeafInstances = 0;

Expand All @@ -58,7 +59,7 @@ public final class TreeBuilder<T extends ClassifierInstance> implements Predicti
private boolean binaryClassifications = true;

public TreeBuilder() {
this(new MSEScorer(MSEScorer.CrossValidationCorrection.FALSE));
this(new GiniImpurityScorer());
}

public TreeBuilder(final Scorer scorer) {
Expand Down Expand Up @@ -658,8 +659,10 @@ private Pair<? extends Branch, Double> createNumericNode(Node parent, final Stri
ClassificationCounter outClassificationCounts = ClassificationCounter.countAll(outSet);

if (binaryClassifications) {
if (attributeValueOrIntervalOfValuesHasInsufficientStatistics(inClassificationCounts) ||
attributeValueOrIntervalOfValuesHasInsufficientStatistics(outClassificationCounts)) {
if (attributeValueOrIntervalOfValuesHasInsufficientStatistics(inClassificationCounts)
|| inClassificationCounts.getTotal() < minLeafInstances
|| attributeValueOrIntervalOfValuesHasInsufficientStatistics(outClassificationCounts)
|| outClassificationCounts.getTotal() < minLeafInstances) {
continue;
}
} else if (shouldWeIgnoreThisValue(inClassificationCounts) || shouldWeIgnoreThisValue(outClassificationCounts)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package quickml.supervised.classifier;

import junit.framework.TestCase;
import org.javatuples.Pair;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.ClassifierInstance;
import quickml.data.OnespotDateTimeExtractor;
import quickml.supervised.classifier.downsampling.DownsamplingClassifier;
import quickml.supervised.crossValidation.lossfunctions.WeightedAUCCrossValLossFunction;

import java.util.List;
import java.util.Map;

import static quickml.supervised.InstanceLoader.getAdvertisingInstances;


public class StaticBuildersTest {
private static final Logger logger = LoggerFactory.getLogger(StaticBuildersTest.class);

@Test
public void getOptimizedDownsampledRandomForestIntegrationTest() throws Exception {
double fractionOfDataForValidation = .2;
int rebuildsPerValidation = 1;
List<ClassifierInstance> trainingData = getAdvertisingInstances().subList(0, 3000);
OnespotDateTimeExtractor dateTimeExtractor = new OnespotDateTimeExtractor();
Pair<Map<String, Object>, DownsamplingClassifier> downsamplingClassifierPair =
StaticBuilders.getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor);
logger.info("logged weighted auc loss should be between 0.25 and 0.28");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import quickml.supervised.crossValidation.ClassifierLossChecker;
import quickml.supervised.crossValidation.data.OutOfTimeData;
import quickml.supervised.crossValidation.lossfunctions.ClassifierLogCVLossFunction;
import quickml.supervised.crossValidation.lossfunctions.WeightedAUCCrossValLossFunction;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.MonotonicConvergenceRecommender;

Expand All @@ -34,10 +35,10 @@ public void setUp() throws Exception {
advertisingInstances = advertisingInstances.subList(0, 3000);
optimizer = new PredictiveModelOptimizerBuilder<Classifier, ClassifierInstance>()
.modelBuilder(new RandomForestBuilder<>())
.dataCycler(new OutOfTimeData<>(advertisingInstances, 0.5, 12, new OnespotDateTimeExtractor()))
.lossChecker(new ClassifierLossChecker<>(new ClassifierLogCVLossFunction(0.000001)))
.dataCycler(new OutOfTimeData<>(advertisingInstances, 0.2, 12, new OnespotDateTimeExtractor()))
.lossChecker(new ClassifierLossChecker<>(new WeightedAUCCrossValLossFunction(1.0)))
.valuesToTest(createConfig())
.iterations(1)
.iterations(3)
.build();
}

Expand All @@ -50,18 +51,15 @@ public void testOptimizer() throws Exception {

private Map<String, FieldValueRecommender> createConfig() {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
config.put(NUM_TREES, new MonotonicConvergenceRecommender(asList(5, 10, 20)));
config.put(IGNORE_ATTR_PROB, new FixedOrderRecommender(0.2, 0.4, 0.7));
config.put(MAX_DEPTH, new FixedOrderRecommender( 4, 8, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9));
config.put(MIN_SCORE, new FixedOrderRecommender(0.00000000000001));//, Double.MIN_VALUE, 0.0, 0.000001, 0.0001, 0.001, 0.01, 0.1));
config.put(MIN_CAT_ATTR_OCC, new FixedOrderRecommender(2, 11, 16, 30 ));
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 20, 40));
config.put(SCORER, new FixedOrderRecommender(new InformationGainScorer(), new GiniImpurityScorer()));
config.put(NUM_TREES, new FixedOrderRecommender(12, 24));
config.put(IGNORE_ATTR_PROB, new FixedOrderRecommender(0.7, 0.5));
config.put(MAX_DEPTH, new FixedOrderRecommender(4, 8, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9));
config.put(MIN_CAT_ATTR_OCC, new FixedOrderRecommender(7, 10, 15));
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 15, 30));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75, .5 ));
return config;
}




}

0 comments on commit da04812

Please sign in to comment.