Skip to content

Commit

Permalink
Merge pull request #103 from sanity/treeBuildCompat
Browse files Browse the repository at this point in the history
fixed static builders issue
  • Loading branch information
athawk81 committed Mar 9, 2015
2 parents 5dbe2ab + 18fd6f1 commit 62c8780
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
be accompanied by a bump in version number, regardless of how minor the change.
-->

<version>0.7.1</version>
<version>0.7.2</version>


<repositories>
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/quickml/supervised/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public static PredictionMapResults calcResultpredictionsWithoutAttrs(Classifier
return new PredictionMapResults(results);
}

public static void sortTrainingInstancesByTime(List<ClassifierInstance> trainingData, final DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {
public static void sortTrainingInstancesByTime(List<? extends ClassifierInstance> trainingData, final DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {
Collections.sort(trainingData, new Comparator<ClassifierInstance>() {
@Override
public int compare(ClassifierInstance o1, ClassifierInstance o2) {
Expand Down
19 changes: 11 additions & 8 deletions src/main/java/quickml/supervised/classifier/StaticBuilders.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import quickml.supervised.classifier.randomForest.RandomForestBuilder;
import quickml.supervised.crossValidation.ClassifierLossChecker;
import quickml.supervised.crossValidation.data.OutOfTimeData;
import quickml.supervised.crossValidation.data.TrainingDataCycler;
import quickml.supervised.crossValidation.lossfunctions.ClassifierLogCVLossFunction;
import quickml.supervised.crossValidation.lossfunctions.ClassifierLossFunction;
import quickml.supervised.crossValidation.lossfunctions.LossFunction;
Expand Down Expand Up @@ -47,7 +48,7 @@
public class StaticBuilders {
private static final Logger logger = LoggerFactory.getLogger(StaticBuilders.class);

public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Map<String, FieldValueRecommender> config) {
public static <T extends ClassifierInstance> Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Map<String, FieldValueRecommender> config) {
/**
* @param rebuildsPerValidation is the number of times the model will be rebuilt with a new training set while estimating the loss of a model
* with a prarticular set of hyperparameters
Expand All @@ -57,27 +58,29 @@ public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDown

int timeSliceHours = getTimeSliceHours(trainingData, rebuildsPerValidation, dateTimeExtractor);
double crossValidationFraction = 0.2;
PredictiveModelOptimizer optimizer= new PredictiveModelOptimizerBuilder<Classifier, ClassifierInstance>()
TrainingDataCycler<T> outOfTimeData = new OutOfTimeData<T>(trainingData, crossValidationFraction, timeSliceHours, dateTimeExtractor);
ClassifierLossChecker<T> classifierInstanceClassifierLossChecker = new ClassifierLossChecker<>(lossFunction);
PredictiveModelOptimizer optimizer= new PredictiveModelOptimizerBuilder<Classifier, T>()
.modelBuilder(new RandomForestBuilder<>())
.dataCycler(new OutOfTimeData<>(trainingData, crossValidationFraction, timeSliceHours,dateTimeExtractor))
.lossChecker(new ClassifierLossChecker<>(lossFunction))
.dataCycler(outOfTimeData)
.lossChecker(classifierInstanceClassifierLossChecker)
.valuesToTest(config)
.iterations(3).build();
Map<String, Object> bestParams = optimizer.determineOptimalConfig();

RandomForestBuilder<ClassifierInstance> randomForestBuilder = new RandomForestBuilder<>(new TreeBuilder<>().attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7))).numTrees(24);
DownsamplingClassifierBuilder<ClassifierInstance> downsamplingClassifierBuilder = new DownsamplingClassifierBuilder<>(randomForestBuilder,0.1);
RandomForestBuilder<T> randomForestBuilder = new RandomForestBuilder<T>(new TreeBuilder<T>().attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7))).numTrees(24);
DownsamplingClassifierBuilder<T> downsamplingClassifierBuilder = new DownsamplingClassifierBuilder<>(randomForestBuilder,0.1);
downsamplingClassifierBuilder.updateBuilderConfig(bestParams);

DownsamplingClassifier downsamplingClassifier = downsamplingClassifierBuilder.buildPredictiveModel(trainingData);
return new Pair<Map<String, Object>, DownsamplingClassifier>(bestParams, downsamplingClassifier);
}
public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) {
public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) {
Map<String, FieldValueRecommender> config = createConfig();
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, config);
}

private static int getTimeSliceHours(List<ClassifierInstance> trainingData, int rebuildsPerValidation, DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {
private static int getTimeSliceHours(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {

Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor);
DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
public class StaticBuildersTest {
private static final Logger logger = LoggerFactory.getLogger(StaticBuildersTest.class);


@Test
public void getOptimizedDownsampledRandomForestIntegrationTest() throws Exception {
double fractionOfDataForValidation = .2;
int rebuildsPerValidation = 1;
List<ClassifierInstance> trainingData = getAdvertisingInstances().subList(0, 3000);
OnespotDateTimeExtractor dateTimeExtractor = new OnespotDateTimeExtractor();
Pair<Map<String, Object>, DownsamplingClassifier> downsamplingClassifierPair =
StaticBuilders.getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor);
StaticBuilders.<ClassifierInstance>getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor);
logger.info("logged weighted auc loss should be between 0.25 and 0.28");
}
}

0 comments on commit 62c8780

Please sign in to comment.