diff --git a/pom.xml b/pom.xml index 02b684e3..f65afb10 100644 --- a/pom.xml +++ b/pom.xml @@ -32,7 +32,7 @@ be accompanied by a bump in version number, regardless of how minor the change. 0.10.1 --> - 0.10.7 + 0.10.9 diff --git a/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java b/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java index b5ec73b0..7c45634b 100644 --- a/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java +++ b/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java @@ -32,12 +32,11 @@ public class LogisticRegressionBuilder> imple public StandardDataTransformer logisticRegressionDataTransformer; private ProductFeatureAppender productFeatureAppender; - private DataTransformer dataTransformer; GradientDescent gradientDescent = new SparseSGD(); private int minWeightForPavBuckets =2; public LogisticRegressionBuilder(StandardDataTransformer dataTransformer) { - this.dataTransformer = dataTransformer; + this.logisticRegressionDataTransformer = dataTransformer; } public LogisticRegressionBuilder productFeatureAppender(ProductFeatureAppender productFeatureAppender) { @@ -67,7 +66,7 @@ public LogisticRegressionBuilder poolAdjacentViolatorsMinWeight(int minWeight @Override public D transformData(List rawInstances){ - return dataTransformer.transformData(rawInstances); + return logisticRegressionDataTransformer.transformData(rawInstances); } diff --git a/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java b/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java index 8bc45cac..06e0a3a3 100644 --- a/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java +++ b/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java @@ -18,12 +18,6 @@ * Created by alexanderhawk on 10/14/15. */ public abstract class StandardDataTransformer> implements DataTransformer { - //to do: get label to digit Map and stick in DTO (and transform to logistic regression eventually) - //make LogisticRegressionBuilder use this class and not be tightly coupled to mean normalization (e.g. allow log^2 values) - //make cross validator take a datetransformer (specifically, the Logistic regression PMB, and then do the data normalization - // and set the date time extractor) - - /** * class provides the method: transformInstances, to convert a set of classifier instances into instances that can be processed by * the LogisticRegressionBuilder. @@ -32,9 +26,7 @@ public abstract class StandardDataTransformer * product feature appendation as well as common co-occurences should be hyper-params within logistic regression. * */ - /*Options, wrap logistic regression? in a new logistic regression class that has a logistic reg transformer? - * Or change sparse classifier instance as the the type of Logistic Regression? I almost prefer this. So now to use it...one just passes in a normal list of training instances - */ + protected ProductFeatureAppender productFeatureAppender; diff --git a/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java b/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java index 8502fdf1..40ebda9b 100644 --- a/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java +++ b/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java @@ -1,5 +1,7 @@ package quickml.supervised.classifier.logRegression; +import com.google.common.collect.Iterables; +import com.google.common.collect.Maps; import org.junit.Ignore; import org.junit.Test; import org.slf4j.Logger; @@ -21,9 +23,16 @@ import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction; import quickml.supervised.dataProcessing.instanceTranformer.CommonCoocurrenceProductFeatureAppender; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder; +import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender; +import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer; +import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender; import quickml.supervised.tree.decisionTree.DecisionTreeBuilder; import java.util.List; +import java.util.Map; + +import static quickml.supervised.classifier.logisticRegression.LogisticRegressionBuilder.MIN_OBSERVATIONS_OF_ATTRIBUTE; +import static quickml.supervised.classifier.logisticRegression.SparseSGD.*; /** * Created by alexanderhawk on 10/13/15. @@ -121,4 +130,56 @@ public void testDiabetesInstances() { logger.info("RF out of time loss: {}", simpleCrossValidator.getLossForModel()); } + @Ignore + @Test + public void optimizerTest(){ + + List instances = InstanceLoader.getAdvertisingInstances().subList(0,1000); + CommonCoocurrenceProductFeatureAppender productFeatureAppender = new CommonCoocurrenceProductFeatureAppender<>() + .setMinObservationsOfRawAttribute(35) + .setAllowCategoricalProductFeatures(false) + .setAllowNumericProductFeatures(false) + .setApproximateOverlap(true) + .setMinOverlap(20) + .setIgnoreAttributesCommonToAllInsances(true); + + DatedAndMeanNormalizedLogisticRegressionDataTransformer lrdt = new DatedAndMeanNormalizedLogisticRegressionDataTransformer() + .minObservationsOfAttribute(35) + .usingProductFeatures(false) + .productFeatureAppender(productFeatureAppender); + + LogisticRegressionBuilder logisticRegressionBuilder = new LogisticRegressionBuilder(lrdt) + .calibrateWithPoolAdjacentViolators(false) + .gradientDescent(new SparseSGD() + .ridgeRegularizationConstant(0.1) + .learningRate(.0025) + .minibatchSize(1000) + .minEpochs(500) + .maxEpochs(500) + .minPredictedProbablity(1E-3) + .sparseParallelization(true) + ); + double start = System.nanoTime(); + EnhancedCrossValidator enhancedCrossValidator = new EnhancedCrossValidator<>(logisticRegressionBuilder, + new ClassifierLossChecker(new WeightedAUCCrossValLossFunction(1.0)), + new OutOfTimeDataFactory(0.25, 48), instances); + + + + + + Map sgdParams = Maps.newHashMap(); + sgdParams.put(RIDGE, new FixedOrderRecommender(.0001));//;, .001, .01, .1, 1));//MonotonicConvergenceRecommender(numTreesList, 0.01)); + sgdParams.put(MIN_EPOCHS, new FixedOrderRecommender(8000));// 16000)); + sgdParams.put(MAX_EPOCHS, new FixedOrderRecommender(16000));//, 3200)); + sgdParams.put(LEARNING_RATE, new FixedOrderRecommender(.0025));//, .001, .005));//11, 14, 16 //Pbest 12 + sgdParams.put(MIN_OBSERVATIONS_OF_ATTRIBUTE, new FixedOrderRecommender(20, 50));// 16000)); + PredictiveModelOptimizer modelOptimizer = new PredictiveModelOptimizer(sgdParams, enhancedCrossValidator, 2); + + + + + logger.info("Optimal sgd parameters: {}", modelOptimizer.determineOptimalConfig()); + } + } \ No newline at end of file