Skip to content

Commit

Permalink
Merge from theJenix/ABAGAIL:
Browse files Browse the repository at this point in the history
DataSetFilter implemented to split a dataset into a testing and training set
Added code to recombine separated binary labels into discrete labels (for the
test metrics)

Fixed bugs in metrics introduced by splitting labels into binary labels

Added a framework for defining experiment runners, which expose a standard set
of test metrics, and support writing to a CSV file

Added a CSV writer and writer framework, to write data to a file
  • Loading branch information
theJenix committed Mar 7, 2013
1 parent 309a8e5 commit 9a1184c
Show file tree
Hide file tree
Showing 12 changed files with 443 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ ABAGAIL.jar
*.jar
*.war
*.ear
manifest.mf
46 changes: 46 additions & 0 deletions src/shared/filt/TestTrainSplitFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package shared.filt;

import shared.DataSet;
import shared.Instance;

public class TestTrainSplitFilter implements DataSetFilter {

private double pctTrain;
private DataSet trainingSet;
private DataSet testingSet;

/**
*
*
* @param pctTrain A percentage from 0 to 100
*/
public TestTrainSplitFilter(int pctTrain) {
this.pctTrain = 1.0 * pctTrain / 100; //
}

@Override
public void filter(DataSet dataSet) {
int totalInstances = dataSet.getInstances().length;
int trainInstances = (int) (totalInstances * pctTrain);
int testInstances = totalInstances - trainInstances;
Instance[] train = new Instance[trainInstances];
Instance[] test = new Instance[testInstances];
for (int ii = 0; ii < trainInstances; ii++) {
train[ii] = dataSet.get(ii);
}
for (int ii = trainInstances; ii < totalInstances; ii++) {
test[ii - trainInstances] = dataSet.get(ii);
}

this.trainingSet = new DataSet(train);
this.testingSet = new DataSet(test);
}

public DataSet getTrainingSet() {
return this.trainingSet;
}

public DataSet getTestingSet() {
return this.testingSet;
}
}
39 changes: 37 additions & 2 deletions src/shared/reader/DataSetLabelBinarySeperator.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import shared.DataSet;
import shared.Instance;
import shared.tester.Comparison;

/**
* Separates Labels into Binary representation for better use in Neural Networks
* Separates Discrete Labels into Binary representation for better use in Neural Networks
* @author Alex Linton <https://github.com/lexlinton>
* @date 2013-03-05
*/

