Skip to content

Commit

Permalink
Merge pull request #151 from sanity/regressionTrees
Browse files Browse the repository at this point in the history
Regression trees
  • Loading branch information
athawk81 committed May 5, 2016
2 parents ae2feec + 7774d96 commit 5b706cc
Show file tree
Hide file tree
Showing 96 changed files with 5,608 additions and 114 deletions.
7 changes: 6 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
be accompanied by a bump in version number, regardless of how minor the change.
0.10.1 -->

<version> 0.10.9</version>
<version> 0.10.10</version>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand All @@ -56,6 +56,11 @@
<artifactId>super-csv</artifactId>
<version>2.2.0</version>
</dependency>
<dependency>
<groupId>com.googlecode.efficient-java-matrix-library</groupId>
<artifactId>ejml</artifactId>
<version>0.23</version>
</dependency>
<dependency>
<groupId>net.sf.opencsv</groupId>
<artifactId>opencsv</artifactId>
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/quickml/data/AttributesMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ public class AttributesMap implements Map<String, Serializable>, Serializable {
private final Map<String, Serializable> map;

public AttributesMap(Map<String, Serializable> map) {
this.map = map;
this.map = new HashMap<String, Serializable>();
this.map.putAll(map);
}

public AttributesMap() {
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/quickml/data/OnespotDateTimeExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import org.joda.time.DateTime;
import quickml.data.instances.ClassifierInstance;
import quickml.data.instances.InstanceWithAttributesMap;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;


public class OnespotDateTimeExtractor<T extends ClassifierInstance> implements DateTimeExtractor<T> {
public class OnespotDateTimeExtractor<T extends InstanceWithAttributesMap> implements DateTimeExtractor<T> {

@Override
public DateTime extractDateTime(T instance) {
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/quickml/data/instances/RegressionInstance.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package quickml.data.instances;

import quickml.data.AttributesMap;

import java.io.Serializable;

/**
* Created by alexanderhawk on 4/14/15.
*/
public class RegressionInstance extends InstanceWithAttributesMap<Double> {
public RegressionInstance(AttributesMap attributes, Double label) {
super(attributes, label, 1.0);
}
public RegressionInstance(AttributesMap attributes, Double label, double weight) {
super(attributes, label, weight);
}
public RegressionInstance(AttributesMap attributes, Double label, double weight, double alternativeTarget) {
super(attributes, label, weight);
this.alternativeTarget = alternativeTarget;
}
public double alternativeTarget;
public long id;

}

84 changes: 84 additions & 0 deletions src/main/java/quickml/data/instances/SparseRegressionInstance.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package quickml.data.instances;

import org.javatuples.Pair;
import quickml.data.AttributesMap;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Map;

/**
* Created by alexanderhawk on 10/12/15.
*/
public class SparseRegressionInstance extends RegressionInstance {
private int[] indicesOfCorrespondingWeights;
private double[] values;

public SparseRegressionInstance(AttributesMap attributes, Double label, Map<String, Integer> nameToValueIndexMap) {
super(attributes, label);
setIndicesAndValues(attributes, nameToValueIndexMap);
}

public SparseRegressionInstance(AttributesMap attributes, Double label, double weight, Map<String, Integer> nameToValueIndexMap) {
super(attributes, label, weight);
setIndicesAndValues(attributes, nameToValueIndexMap);
}

private void setIndicesAndValues(AttributesMap attributes, Map<String, Integer> nameToIndexMap) {
indicesOfCorrespondingWeights = new int[attributes.size()+1];
values = new double[attributes.size()+1];
//add bias term
indicesOfCorrespondingWeights[0] = 0;
values[0] = 1.0;
//add non bias terms
int i = 1;
for (Map.Entry<String, Serializable> entry : attributes.entrySet()) {
if (!(entry.getValue() instanceof Double)) {
throw new RuntimeException("wrong type of values in attributes");
}
int valueIndex = nameToIndexMap.get(entry.getKey());
indicesOfCorrespondingWeights[i] = valueIndex;
values[i] = (Double)entry.getValue();
i++;
}
}

public static double[] getArrayOfValues(RegressionInstance regressionInstance, Map<String, Integer> nameToIndexMap, boolean useBias){
int numAttributes = regressionInstance.getAttributes().size();
AttributesMap attributesMap = regressionInstance.getAttributes();
double[] valuesArray;
int attributeIndex = 0;

if (useBias) {
valuesArray = new double[numAttributes + 1];
valuesArray[0] = 1.0;
attributeIndex++;
} else {
valuesArray = new double[numAttributes];
}
for (Map.Entry<String, Serializable> attributeEntry : attributesMap.entrySet()) {
attributeIndex = nameToIndexMap.get(attributeEntry.getKey());
valuesArray[attributeIndex] = (Double)attributeEntry.getValue();
}
return valuesArray;
}

@Override
public AttributesMap getAttributes() {
return super.getAttributes();
}

public Pair<int[], double[]> getSparseAttributes(){
return new Pair<>(indicesOfCorrespondingWeights, values);
}

public double dotProduct(double[] omega) {
double result = 0;
for (int i = 0; i< indicesOfCorrespondingWeights.length; i++) {
int indexOfFeature = indicesOfCorrespondingWeights[i];
double valueOfFeature = values[i];
result+= omega[indexOfFeature]* valueOfFeature;
}
return result;
}
}
76 changes: 76 additions & 0 deletions src/main/java/quickml/experiments/GeoDistance.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package quickml.experiments;

/**
* Created by alexanderhawk on 4/9/16.
*
import java.util.*;
import java.lang.*;
import java.io.*;
*/
public class GeoDistance {

/*::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::*/
/*:: :*/
/*:: This routine calculates the distance between two points (given the :*/
/*:: latitude/longitude of those points). It is being used to calculate :*/
/*:: the distance between two locations using GeoDataSource (TM) prodducts :*/
/*:: :*/
/*:: Definitions: :*/
/*:: South latitudes are negative, east longitudes are positive :*/
/*:: :*/
/*:: Passed to function: :*/
/*:: lat1, lon1 = Latitude and Longitude of point 1 (in decimal degrees) :*/
/*:: lat2, lon2 = Latitude and Longitude of point 2 (in decimal degrees) :*/
/*:: unit = the unit you desire for results :*/
/*:: where: 'M' is statute miles (default) :*/
/*:: 'K' is kilometers :*/
/*:: 'N' is nautical miles :*/
/*:: Worldwide cities and other features databases with latitude longitude :*/
/*:: are available at http://www.geodatasource.com :*/
/*:: :*/
/*:: For enquiries, please contact [email protected] :*/
/*:: :*/
/*:: Official Web site: http://www.geodatasource.com :*/
/*:: :*/
/*:: GeoDataSource.com (C) All Rights Reserved 2015 :*/
/*:: :*/
/*::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::*/


public static void main (String[] args) throws java.lang.Exception
{
System.out.println(distance(32.9697, -96.80322, 29.46786, -98.53506, "M") + " Miles\n");
System.out.println(distance(32.9697, -96.80322, 29.46786, -98.53506, "K") + " Kilometers\n");
System.out.println(distance(32.9697, -96.80322, 29.46786, -98.53506, "N") + " Nautical Miles\n");
}

public static double distance(double lat1, double lon1, double lat2, double lon2, String unit) {
double theta = lon1 - lon2;
double dist = Math.sin(deg2rad(lat1)) * Math.sin(deg2rad(lat2)) + Math.cos(deg2rad(lat1)) * Math.cos(deg2rad(lat2)) * Math.cos(deg2rad(theta));
dist = Math.acos(dist);
dist = rad2deg(dist);
dist = dist * 60 * 1.1515;
if (unit == "K") {
dist = dist * 1.609344;
} else if (unit == "N") {
dist = dist * 0.8684;
}

return (dist);
}

/*:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::*/
/*:: This function converts decimal degrees to radians :*/
/*:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::*/
private static double deg2rad(double deg) {
return (deg * Math.PI / 180.0);
}

/*:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::*/
/*:: This function converts radians to decimal degrees :*/
/*:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::*/
private static double rad2deg(double rad) {
return (rad * 180 / Math.PI);
}
}

52 changes: 50 additions & 2 deletions src/main/java/quickml/experiments/kin88nm.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
package quickml.experiments;

import org.javatuples.Pair;
import quickml.data.instances.ClassifierInstance;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.crossValidation.RegressionLossChecker;
import quickml.supervised.crossValidation.SimpleCrossValidator;
import quickml.supervised.crossValidation.data.FoldedData;
import quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions.RegressionRMSELossFunction;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import quickml.supervised.ensembles.randomForest.randomRegressionForest.RandomRegressionForest;
import quickml.supervised.ensembles.randomForest.randomRegressionForest.RandomRegressionForestBuilder;
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;
import quickml.supervised.tree.regressionTree.OptimizedRegressionForests;
import quickml.supervised.tree.regressionTree.RegressionTree;
import quickml.supervised.tree.regressionTree.RegressionTreeBuilder;
import quickml.utlities.CSVToInstanceReader;
import quickml.utlities.CSVToInstanceReaderBuilder;
import quickml.utlities.selectors.NumericSelector;

import java.io.Serializable;
import java.util.List;
import java.util.Map;

/**
* Created by alexanderhawk on 9/16/15.
Expand All @@ -23,10 +38,43 @@ public boolean isNumeric(String columnName) {
public String cleanValue(String value) {
return value;
}
}).delimiter(',').collumnNameForLabel("y");
}).delimiter(',').collumnNameForLabel("x8").hasHeader(false);
CSVToInstanceReader csvToInstanceReader =csvToInstanceReaderBuilder.buildCsvReader();
try {
List<ClassifierInstance> allTrainingData = csvToInstanceReader.readCsv("");
List<RegressionInstance> allTrainingData = csvToInstanceReader.readRegressionInstancesFromCsv("uci-20070111-kin8nm.csv");
List<RegressionInstance> trainData = csvToInstanceReader.readRegressionInstancesFromCsv("/Users/alexanderhawk/msda-denoising/spearmint/data/kin8nm_train.csv");
List<RegressionInstance> valData = csvToInstanceReader.readRegressionInstancesFromCsv("/Users/alexanderhawk/msda-denoising/spearmint/data/kin8nm_test.csv");
RegressionTreeBuilder<RegressionInstance> regressionTreeBuilder
= new RegressionTreeBuilder<>()
.degreeOfGainRatioPenalty(1.0)
.attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.5))
.maxDepth(18)
.minLeafInstances(2)
.minSplitFraction(0.1)
.numNumericBins(10)
.numSamplesPerNumericBin(20);
RandomRegressionForestBuilder<RegressionInstance> regressionForestBuilder = new RandomRegressionForestBuilder<>(regressionTreeBuilder).numTrees(400);
//RegressionTree regressionTree = regressionTreeBuilder.buildPredictiveModel(trainData);


RandomRegressionForest randomRegressionForest = regressionForestBuilder.buildPredictiveModel(trainData);
//Pair<Map<String, Serializable>, RandomRegressionForest> randomForestPair = OptimizedRegressionForests.<RegressionInstance>getOptimizedRandomForest(trainData);
//RandomRegressionForest randomRegressionForest = randomForestPair.getValue1();


double loss =0;
for (RegressionInstance instance: valData) {
loss+=(instance.getLabel() - randomRegressionForest.predict(instance.getAttributes()))
*(instance.getLabel() - randomRegressionForest.predict(instance.getAttributes()));
}
loss=Math.sqrt(loss/valData.size());
System.out.println("loss " + loss);

SimpleCrossValidator simpleCrossValidator = new SimpleCrossValidator(regressionTreeBuilder,
new RegressionLossChecker(new RegressionRMSELossFunction()), new FoldedData(allTrainingData, 8, 8));
// double loss=simpleCrossValidator.getLossForModel();

System.out.println("here");
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.javatuples.Pair;
import quickml.collections.ValueSummingMap;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;

import java.io.Serializable;
import java.util.*;
Expand All @@ -25,7 +26,7 @@ public OldClassificationCounter(OldClassificationCounter classificationCounter)
this.counts.putAll(classificationCounter.counts);
}

public OldClassificationCounter(quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter classificationCounter) {
public OldClassificationCounter(ClassificationCounter classificationCounter) {
this.counts.putAll(classificationCounter.getCounts());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ public boolean ignoreAttribute(String attribute, OldBranch parent) {
return false;
}

public double getIgnoreAttributeProbability() {
return ignoreAttributeProbability;
}

@Override
public String toString(){
return "ignoreAttributeProbability = " + ignoreAttributeProbability;
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/quickml/supervised/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import quickml.data.instances.Instance;
import quickml.data.instances.InstanceWithAttributesMap;
import quickml.data.PredictionMap;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.classifier.logisticRegression.SparseClassifierInstance;
import quickml.supervised.crossValidation.PredictionMapResult;
Expand All @@ -22,8 +23,11 @@
import quickml.supervised.tree.nodes.Node;
import quickml.supervised.tree.summaryStatistics.ValueCounter;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.Serializable;
import java.math.BigDecimal;
import java.nio.DoubleBuffer;
import java.util.*;

/**
Expand Down Expand Up @@ -69,6 +73,21 @@ public static List<LabelPredictionWeight<Double, Double>> getRegLabelsPrediction
return results;
}

public static List<LabelPredictionWeight<Double, Double>> getRegLabelsPredictionsWeights(PredictiveModel<AttributesMap, Double> predictiveModel, List<? extends Instance<AttributesMap, Double>> validationSet, BufferedWriter bw) {
ArrayList<LabelPredictionWeight<Double, Double>> results = new ArrayList<>();
for (Instance<AttributesMap, Double> instance : validationSet) {
Double prediction = predictiveModel.predict(instance.getAttributes());
Long id = ((RegressionInstance)instance).id;
results.add(new LabelPredictionWeight<Double, Double>(instance.getLabel(), prediction, instance.getWeight()));
try {
bw.write(""+id + "," + instance.getLabel() + "," + prediction + "\n");
} catch (IOException e) {
e.printStackTrace();
}
}
return results;
}

public static PredictionMapResults calcResultPredictions(Classifier predictiveModel, List<? extends InstanceWithAttributesMap<?>> validationSet) {
ArrayList<PredictionMapResult> results = new ArrayList<>();
for (InstanceWithAttributesMap<?> instance : validationSet) {
Expand All @@ -86,6 +105,15 @@ public static PredictionMapResults calcResultpredictionsWithoutAttrs(Classifier
return new PredictionMapResults(results);
}

public static List<LabelPredictionWeight<Double, Double>> calcLabelPredictionsWeightsWithoutAttrs(PredictiveModel<AttributesMap, Double> predictiveModel, List<? extends RegressionInstance> validationSet, Set<String> attributesToIgnore) {
ArrayList<LabelPredictionWeight<Double, Double>> results = new ArrayList<>();
for (RegressionInstance instance : validationSet) {
Double prediction = predictiveModel.predictWithoutAttributes(instance.getAttributes(), attributesToIgnore);
results.add(new LabelPredictionWeight<Double, Double>(prediction, instance.getLabel(), instance.getWeight()));
}
return results;
}

public static <T extends InstanceWithAttributesMap<?>> void sortTrainingInstancesByTime(List<T> trainingData, final DateTimeExtractor<T> dateTimeExtractor) {
Collections.sort(trainingData, new Comparator<T>() {
@Override
Expand Down
Loading

0 comments on commit 5b706cc

Please sign in to comment.