forked from pushkar/ABAGAIL
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DataSetFilter implemented to split a dataset into a testing and training set Added code to recombine separated binary labels into discrete labels (for the test metrics) Fixed bugs in metrics introduced by splitting labels into binary labels Added a framework for defining experiment runners, which expose a standard set of test metrics, and support writing to a CSV file Added a CSV writer and writer framework, to write data to a file
- Loading branch information
Showing
12 changed files
with
443 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ ABAGAIL.jar | |
*.jar | ||
*.war | ||
*.ear | ||
manifest.mf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package shared.filt; | ||
|
||
import shared.DataSet; | ||
import shared.Instance; | ||
|
||
public class TestTrainSplitFilter implements DataSetFilter { | ||
|
||
private double pctTrain; | ||
private DataSet trainingSet; | ||
private DataSet testingSet; | ||
|
||
/** | ||
* | ||
* | ||
* @param pctTrain A percentage from 0 to 100 | ||
*/ | ||
public TestTrainSplitFilter(int pctTrain) { | ||
this.pctTrain = 1.0 * pctTrain / 100; // | ||
} | ||
|
||
@Override | ||
public void filter(DataSet dataSet) { | ||
int totalInstances = dataSet.getInstances().length; | ||
int trainInstances = (int) (totalInstances * pctTrain); | ||
int testInstances = totalInstances - trainInstances; | ||
Instance[] train = new Instance[trainInstances]; | ||
Instance[] test = new Instance[testInstances]; | ||
for (int ii = 0; ii < trainInstances; ii++) { | ||
train[ii] = dataSet.get(ii); | ||
} | ||
for (int ii = trainInstances; ii < totalInstances; ii++) { | ||
test[ii - trainInstances] = dataSet.get(ii); | ||
} | ||
|
||
this.trainingSet = new DataSet(train); | ||
this.testingSet = new DataSet(test); | ||
} | ||
|
||
public DataSet getTrainingSet() { | ||
return this.trainingSet; | ||
} | ||
|
||
public DataSet getTestingSet() { | ||
return this.testingSet; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
package shared.runner; | ||
|
||
import java.io.File; | ||
|
||
import shared.tester.AccuracyTestMetric; | ||
import shared.tester.ConfusionMatrixTestMetric; | ||
import shared.writer.CSVWriter; | ||
import shared.writer.Writer; | ||
import util.TimeUtil; | ||
|
||
/** | ||
* A runner for multiple tests/experiments. This class takes in a Runner, an array of iteration values to use, | ||
* and an array of test/train splits to use | ||
* | ||
* @author Jesse Rosalia <https://github.com/theJenix> | ||
* @date 2013-03-06 | ||
*/ | ||
public class MultiRunner { | ||
|
||
private Runner runner; | ||
private int[] iterArray; | ||
private int[] pctTrainArray; | ||
private Writer writer; | ||
private File outputFolder; | ||
|
||
public MultiRunner(Runner runner, int[] iterArray, int[] pctTrainArray) { | ||
this.runner = runner; | ||
this.iterArray = iterArray; | ||
this.pctTrainArray = pctTrainArray; | ||
} | ||
|
||
/** | ||
* Run all combinations of iterations and test/train splits, and output the results. | ||
* | ||
* @throws Exception | ||
*/ | ||
public void runAll() throws Exception { | ||
String[] outputFields = { | ||
"iterations", | ||
"% train", | ||
"% correct", | ||
// "% incorrect", | ||
}; | ||
Writer writer = null; | ||
if (this.outputFolder != null) { | ||
File outputFile = new File(this.outputFolder, String.format("%s.csv", runner.getName())); | ||
writer = new CSVWriter(outputFile.getAbsolutePath(), outputFields); | ||
writer.open(); | ||
} | ||
for (int iterations : iterArray) { | ||
for (int pctTrain : pctTrainArray) { | ||
runner.run(iterations, pctTrain); | ||
|
||
AccuracyTestMetric am = runner.getAccuracyMetric(); | ||
ConfusionMatrixTestMetric cm = runner.getConfusionMatrix(); | ||
|
||
//print results to the console | ||
String trainTime = TimeUtil.formatTime(runner.getTrainingTime()); | ||
String testTime = TimeUtil.formatTime(runner.getTestTime()); | ||
am.printResults(); | ||
System.out.println("Training time: " + trainTime); | ||
System.out.println("Testing time: " + testTime); | ||
System.out.println("Number of iterations: " + iterations); | ||
if (pctTrain > 0 && pctTrain < 100) { | ||
System.out.println(String.format("%02d%% training / %02d%% testing", pctTrain, 100 - pctTrain)); | ||
} else { | ||
System.out.println("Testing using training set."); | ||
} | ||
cm.printResults(); | ||
System.out.println(); | ||
|
||
//write results to a file if available | ||
if (writer != null) { | ||
writer.write("" + iterations); | ||
writer.write("" + pctTrain); | ||
writer.write("" + am.getPctCorrect()); | ||
writer.nextRecord(); | ||
} | ||
} | ||
} | ||
if (writer != null) { | ||
writer.close(); | ||
} | ||
} | ||
|
||
/** | ||
* Set the output folder for results files. If not set, this class | ||
* will not output any data to file. | ||
* | ||
* @param outputFolder | ||
*/ | ||
public void setOutputFolder(File outputFolder) { | ||
this.outputFolder = outputFolder; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package shared.runner; | ||
import shared.tester.AccuracyTestMetric; | ||
import shared.tester.ConfusionMatrixTestMetric; | ||
import shared.tester.RawOutputTestMetric; | ||
|
||
/** | ||
* A runner for a given experiment or test. The runner will be responsible | ||
* for constructing a classifier, loading the data into a DataSet, constructing | ||
* the necessary objects to train and test the classifier, and collecting the | ||
* results in a series of TestMetric objects. | ||
* | ||
* NOTE: most of the metrics in this API refer to the last call to run, and are | ||
* replaced after each subsequent call. You must call run at least once for them | ||
* to be valid. | ||
* | ||
* @author Jesse Rosalia <https://github.com/theJenix> | ||
* @date 2013-03-07 | ||
*/ | ||
public interface Runner { | ||
|
||
/** | ||
* Get the accuracy metric test metric for the last run, for reporting % correct | ||
* and % incorrect. | ||
* | ||
* @return | ||
*/ | ||
public AccuracyTestMetric getAccuracyMetric(); | ||
|
||
/** | ||
* Get the confusion matrix for the last run. | ||
* | ||
* @return | ||
*/ | ||
public ConfusionMatrixTestMetric getConfusionMatrix(); | ||
|
||
/** | ||
* Get a name for this runner. This is likely the name of the implementation | ||
* combined with the name of the data set. | ||
* | ||
* @return | ||
*/ | ||
public String getName(); | ||
|
||
/** | ||
* Get the raw output metric for the last run. | ||
* | ||
* @return | ||
*/ | ||
public RawOutputTestMetric getRawOutput(); | ||
|
||
/** | ||
* Get the training time for the last run. | ||
* | ||
* @return | ||
*/ | ||
public long getTrainingTime(); | ||
|
||
/** | ||
* Get the testing time for the last run. | ||
* | ||
* @return | ||
*/ | ||
public long getTestTime(); | ||
|
||
/** | ||
* Run the runner with the specified number of training iterations and | ||
* specified train/test split. | ||
* | ||
* @param iterations The number of iterations used to train the classifier. | ||
* @param pctTrain The % of the data to use in a training set (the remainder is | ||
* in the testing set). Note that values of 0 or 100 will result in the whole | ||
* set being used for training and testing. | ||
* | ||
* @throws Exception | ||
*/ | ||
public void run(int iterations, int pctTrain) throws Exception; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.