public class DataSetLabelBinarySeperator {
public static void seperateLabels(DataSet set){

public static void seperateLabels(DataSet set) {
int numberOfLabels = 0;
ArrayList<Integer> labels = new ArrayList<Integer>();
//count up the number of distinct labels
Expand All @@ -35,4 +37,37 @@ public static void seperateLabels(DataSet set){
values[labelValue] = 0;
}
}

/**
* Combine separated labels into a single valued instance representing
* the output label, based on the max value found in the instance.
*
* NOTE: This assumes labels that were split using separateLabels, and
* a function that maps output values to
*
* @param instance
* @return
*/
public static Instance combineLabels(Instance instance) {
//if it's already a size 1 instance, we assume it's already collapsed...otherwise
// the code below will adversely affect the value
if (instance.size() == 1) {
return instance;
}

//we have values to collapse into a discrete measurement based
// on the instance datapoint with the biggest value. This is meant
// to be a reversal of separateLabels
int maxInx = -1;
double max = 0;
for (int ii = 0; ii < instance.size(); ii++) {
double inst = instance.getContinuous(ii);
if (inst > max) {
maxInx = ii;
max = inst;
}
}
//max will be the max value (between 0 and 1), and maxInx will be
return new Instance(max + maxInx - 1);
}
}
95 changes: 95 additions & 0 deletions src/shared/runner/MultiRunner.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package shared.runner;

import java.io.File;

import shared.tester.AccuracyTestMetric;
import shared.tester.ConfusionMatrixTestMetric;
import shared.writer.CSVWriter;
import shared.writer.Writer;
import util.TimeUtil;

/**
* A runner for multiple tests/experiments. This class takes in a Runner, an array of iteration values to use,
* and an array of test/train splits to use
*
* @author Jesse Rosalia <https://github.com/theJenix>
* @date 2013-03-06
*/
public class MultiRunner {

private Runner runner;
private int[] iterArray;
private int[] pctTrainArray;
private Writer writer;
private File outputFolder;

public MultiRunner(Runner runner, int[] iterArray, int[] pctTrainArray) {
this.runner = runner;
this.iterArray = iterArray;
this.pctTrainArray = pctTrainArray;
}

/**
* Run all combinations of iterations and test/train splits, and output the results.
*
* @throws Exception
*/
public void runAll() throws Exception {
String[] outputFields = {
"iterations",
"% train",
"% correct",
// "% incorrect",
};
Writer writer = null;
if (this.outputFolder != null) {
File outputFile = new File(this.outputFolder, String.format("%s.csv", runner.getName()));
writer = new CSVWriter(outputFile.getAbsolutePath(), outputFields);
writer.open();
}
for (int iterations : iterArray) {
for (int pctTrain : pctTrainArray) {
runner.run(iterations, pctTrain);

AccuracyTestMetric am = runner.getAccuracyMetric();
ConfusionMatrixTestMetric cm = runner.getConfusionMatrix();

//print results to the console
String trainTime = TimeUtil.formatTime(runner.getTrainingTime());
String testTime = TimeUtil.formatTime(runner.getTestTime());
am.printResults();
System.out.println("Training time: " + trainTime);
System.out.println("Testing time: " + testTime);
System.out.println("Number of iterations: " + iterations);
if (pctTrain > 0 && pctTrain < 100) {
System.out.println(String.format("%02d%% training / %02d%% testing", pctTrain, 100 - pctTrain));
} else {
System.out.println("Testing using training set.");
}
cm.printResults();
System.out.println();

//write results to a file if available
if (writer != null) {
writer.write("" + iterations);
writer.write("" + pctTrain);
writer.write("" + am.getPctCorrect());
writer.nextRecord();
}
}
}
if (writer != null) {
writer.close();
}
}

/**
* Set the output folder for results files. If not set, this class
* will not output any data to file.
*
* @param outputFolder
*/
public void setOutputFolder(File outputFolder) {
this.outputFolder = outputFolder;
}
}
77 changes: 77 additions & 0 deletions src/shared/runner/Runner.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package shared.runner;
import shared.tester.AccuracyTestMetric;
import shared.tester.ConfusionMatrixTestMetric;
import shared.tester.RawOutputTestMetric;

/**
* A runner for a given experiment or test. The runner will be responsible
* for constructing a classifier, loading the data into a DataSet, constructing
* the necessary objects to train and test the classifier, and collecting the
* results in a series of TestMetric objects.
*
* NOTE: most of the metrics in this API refer to the last call to run, and are
* replaced after each subsequent call. You must call run at least once for them
* to be valid.
*
* @author Jesse Rosalia <https://github.com/theJenix>
* @date 2013-03-07
*/
public interface Runner {

/**
* Get the accuracy metric test metric for the last run, for reporting % correct
* and % incorrect.
*
* @return
*/
public AccuracyTestMetric getAccuracyMetric();

/**
* Get the confusion matrix for the last run.
*
* @return
*/
public ConfusionMatrixTestMetric getConfusionMatrix();

/**
* Get a name for this runner. This is likely the name of the implementation
* combined with the name of the data set.
*
* @return
*/
public String getName();

/**
* Get the raw output metric for the last run.
*
* @return
*/
public RawOutputTestMetric getRawOutput();

/**
* Get the training time for the last run.
*
* @return
*/
public long getTrainingTime();

/**
* Get the testing time for the last run.
*
* @return
*/
public long getTestTime();

/**
* Run the runner with the specified number of training iterations and
* specified train/test split.
*
* @param iterations The number of iterations used to train the classifier.
* @param pctTrain The % of the data to use in a training set (the remainder is
* in the testing set). Note that values of 0 or 100 will result in the whole
* set being used for training and testing.
*
* @throws Exception
*/
public void run(int iterations, int pctTrain) throws Exception;
}
18 changes: 10 additions & 8 deletions src/shared/tester/AccuracyTestMetric.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package shared.tester;

import shared.Instance;
import shared.reader.DataSetLabelBinarySeperator;
import util.linalg.Vector;

/**
Expand All @@ -17,20 +18,21 @@ public class AccuracyTestMetric implements TestMetric {
@Override
public void addResult(Instance expected, Instance actual) {
Comparison c = new Comparison(expected, actual);
for (int ii = 0; ii < expected.size(); ii++) {
//count up one for each instance
count++;
if (c.isCorrect(ii)) {
//count up one for each correct instance
countCorrect++;
}

count++;
if (c.isAllCorrect()) {
countCorrect++;
}
}

public double getPctCorrect() {
return count > 0 ? ((double)countCorrect)/count : 1; //if count is 0, we consider it all correct
}

public void printResults() {
//only report results if there were any results to report.
if (count > 0) {
double pctCorrect = ((double)countCorrect)/count;
double pctCorrect = getPctCorrect();
double pctIncorrect = (1 - pctCorrect);
System.out.println(String.format("Correctly Classified Instances: %.02f%%", 100 * pctCorrect));
System.out.println(String.format("Incorrectly Classified Instances: %.02f%%", 100 * pctIncorrect));
Expand Down
1 change: 1 addition & 0 deletions src/shared/tester/Comparison.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public boolean isAllCorrect() {
}
return equals;
}

/**
* A generic comparison function. This should work for continuous, discrete and boolean
* output values, but will not make inferences for boolean or discrete values represented
Expand Down
Loading

0 comments on commit 9a1184c

Please sign in to comment.