diff --git a/pom.xml b/pom.xml
index 02b684e3..f65afb10 100644
--- a/pom.xml
+++ b/pom.xml
@@ -32,7 +32,7 @@
be accompanied by a bump in version number, regardless of how minor the change.
0.10.1 -->
- 0.10.7
+ 0.10.9
diff --git a/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java b/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java
index b5ec73b0..7c45634b 100644
--- a/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java
+++ b/src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionBuilder.java
@@ -32,12 +32,11 @@ public class LogisticRegressionBuilder> imple
public StandardDataTransformer logisticRegressionDataTransformer;
private ProductFeatureAppender productFeatureAppender;
- private DataTransformer dataTransformer;
GradientDescent gradientDescent = new SparseSGD();
private int minWeightForPavBuckets =2;
public LogisticRegressionBuilder(StandardDataTransformer dataTransformer) {
- this.dataTransformer = dataTransformer;
+ this.logisticRegressionDataTransformer = dataTransformer;
}
public LogisticRegressionBuilder productFeatureAppender(ProductFeatureAppender productFeatureAppender) {
@@ -67,7 +66,7 @@ public LogisticRegressionBuilder poolAdjacentViolatorsMinWeight(int minWeight
@Override
public D transformData(List rawInstances){
- return dataTransformer.transformData(rawInstances);
+ return logisticRegressionDataTransformer.transformData(rawInstances);
}
diff --git a/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java b/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java
index 8bc45cac..06e0a3a3 100644
--- a/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java
+++ b/src/main/java/quickml/supervised/classifier/logisticRegression/StandardDataTransformer.java
@@ -18,12 +18,6 @@
* Created by alexanderhawk on 10/14/15.
*/
public abstract class StandardDataTransformer> implements DataTransformer {
- //to do: get label to digit Map and stick in DTO (and transform to logistic regression eventually)
- //make LogisticRegressionBuilder use this class and not be tightly coupled to mean normalization (e.g. allow log^2 values)
- //make cross validator take a datetransformer (specifically, the Logistic regression PMB, and then do the data normalization
- // and set the date time extractor)
-
-
/**
* class provides the method: transformInstances, to convert a set of classifier instances into instances that can be processed by
* the LogisticRegressionBuilder.
@@ -32,9 +26,7 @@ public abstract class StandardDataTransformer
* product feature appendation as well as common co-occurences should be hyper-params within logistic regression.
*
*/
- /*Options, wrap logistic regression? in a new logistic regression class that has a logistic reg transformer?
- * Or change sparse classifier instance as the the type of Logistic Regression? I almost prefer this. So now to use it...one just passes in a normal list of training instances
- */
+
protected ProductFeatureAppender productFeatureAppender;
diff --git a/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java b/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java
index 8502fdf1..40ebda9b 100644
--- a/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java
+++ b/src/test/java/quickml/supervised/classifier/logRegression/LogisticRegressionBuilderTest.java
@@ -1,5 +1,7 @@
package quickml.supervised.classifier.logRegression;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Maps;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
@@ -21,9 +23,16 @@
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction;
import quickml.supervised.dataProcessing.instanceTranformer.CommonCoocurrenceProductFeatureAppender;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
+import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender;
+import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer;
+import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;
import quickml.supervised.tree.decisionTree.DecisionTreeBuilder;
import java.util.List;
+import java.util.Map;
+
+import static quickml.supervised.classifier.logisticRegression.LogisticRegressionBuilder.MIN_OBSERVATIONS_OF_ATTRIBUTE;
+import static quickml.supervised.classifier.logisticRegression.SparseSGD.*;
/**
* Created by alexanderhawk on 10/13/15.
@@ -121,4 +130,56 @@ public void testDiabetesInstances() {
logger.info("RF out of time loss: {}", simpleCrossValidator.getLossForModel());
}
+ @Ignore
+ @Test
+ public void optimizerTest(){
+
+ List instances = InstanceLoader.getAdvertisingInstances().subList(0,1000);
+ CommonCoocurrenceProductFeatureAppender productFeatureAppender = new CommonCoocurrenceProductFeatureAppender<>()
+ .setMinObservationsOfRawAttribute(35)
+ .setAllowCategoricalProductFeatures(false)
+ .setAllowNumericProductFeatures(false)
+ .setApproximateOverlap(true)
+ .setMinOverlap(20)
+ .setIgnoreAttributesCommonToAllInsances(true);
+
+ DatedAndMeanNormalizedLogisticRegressionDataTransformer lrdt = new DatedAndMeanNormalizedLogisticRegressionDataTransformer()
+ .minObservationsOfAttribute(35)
+ .usingProductFeatures(false)
+ .productFeatureAppender(productFeatureAppender);
+
+ LogisticRegressionBuilder logisticRegressionBuilder = new LogisticRegressionBuilder(lrdt)
+ .calibrateWithPoolAdjacentViolators(false)
+ .gradientDescent(new SparseSGD()
+ .ridgeRegularizationConstant(0.1)
+ .learningRate(.0025)
+ .minibatchSize(1000)
+ .minEpochs(500)
+ .maxEpochs(500)
+ .minPredictedProbablity(1E-3)
+ .sparseParallelization(true)
+ );
+ double start = System.nanoTime();
+ EnhancedCrossValidator enhancedCrossValidator = new EnhancedCrossValidator<>(logisticRegressionBuilder,
+ new ClassifierLossChecker(new WeightedAUCCrossValLossFunction(1.0)),
+ new OutOfTimeDataFactory(0.25, 48), instances);
+
+
+
+
+
+ Map sgdParams = Maps.newHashMap();
+ sgdParams.put(RIDGE, new FixedOrderRecommender(.0001));//;, .001, .01, .1, 1));//MonotonicConvergenceRecommender(numTreesList, 0.01));
+ sgdParams.put(MIN_EPOCHS, new FixedOrderRecommender(8000));// 16000));
+ sgdParams.put(MAX_EPOCHS, new FixedOrderRecommender(16000));//, 3200));
+ sgdParams.put(LEARNING_RATE, new FixedOrderRecommender(.0025));//, .001, .005));//11, 14, 16 //Pbest 12
+ sgdParams.put(MIN_OBSERVATIONS_OF_ATTRIBUTE, new FixedOrderRecommender(20, 50));// 16000));
+ PredictiveModelOptimizer modelOptimizer = new PredictiveModelOptimizer(sgdParams, enhancedCrossValidator, 2);
+
+
+
+
+ logger.info("Optimal sgd parameters: {}", modelOptimizer.determineOptimalConfig());
+ }
+
}
\ No newline at end of file