Skip to content

Commit

Permalink
Merge pull request pushkar#4 from theJenix/master
Browse files Browse the repository at this point in the history
Bug fixes, improvements to dimensionality reduction classes
  • Loading branch information
pushkar committed Mar 31, 2013
2 parents a9abda1 + f8f8cf6 commit 1a93c08
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 8 deletions.
19 changes: 18 additions & 1 deletion src/shared/DataSet.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package shared;

import java.util.Arrays;
import java.util.Iterator;


/**
* A data set is just a collection of instances
* @author Andrew Guillory [email protected]
* @version 1.0
*/
public class DataSet {
public class DataSet implements Copyable, Iterable<Instance> {
/**
* The list of instances
*/
Expand Down Expand Up @@ -122,5 +125,19 @@ public String toString() {
return result;
}

@Override
public DataSet copy() {
Instance[] copy = new Instance[this.size()];
for (int i = 0; i < copy.length; i++) {
copy[i] = (Instance) this.get(i).copy();
}
DataSet newSet = new DataSet(copy);
newSet.setDescription(new DataSetDescription(newSet));
return newSet;
}

@Override
public Iterator<Instance> iterator() {
return Arrays.asList(instances).iterator();
}
}
57 changes: 52 additions & 5 deletions src/shared/DataSetWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,78 @@ public class DataSetWriter {
*/
private String filename;

/**
* True to append the results to the end of
* the file, false to overwrite. This is
* useful if we need to write some kind of
* header to the file before writing the
* dataset.
*
*/
private boolean append;

private String[] labelStrings;

/**
* Make a new data set writer
* @param set the data set to writer
*/
public DataSetWriter(DataSet set, String filename) {
this.set = set;
this.filename = filename;
this.append = false;
this.labelStrings = null;
}

/**
* Make a new data set writer
* @param set the data set to writer
*/
public DataSetWriter(DataSet set, String filename, boolean append) {
this.set = set;
this.filename = filename;
this.append = append;
this.labelStrings = null;
}

/**
* Make a new data set writer
* @param set the data set to writer
*/
public DataSetWriter(DataSet set, String filename, boolean append, String[] labelStrings) {
this.set = set;
this.filename = filename;
this.append = append;
this.labelStrings = labelStrings;
}

/**
* Write the file out
* @throws IOException when something goes bad
*/
public void write() throws IOException {
PrintWriter pw = new PrintWriter(new FileWriter(filename));
PrintWriter pw = new PrintWriter(new FileWriter(filename, this.append));
for (int i = 0; i < set.size(); i++) {
Instance data = set.get(i);
boolean label = false;
while (data != null) {
for (int j = 0; j < data.size(); j++) {
pw.print(data.getContinuous(j));
if (j + 1 < data.size() || data.getLabel() != null) {
pw.print(", ");
if (label && this.labelStrings != null) {
for (int j = 0; j < data.size(); j++) {
pw.print(this.labelStrings[data.getDiscrete(j)]);
if (j + 1 < data.size() || data.getLabel() != null) {
pw.print(", ");
}
}
} else {
for (int j = 0; j < data.size(); j++) {
pw.print(data.getContinuous(j));
if (j + 1 < data.size() || data.getLabel() != null) {
pw.print(", ");
}
}
}
data = data.getLabel();
label = true;
}
pw.println();
}
Expand Down
19 changes: 19 additions & 0 deletions src/shared/filt/InsignificantComponentAnalysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ public InsignificantComponentAnalysis(DataSet dataSet, int toKeep, double thresh
}
}

public InsignificantComponentAnalysis(DataSet dataSet, double varianceToKeep) {
MultivariateGaussian mg = new MultivariateGaussian();
mg.estimate(dataSet);
Matrix covarianceMatrix = mg.getCovarianceMatrix();
mean = mg.getMean();

SymmetricEigenvalueDecomposition sed =
new SymmetricEigenvalueDecomposition(covarianceMatrix);
Matrix eigenVectors = sed.getU();
eigenValues = sed.getD();

VarianceCounter vc = new VarianceCounter(eigenValues);
int toKeep = vc.countRight(varianceToKeep);
projection = new RectangularMatrix(toKeep, eigenVectors.m());
for (int i = eigenVectors.m() - 1; eigenVectors.m() - i - 1 < toKeep; i--) {
projection.setRow(eigenVectors.m() - i - 1, eigenVectors.getColumn(i));
}
}

/**
* Make a new PCA filter
* @param numberOfComponents the number to keep
Expand Down
48 changes: 48 additions & 0 deletions src/shared/filt/LabelSelectFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package shared.filt;

import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
import util.linalg.Vector;

/**
* A filter that selects a specified value as the label.
* This is useful for processing datasets with an extra
* attribute appended to the end (such as files Weka
* spits out with the cluster appended to each instance)
*
* @author Jesse Rosalia <https://github.com/theJenix>
*/
public class LabelSelectFilter implements DataSetFilter {
/**
* The size of the data
*/
private int labelIndex;

/**
* Make a new label select filter
* @param labelIndex the index of the value to use as the label
*/
public LabelSelectFilter(int labelIndex) {
this.labelIndex = labelIndex;
}

/**
* @see shared.filt.DataSetFilter#filter(shared.DataSet)
*/
public void filter(DataSet dataSet) {
int dataCount = dataSet.get(0).size() - labelIndex;
for (int i = 0; i < dataSet.size(); i++) {
Instance instance = dataSet.get(i);
Vector input =
instance.getData().get(0, instance.getData().size());
double output =
instance.getData().get(this.labelIndex);
input = input.remove(this.labelIndex);
instance.setData(input);
instance.setLabel(new Instance(output));
}
dataSet.setDescription(new DataSetDescription(dataSet));
}

}
29 changes: 28 additions & 1 deletion src/shared/filt/PrincipalComponentAnalysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dist.MultivariateGaussian;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
import util.linalg.Matrix;
import util.linalg.RectangularMatrix;
Expand Down Expand Up @@ -63,6 +64,32 @@ public PrincipalComponentAnalysis(DataSet dataSet, int toKeep, double threshold)
}
}

/**
* Make a new PCA filter
* @param varianceToKeep The % variance to keep. This assumes that sum(eigenvalues) represents all of the variance, and
* @param dataSet the set form which to estimate components
*/
public PrincipalComponentAnalysis(DataSet dataSet, double varianceToKeep) {
MultivariateGaussian mg = new MultivariateGaussian();
mg.estimate(dataSet);
Matrix covarianceMatrix = mg.getCovarianceMatrix();
mean = mg.getMean();
// if (toKeep == -1) {
// toKeep = mean.size();
// }
SymmetricEigenvalueDecomposition sed =
new SymmetricEigenvalueDecomposition(covarianceMatrix);
Matrix eigenVectors = sed.getU();
eigenValues = sed.getD();

VarianceCounter vc = new VarianceCounter(eigenValues);
int toKeep = vc.countLeft(varianceToKeep);
projection = new RectangularMatrix(toKeep, eigenVectors.m());
for (int i = 0; i < toKeep; i++) {
projection.setRow(i, eigenVectors.getColumn(i));
}
}

/**
* Make a new PCA filter
* @param numberOfComponents the number to keep
Expand Down Expand Up @@ -90,7 +117,7 @@ public void filter(DataSet dataSet) {
instance.setData(instance.getData().minus(mean));
instance.setData(projection.times(instance.getData()));
}
dataSet.setDescription(null);
dataSet.setDescription(new DataSetDescription(dataSet));
}


Expand Down
83 changes: 83 additions & 0 deletions src/shared/filt/VarianceCounter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package shared.filt;
import util.linalg.Matrix;
import util.linalg.Vector;

/**
* A helper class used in PCA and other dimensionality reduction filters.
*
* This class assumes that the eigenvalues represent the % of variation captured by each
* component, and can therefore determine how many components to keep to capture
* a specified % of variance
*
* @author Jesse Rosalia
*
*/
public class VarianceCounter {

private Matrix eigenValues;
private double sum = 0;

public VarianceCounter(Matrix eigenValues) {
if (eigenValues.m() != eigenValues.n() || ! isDiagonal(eigenValues)) {
throw new IllegalStateException("Expected a square diagonal matrix");
}
this.eigenValues = eigenValues;
for (int ii = 0; ii < eigenValues.m(); ii++) {
sum += eigenValues.get(ii, ii);
}
}

/**
* Count from the left...this captures the biggest components first.
*
* @param varianceToKeep
* @return
*/
public int countLeft(double varianceToKeep) {
int toKeep = 0;
double kept = 0;
for (int ii = 0; ii < eigenValues.m(); ii++) {
double var = eigenValues.get(ii, ii) / sum;
if (kept + var > varianceToKeep) {
break;
}
toKeep = ii;
kept += var;
}

return toKeep;
}

/**
* Count from the right...this captures the smallest components first.
*
* @param varianceToKeep
* @return
*/
public int countRight(double varianceToKeep) {
int toKeep = 0;
double kept = 0;
for (int ii = eigenValues.m() - 1; ii >= 0; ii--) {
double var = eigenValues.get(ii, ii) / sum;
if (kept + var > varianceToKeep) {
break;
}
toKeep = ii;
kept += var;
}

return toKeep;
}

private boolean isDiagonal(Matrix eigenValues) {
boolean diagonal = true;
for (int ii = 0; ii < eigenValues.m(); ii++) {
Vector v = eigenValues.getColumn(ii);
if (v.sum() != v.get(ii)) {
diagonal = false;
break;
}
}
return diagonal;
}
}
4 changes: 3 additions & 1 deletion src/shared/reader/ArffDataSetReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ private Instance[] processInstances(BufferedReader in,
String[] values = pattern.split(line.trim());
double[] ins = new double[values.length];
for (int i = 0; i < values.length; i++) {
String v = values[i];
//some values are single quoted (especially in datafiles bundled
// with weka)
String v = values[i].replaceAll("'", "");
// defaulting to 0 if attribute value unknown.
double d = 0;
try {
Expand Down
31 changes: 31 additions & 0 deletions src/shared/test/LabelSelectFilterTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package shared.test;

import java.io.File;

import shared.DataSet;
import shared.DataSetDescription;
import shared.filt.ContinuousToDiscreteFilter;
import shared.filt.LabelSelectFilter;
import shared.filt.LabelSplitFilter;
import shared.reader.ArffDataSetReader;
import shared.reader.DataSetLabelBinarySeperator;
import shared.reader.DataSetReader;

public class LabelSelectFilterTest {
/**
* The test main
* @param args ignored parameters
*/
public static void main(String[] args) throws Exception {
DataSetReader dsr = new ArffDataSetReader(new File("").getAbsolutePath() + "/src/shared/test/abalone.arff");
// read in the raw data
DataSet ds = dsr.read();
// split out the label
LabelSelectFilter lsf = new LabelSelectFilter(1);
lsf.filter(ds);
ContinuousToDiscreteFilter ctdf = new ContinuousToDiscreteFilter(10);
ctdf.filter(ds);
System.out.println(ds);
System.out.println(new DataSetDescription(ds));
}
}
5 changes: 5 additions & 0 deletions src/shared/tester/ConfusionMatrixTestMetric.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ private Instance findLabel(Instance[] labels, Instance toFind) {
}


/**
*
* NOTE: Rows are "expected", columns are "actual"
*
*/
@Override
public void printResults() {
System.out.println("Confusion Matrix:");
Expand Down
Loading

0 comments on commit 1a93c08

Please sign in to comment.