From 9a1184c917204e4a533fabea1de6e72f68a60ef4 Mon Sep 17 00:00:00 2001 From: Jesse Rosalia Date: Thu, 7 Mar 2013 14:08:15 -0500 Subject: [PATCH] Merge from theJenix/ABAGAIL: DataSetFilter implemented to split a dataset into a testing and training set Added code to recombine separated binary labels into discrete labels (for the test metrics) Fixed bugs in metrics introduced by splitting labels into binary labels Added a framework for defining experiment runners, which expose a standard set of test metrics, and support writing to a CSV file Added a CSV writer and writer framework, to write data to a file --- .gitignore | 1 + src/shared/filt/TestTrainSplitFilter.java | 46 +++++++++ .../reader/DataSetLabelBinarySeperator.java | 39 +++++++- src/shared/runner/MultiRunner.java | 95 +++++++++++++++++++ src/shared/runner/Runner.java | 77 +++++++++++++++ src/shared/tester/AccuracyTestMetric.java | 18 ++-- src/shared/tester/Comparison.java | 1 + .../tester/ConfusionMatrixTestMetric.java | 50 +++++++--- src/shared/tester/NeuralNetworkTester.java | 10 +- src/shared/writer/CSVWriter.java | 68 +++++++++++++ src/shared/writer/Writer.java | 47 +++++++++ src/util/TimeUtil.java | 17 ++++ 12 files changed, 443 insertions(+), 26 deletions(-) create mode 100644 src/shared/filt/TestTrainSplitFilter.java create mode 100644 src/shared/runner/MultiRunner.java create mode 100644 src/shared/runner/Runner.java create mode 100644 src/shared/writer/CSVWriter.java create mode 100644 src/shared/writer/Writer.java create mode 100644 src/util/TimeUtil.java diff --git a/.gitignore b/.gitignore index 0779e5ff..b95ddf71 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ ABAGAIL.jar *.jar *.war *.ear +manifest.mf diff --git a/src/shared/filt/TestTrainSplitFilter.java b/src/shared/filt/TestTrainSplitFilter.java new file mode 100644 index 00000000..de7d5700 --- /dev/null +++ b/src/shared/filt/TestTrainSplitFilter.java @@ -0,0 +1,46 @@ +package shared.filt; + +import shared.DataSet; +import shared.Instance; + +public class TestTrainSplitFilter implements DataSetFilter { + + private double pctTrain; + private DataSet trainingSet; + private DataSet testingSet; + + /** + * + * + * @param pctTrain A percentage from 0 to 100 + */ + public TestTrainSplitFilter(int pctTrain) { + this.pctTrain = 1.0 * pctTrain / 100; // + } + + @Override + public void filter(DataSet dataSet) { + int totalInstances = dataSet.getInstances().length; + int trainInstances = (int) (totalInstances * pctTrain); + int testInstances = totalInstances - trainInstances; + Instance[] train = new Instance[trainInstances]; + Instance[] test = new Instance[testInstances]; + for (int ii = 0; ii < trainInstances; ii++) { + train[ii] = dataSet.get(ii); + } + for (int ii = trainInstances; ii < totalInstances; ii++) { + test[ii - trainInstances] = dataSet.get(ii); + } + + this.trainingSet = new DataSet(train); + this.testingSet = new DataSet(test); + } + + public DataSet getTrainingSet() { + return this.trainingSet; + } + + public DataSet getTestingSet() { + return this.testingSet; + } +} diff --git a/src/shared/reader/DataSetLabelBinarySeperator.java b/src/shared/reader/DataSetLabelBinarySeperator.java index bc68436a..52839ad3 100644 --- a/src/shared/reader/DataSetLabelBinarySeperator.java +++ b/src/shared/reader/DataSetLabelBinarySeperator.java @@ -5,15 +5,17 @@ import shared.DataSet; import shared.Instance; +import shared.tester.Comparison; /** - * Separates Labels into Binary representation for better use in Neural Networks + * Separates Discrete Labels into Binary representation for better use in Neural Networks * @author Alex Linton * @date 2013-03-05 */ public class DataSetLabelBinarySeperator { - public static void seperateLabels(DataSet set){ + + public static void seperateLabels(DataSet set) { int numberOfLabels = 0; ArrayList labels = new ArrayList(); //count up the number of distinct labels @@ -35,4 +37,37 @@ public static void seperateLabels(DataSet set){ values[labelValue] = 0; } } + + /** + * Combine separated labels into a single valued instance representing + * the output label, based on the max value found in the instance. + * + * NOTE: This assumes labels that were split using separateLabels, and + * a function that maps output values to + * + * @param instance + * @return + */ + public static Instance combineLabels(Instance instance) { + //if it's already a size 1 instance, we assume it's already collapsed...otherwise + // the code below will adversely affect the value + if (instance.size() == 1) { + return instance; + } + + //we have values to collapse into a discrete measurement based + // on the instance datapoint with the biggest value. This is meant + // to be a reversal of separateLabels + int maxInx = -1; + double max = 0; + for (int ii = 0; ii < instance.size(); ii++) { + double inst = instance.getContinuous(ii); + if (inst > max) { + maxInx = ii; + max = inst; + } + } + //max will be the max value (between 0 and 1), and maxInx will be + return new Instance(max + maxInx - 1); + } } diff --git a/src/shared/runner/MultiRunner.java b/src/shared/runner/MultiRunner.java new file mode 100644 index 00000000..625f1eba --- /dev/null +++ b/src/shared/runner/MultiRunner.java @@ -0,0 +1,95 @@ +package shared.runner; + +import java.io.File; + +import shared.tester.AccuracyTestMetric; +import shared.tester.ConfusionMatrixTestMetric; +import shared.writer.CSVWriter; +import shared.writer.Writer; +import util.TimeUtil; + +/** + * A runner for multiple tests/experiments. This class takes in a Runner, an array of iteration values to use, + * and an array of test/train splits to use + * + * @author Jesse Rosalia + * @date 2013-03-06 + */ +public class MultiRunner { + + private Runner runner; + private int[] iterArray; + private int[] pctTrainArray; + private Writer writer; + private File outputFolder; + + public MultiRunner(Runner runner, int[] iterArray, int[] pctTrainArray) { + this.runner = runner; + this.iterArray = iterArray; + this.pctTrainArray = pctTrainArray; + } + + /** + * Run all combinations of iterations and test/train splits, and output the results. + * + * @throws Exception + */ + public void runAll() throws Exception { + String[] outputFields = { + "iterations", + "% train", + "% correct", +// "% incorrect", + }; + Writer writer = null; + if (this.outputFolder != null) { + File outputFile = new File(this.outputFolder, String.format("%s.csv", runner.getName())); + writer = new CSVWriter(outputFile.getAbsolutePath(), outputFields); + writer.open(); + } + for (int iterations : iterArray) { + for (int pctTrain : pctTrainArray) { + runner.run(iterations, pctTrain); + + AccuracyTestMetric am = runner.getAccuracyMetric(); + ConfusionMatrixTestMetric cm = runner.getConfusionMatrix(); + + //print results to the console + String trainTime = TimeUtil.formatTime(runner.getTrainingTime()); + String testTime = TimeUtil.formatTime(runner.getTestTime()); + am.printResults(); + System.out.println("Training time: " + trainTime); + System.out.println("Testing time: " + testTime); + System.out.println("Number of iterations: " + iterations); + if (pctTrain > 0 && pctTrain < 100) { + System.out.println(String.format("%02d%% training / %02d%% testing", pctTrain, 100 - pctTrain)); + } else { + System.out.println("Testing using training set."); + } + cm.printResults(); + System.out.println(); + + //write results to a file if available + if (writer != null) { + writer.write("" + iterations); + writer.write("" + pctTrain); + writer.write("" + am.getPctCorrect()); + writer.nextRecord(); + } + } + } + if (writer != null) { + writer.close(); + } + } + + /** + * Set the output folder for results files. If not set, this class + * will not output any data to file. + * + * @param outputFolder + */ + public void setOutputFolder(File outputFolder) { + this.outputFolder = outputFolder; + } +} diff --git a/src/shared/runner/Runner.java b/src/shared/runner/Runner.java new file mode 100644 index 00000000..2ea67c58 --- /dev/null +++ b/src/shared/runner/Runner.java @@ -0,0 +1,77 @@ +package shared.runner; +import shared.tester.AccuracyTestMetric; +import shared.tester.ConfusionMatrixTestMetric; +import shared.tester.RawOutputTestMetric; + +/** + * A runner for a given experiment or test. The runner will be responsible + * for constructing a classifier, loading the data into a DataSet, constructing + * the necessary objects to train and test the classifier, and collecting the + * results in a series of TestMetric objects. + * + * NOTE: most of the metrics in this API refer to the last call to run, and are + * replaced after each subsequent call. You must call run at least once for them + * to be valid. + * + * @author Jesse Rosalia + * @date 2013-03-07 + */ +public interface Runner { + + /** + * Get the accuracy metric test metric for the last run, for reporting % correct + * and % incorrect. + * + * @return + */ + public AccuracyTestMetric getAccuracyMetric(); + + /** + * Get the confusion matrix for the last run. + * + * @return + */ + public ConfusionMatrixTestMetric getConfusionMatrix(); + + /** + * Get a name for this runner. This is likely the name of the implementation + * combined with the name of the data set. + * + * @return + */ + public String getName(); + + /** + * Get the raw output metric for the last run. + * + * @return + */ + public RawOutputTestMetric getRawOutput(); + + /** + * Get the training time for the last run. + * + * @return + */ + public long getTrainingTime(); + + /** + * Get the testing time for the last run. + * + * @return + */ + public long getTestTime(); + + /** + * Run the runner with the specified number of training iterations and + * specified train/test split. + * + * @param iterations The number of iterations used to train the classifier. + * @param pctTrain The % of the data to use in a training set (the remainder is + * in the testing set). Note that values of 0 or 100 will result in the whole + * set being used for training and testing. + * + * @throws Exception + */ + public void run(int iterations, int pctTrain) throws Exception; +} \ No newline at end of file diff --git a/src/shared/tester/AccuracyTestMetric.java b/src/shared/tester/AccuracyTestMetric.java index 0ee9c906..badb5830 100644 --- a/src/shared/tester/AccuracyTestMetric.java +++ b/src/shared/tester/AccuracyTestMetric.java @@ -1,6 +1,7 @@ package shared.tester; import shared.Instance; +import shared.reader.DataSetLabelBinarySeperator; import util.linalg.Vector; /** @@ -17,20 +18,21 @@ public class AccuracyTestMetric implements TestMetric { @Override public void addResult(Instance expected, Instance actual) { Comparison c = new Comparison(expected, actual); - for (int ii = 0; ii < expected.size(); ii++) { - //count up one for each instance - count++; - if (c.isCorrect(ii)) { - //count up one for each correct instance - countCorrect++; - } + + count++; + if (c.isAllCorrect()) { + countCorrect++; } } + public double getPctCorrect() { + return count > 0 ? ((double)countCorrect)/count : 1; //if count is 0, we consider it all correct + } + public void printResults() { //only report results if there were any results to report. if (count > 0) { - double pctCorrect = ((double)countCorrect)/count; + double pctCorrect = getPctCorrect(); double pctIncorrect = (1 - pctCorrect); System.out.println(String.format("Correctly Classified Instances: %.02f%%", 100 * pctCorrect)); System.out.println(String.format("Incorrectly Classified Instances: %.02f%%", 100 * pctIncorrect)); diff --git a/src/shared/tester/Comparison.java b/src/shared/tester/Comparison.java index 837d73e1..a56cbe84 100644 --- a/src/shared/tester/Comparison.java +++ b/src/shared/tester/Comparison.java @@ -41,6 +41,7 @@ public boolean isAllCorrect() { } return equals; } + /** * A generic comparison function. This should work for continuous, discrete and boolean * output values, but will not make inferences for boolean or discrete values represented diff --git a/src/shared/tester/ConfusionMatrixTestMetric.java b/src/shared/tester/ConfusionMatrixTestMetric.java index 4369e78b..29cd83fa 100644 --- a/src/shared/tester/ConfusionMatrixTestMetric.java +++ b/src/shared/tester/ConfusionMatrixTestMetric.java @@ -3,7 +3,10 @@ import java.util.HashMap; import java.util.Map; +import shared.AttributeType; +import shared.DataSetDescription; import shared.Instance; +import shared.reader.DataSetLabelBinarySeperator; /** * A test metric to generate a confusion matrix. This metric expects the true labels @@ -111,22 +114,37 @@ public ConfusionMatrixTestMetric(boolean[] labels) { } } - @Override - public void addResult(Instance expected, Instance actual) { - Comparison c = new Comparison(expected, actual); - for (int ii = 0; ii < expected.size(); ii++) { - - //find the actual value in the list of classes - //...this makes sure we work with homogeneous label values, so our - // matrix is readable. - Instance found = findLabel(this.labels, actual); - MatrixEntry e = new MatrixEntry(expected, found); - if (matrix.containsKey(e)) { - matrix.put(e, matrix.get(e) + 1); - } else { - matrix.put(e, 1); + /** + * Construct the test metric with discrete values, contained in the label desc. + * + * @param labelDesc + */ + public ConfusionMatrixTestMetric(DataSetDescription labelDesc) { + for (AttributeType type : labelDesc.getAttributeTypes()) { + if (type == AttributeType.CONTINUOUS) { + throw new IllegalStateException("This metric only works with discrete or binary labels"); } - + } + int range = labelDesc.getDiscreteRange(); + this.labels = new Instance[range]; + this.labelStrs = new String [range]; + for (int i = 0; i < labelDesc.getDiscreteRange(); i++) { + this.labels [i] = new Instance(i); + this.labelStrs[i] = Integer.toString(i); + } + } + + @Override + public void addResult(Instance expected, Instance actual) { + //find the actual value in the list of classes + //...this makes sure we work with homogeneous label values, so our + // matrix is readable. + Instance found = findLabel(this.labels, actual); + MatrixEntry e = new MatrixEntry(expected, found); + if (matrix.containsKey(e)) { + matrix.put(e, matrix.get(e) + 1); + } else { + matrix.put(e, 1); } } @@ -183,5 +201,7 @@ public void printResults() { System.out.print("\t"); System.out.print(val); } + + System.out.println(); } } diff --git a/src/shared/tester/NeuralNetworkTester.java b/src/shared/tester/NeuralNetworkTester.java index c2044de2..6c696ace 100644 --- a/src/shared/tester/NeuralNetworkTester.java +++ b/src/shared/tester/NeuralNetworkTester.java @@ -1,6 +1,7 @@ package shared.tester; import shared.Instance; +import shared.reader.DataSetLabelBinarySeperator; import func.nn.NeuralNetwork; /** @@ -31,9 +32,16 @@ public void test(Instance[] instances) { Instance expected = instances[i].getLabel(); Instance actual = new Instance(network.getOutputValues()); + //collapse the values, for statistics reporting + //NOTE: assumes discrete labels, with n output nodes for n + // potential labels, and an activation function that outputs + // values between 0 and 1. + Instance expectedOne = DataSetLabelBinarySeperator.combineLabels(expected); + Instance actualOne = DataSetLabelBinarySeperator.combineLabels(actual); + //run this result past all of the available test metrics for (TestMetric metric : metrics) { - metric.addResult(expected, actual); + metric.addResult(expectedOne, actualOne); } } } diff --git a/src/shared/writer/CSVWriter.java b/src/shared/writer/CSVWriter.java new file mode 100644 index 00000000..f4e429f9 --- /dev/null +++ b/src/shared/writer/CSVWriter.java @@ -0,0 +1,68 @@ +package shared.writer; + +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Write arbitrary data to a CSV file. This is used to write results out, + * to be consumed by another program (GNUPlot, etc). + * + * @author Jesse Rosalia + * @date 2013-03-07 + * + */ +public class CSVWriter implements Writer { + + private String fileName; + private List fields; + private List buffer; + private FileWriter fileWriter; + + public CSVWriter(String fileName, String[] fields) { + this.fileName = fileName; + this.fields = Arrays.asList(fields); + this.buffer = new ArrayList(); + } + + @Override + public void close() throws IOException { + this.fileWriter.close(); + } + + @Override + public void open() throws IOException { + this.fileWriter = new FileWriter(fileName); + writeRow(this.fields); + } + + /** + * @param toWrite + * @throws IOException + */ + private void writeRow(List toWrite) throws IOException { + boolean addComma = false; + for (String field : toWrite) { + if (addComma) { + this.fileWriter.append(","); + } + this.fileWriter.append(field); + addComma = true; + } + this.fileWriter.append('\n'); + } + + @Override + public void write(String str) throws IOException { + this.buffer.add(str); + } + + @Override + public void nextRecord() throws IOException { + writeRow(buffer); + //clear the buffer for the next record + buffer.clear(); + } +} diff --git a/src/shared/writer/Writer.java b/src/shared/writer/Writer.java new file mode 100644 index 00000000..2bcc6452 --- /dev/null +++ b/src/shared/writer/Writer.java @@ -0,0 +1,47 @@ +package shared.writer; + +import java.io.IOException; + +/** + * This interface defines an API for a Writer object. Writers are used to write results + * to a file of a certain type. + * + * The writer lets a caller write to a given record, or advance to the next record. + * As an example, a CSVWriter might consider each line a record. A user can write + * to a line, which will create comma separated values. The call to nextRecord + * will then go to the next line. + * + * @author Jesse Rosalia + * @date 2013-03-07 + */ +public interface Writer { + + /** + * Close a writer and flush it's contents. + * + * @throws IOException + */ + public void close() throws IOException; + + /** + * Open a writer for writing. + * + * @throws IOException + */ + public void open() throws IOException; + + /** + * Write a datapoint to a record. + * + * @param str + * @throws IOException + */ + public void write(String str) throws IOException; + + /** + * Advance to the next record. + * + * @throws IOException + */ + public void nextRecord() throws IOException; +} diff --git a/src/util/TimeUtil.java b/src/util/TimeUtil.java new file mode 100644 index 00000000..b100e4b0 --- /dev/null +++ b/src/util/TimeUtil.java @@ -0,0 +1,17 @@ +package util; + +/** + * A utility for preparing and presenting run time metrics. + * + * @author Jesse Rosalia + * @date 2013-03-07 + */ +public class TimeUtil { + + public static String formatTime(long time) { + long secs = ((long) time) / 1000; + long min = secs / 60; + secs -= min * 60; + return String.format("%02d:%02d", min, secs); + } +}