forked from pushkar/ABAGAIL
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request pushkar#4 from theJenix/master
Bug fixes, improvements to dimensionality reduction classes
- Loading branch information
Showing
10 changed files
with
304 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
package shared; | ||
|
||
import java.util.Arrays; | ||
import java.util.Iterator; | ||
|
||
|
||
/** | ||
* A data set is just a collection of instances | ||
* @author Andrew Guillory [email protected] | ||
* @version 1.0 | ||
*/ | ||
public class DataSet { | ||
public class DataSet implements Copyable, Iterable<Instance> { | ||
/** | ||
* The list of instances | ||
*/ | ||
|
@@ -122,5 +125,19 @@ public String toString() { | |
return result; | ||
} | ||
|
||
@Override | ||
public DataSet copy() { | ||
Instance[] copy = new Instance[this.size()]; | ||
for (int i = 0; i < copy.length; i++) { | ||
copy[i] = (Instance) this.get(i).copy(); | ||
} | ||
DataSet newSet = new DataSet(copy); | ||
newSet.setDescription(new DataSetDescription(newSet)); | ||
return newSet; | ||
} | ||
|
||
@Override | ||
public Iterator<Instance> iterator() { | ||
return Arrays.asList(instances).iterator(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package shared.filt; | ||
|
||
import shared.DataSet; | ||
import shared.DataSetDescription; | ||
import shared.Instance; | ||
import util.linalg.Vector; | ||
|
||
/** | ||
* A filter that selects a specified value as the label. | ||
* This is useful for processing datasets with an extra | ||
* attribute appended to the end (such as files Weka | ||
* spits out with the cluster appended to each instance) | ||
* | ||
* @author Jesse Rosalia <https://github.com/theJenix> | ||
*/ | ||
public class LabelSelectFilter implements DataSetFilter { | ||
/** | ||
* The size of the data | ||
*/ | ||
private int labelIndex; | ||
|
||
/** | ||
* Make a new label select filter | ||
* @param labelIndex the index of the value to use as the label | ||
*/ | ||
public LabelSelectFilter(int labelIndex) { | ||
this.labelIndex = labelIndex; | ||
} | ||
|
||
/** | ||
* @see shared.filt.DataSetFilter#filter(shared.DataSet) | ||
*/ | ||
public void filter(DataSet dataSet) { | ||
int dataCount = dataSet.get(0).size() - labelIndex; | ||
for (int i = 0; i < dataSet.size(); i++) { | ||
Instance instance = dataSet.get(i); | ||
Vector input = | ||
instance.getData().get(0, instance.getData().size()); | ||
double output = | ||
instance.getData().get(this.labelIndex); | ||
input = input.remove(this.labelIndex); | ||
instance.setData(input); | ||
instance.setLabel(new Instance(output)); | ||
} | ||
dataSet.setDescription(new DataSetDescription(dataSet)); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package shared.filt; | ||
import util.linalg.Matrix; | ||
import util.linalg.Vector; | ||
|
||
/** | ||
* A helper class used in PCA and other dimensionality reduction filters. | ||
* | ||
* This class assumes that the eigenvalues represent the % of variation captured by each | ||
* component, and can therefore determine how many components to keep to capture | ||
* a specified % of variance | ||
* | ||
* @author Jesse Rosalia | ||
* | ||
*/ | ||
public class VarianceCounter { | ||
|
||
private Matrix eigenValues; | ||
private double sum = 0; | ||
|
||
public VarianceCounter(Matrix eigenValues) { | ||
if (eigenValues.m() != eigenValues.n() || ! isDiagonal(eigenValues)) { | ||
throw new IllegalStateException("Expected a square diagonal matrix"); | ||
} | ||
this.eigenValues = eigenValues; | ||
for (int ii = 0; ii < eigenValues.m(); ii++) { | ||
sum += eigenValues.get(ii, ii); | ||
} | ||
} | ||
|
||
/** | ||
* Count from the left...this captures the biggest components first. | ||
* | ||
* @param varianceToKeep | ||
* @return | ||
*/ | ||
public int countLeft(double varianceToKeep) { | ||
int toKeep = 0; | ||
double kept = 0; | ||
for (int ii = 0; ii < eigenValues.m(); ii++) { | ||
double var = eigenValues.get(ii, ii) / sum; | ||
if (kept + var > varianceToKeep) { | ||
break; | ||
} | ||
toKeep = ii; | ||
kept += var; | ||
} | ||
|
||
return toKeep; | ||
} | ||
|
||
/** | ||
* Count from the right...this captures the smallest components first. | ||
* | ||
* @param varianceToKeep | ||
* @return | ||
*/ | ||
public int countRight(double varianceToKeep) { | ||
int toKeep = 0; | ||
double kept = 0; | ||
for (int ii = eigenValues.m() - 1; ii >= 0; ii--) { | ||
double var = eigenValues.get(ii, ii) / sum; | ||
if (kept + var > varianceToKeep) { | ||
break; | ||
} | ||
toKeep = ii; | ||
kept += var; | ||
} | ||
|
||
return toKeep; | ||
} | ||
|
||
private boolean isDiagonal(Matrix eigenValues) { | ||
boolean diagonal = true; | ||
for (int ii = 0; ii < eigenValues.m(); ii++) { | ||
Vector v = eigenValues.getColumn(ii); | ||
if (v.sum() != v.get(ii)) { | ||
diagonal = false; | ||
break; | ||
} | ||
} | ||
return diagonal; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package shared.test; | ||
|
||
import java.io.File; | ||
|
||
import shared.DataSet; | ||
import shared.DataSetDescription; | ||
import shared.filt.ContinuousToDiscreteFilter; | ||
import shared.filt.LabelSelectFilter; | ||
import shared.filt.LabelSplitFilter; | ||
import shared.reader.ArffDataSetReader; | ||
import shared.reader.DataSetLabelBinarySeperator; | ||
import shared.reader.DataSetReader; | ||
|
||
public class LabelSelectFilterTest { | ||
/** | ||
* The test main | ||
* @param args ignored parameters | ||
*/ | ||
public static void main(String[] args) throws Exception { | ||
DataSetReader dsr = new ArffDataSetReader(new File("").getAbsolutePath() + "/src/shared/test/abalone.arff"); | ||
// read in the raw data | ||
DataSet ds = dsr.read(); | ||
// split out the label | ||
LabelSelectFilter lsf = new LabelSelectFilter(1); | ||
lsf.filter(ds); | ||
ContinuousToDiscreteFilter ctdf = new ContinuousToDiscreteFilter(10); | ||
ctdf.filter(ds); | ||
System.out.println(ds); | ||
System.out.println(new DataSetDescription(ds)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.