diff --git a/pom.xml b/pom.xml index 029599c2..a14482ec 100644 --- a/pom.xml +++ b/pom.xml @@ -32,7 +32,7 @@ be accompanied by a bump in version number, regardless of how minor the change. --> - 0.7.1 + 0.7.2 diff --git a/src/main/java/quickml/supervised/Utils.java b/src/main/java/quickml/supervised/Utils.java index f0a6eb57..400bebdc 100644 --- a/src/main/java/quickml/supervised/Utils.java +++ b/src/main/java/quickml/supervised/Utils.java @@ -63,7 +63,7 @@ public static PredictionMapResults calcResultpredictionsWithoutAttrs(Classifier return new PredictionMapResults(results); } - public static void sortTrainingInstancesByTime(List trainingData, final DateTimeExtractor dateTimeExtractor) { + public static void sortTrainingInstancesByTime(List trainingData, final DateTimeExtractor dateTimeExtractor) { Collections.sort(trainingData, new Comparator() { @Override public int compare(ClassifierInstance o1, ClassifierInstance o2) { diff --git a/src/main/java/quickml/supervised/classifier/StaticBuilders.java b/src/main/java/quickml/supervised/classifier/StaticBuilders.java index 20406097..51789c96 100644 --- a/src/main/java/quickml/supervised/classifier/StaticBuilders.java +++ b/src/main/java/quickml/supervised/classifier/StaticBuilders.java @@ -20,6 +20,7 @@ import quickml.supervised.classifier.randomForest.RandomForestBuilder; import quickml.supervised.crossValidation.ClassifierLossChecker; import quickml.supervised.crossValidation.data.OutOfTimeData; +import quickml.supervised.crossValidation.data.TrainingDataCycler; import quickml.supervised.crossValidation.lossfunctions.ClassifierLogCVLossFunction; import quickml.supervised.crossValidation.lossfunctions.ClassifierLossFunction; import quickml.supervised.crossValidation.lossfunctions.LossFunction; @@ -47,7 +48,7 @@ public class StaticBuilders { private static final Logger logger = LoggerFactory.getLogger(StaticBuilders.class); - public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Map config) { + public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Map config) { /** * @param rebuildsPerValidation is the number of times the model will be rebuilt with a new training set while estimating the loss of a model * with a prarticular set of hyperparameters @@ -57,27 +58,29 @@ public static Pair, DownsamplingClassifier> getOptimizedDown int timeSliceHours = getTimeSliceHours(trainingData, rebuildsPerValidation, dateTimeExtractor); double crossValidationFraction = 0.2; - PredictiveModelOptimizer optimizer= new PredictiveModelOptimizerBuilder() + TrainingDataCycler outOfTimeData = new OutOfTimeData(trainingData, crossValidationFraction, timeSliceHours, dateTimeExtractor); + ClassifierLossChecker classifierInstanceClassifierLossChecker = new ClassifierLossChecker<>(lossFunction); + PredictiveModelOptimizer optimizer= new PredictiveModelOptimizerBuilder() .modelBuilder(new RandomForestBuilder<>()) - .dataCycler(new OutOfTimeData<>(trainingData, crossValidationFraction, timeSliceHours,dateTimeExtractor)) - .lossChecker(new ClassifierLossChecker<>(lossFunction)) + .dataCycler(outOfTimeData) + .lossChecker(classifierInstanceClassifierLossChecker) .valuesToTest(config) .iterations(3).build(); Map bestParams = optimizer.determineOptimalConfig(); - RandomForestBuilder randomForestBuilder = new RandomForestBuilder<>(new TreeBuilder<>().attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7))).numTrees(24); - DownsamplingClassifierBuilder downsamplingClassifierBuilder = new DownsamplingClassifierBuilder<>(randomForestBuilder,0.1); + RandomForestBuilder randomForestBuilder = new RandomForestBuilder(new TreeBuilder().attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7))).numTrees(24); + DownsamplingClassifierBuilder downsamplingClassifierBuilder = new DownsamplingClassifierBuilder<>(randomForestBuilder,0.1); downsamplingClassifierBuilder.updateBuilderConfig(bestParams); DownsamplingClassifier downsamplingClassifier = downsamplingClassifierBuilder.buildPredictiveModel(trainingData); return new Pair, DownsamplingClassifier>(bestParams, downsamplingClassifier); } - public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) { + public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) { Map config = createConfig(); return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, config); } - private static int getTimeSliceHours(List trainingData, int rebuildsPerValidation, DateTimeExtractor dateTimeExtractor) { + private static int getTimeSliceHours(List trainingData, int rebuildsPerValidation, DateTimeExtractor dateTimeExtractor) { Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor); DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1)); diff --git a/src/test/java/quickml/supervised/classifier/StaticBuildersTest.java b/src/test/java/quickml/supervised/classifier/StaticBuildersTest.java index 52ef2c7d..4457f2ea 100644 --- a/src/test/java/quickml/supervised/classifier/StaticBuildersTest.java +++ b/src/test/java/quickml/supervised/classifier/StaticBuildersTest.java @@ -19,14 +19,14 @@ public class StaticBuildersTest { private static final Logger logger = LoggerFactory.getLogger(StaticBuildersTest.class); - +@Test public void getOptimizedDownsampledRandomForestIntegrationTest() throws Exception { double fractionOfDataForValidation = .2; int rebuildsPerValidation = 1; List trainingData = getAdvertisingInstances().subList(0, 3000); OnespotDateTimeExtractor dateTimeExtractor = new OnespotDateTimeExtractor(); Pair, DownsamplingClassifier> downsamplingClassifierPair = - StaticBuilders.getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor); + StaticBuilders.getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor); logger.info("logged weighted auc loss should be between 0.25 and 0.28"); } } \ No newline at end of file