Skip to content

Commit

Permalink
add Unit test #4 TVM, SYM
Browse files Browse the repository at this point in the history
  • Loading branch information
walterxie committed May 24, 2017
1 parent 4aef94f commit 5865377
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 0 deletions.
89 changes: 89 additions & 0 deletions src/test/substmodels/nucleotide/SYMTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package test.substmodels.nucleotide;


import beast.core.Description;
import beast.core.parameter.RealParameter;
import junit.framework.TestCase;
import substmodels.nucleotide.SYM;

/**
* Test SYM matrix exponentiation
*
*/
@Description("Test SYM matrix exponentiation")
public class SYMTest extends TestCase {

/*
* import numpy as np
* from scipy.linalg import expm
*
* piQ = np.diag([.25, .25, .25, .25])
* d = 0.1
* # Q matrix with zeroed diagonal
* XQ = np.matrix([[0, 0.2, 10, .3], [0.2, 0, 0.4, 5], [10, 0.4, 0, 0.5], [0.3, 5, 0.5, 0]])
*
* xx = XQ * piQ
*
* # fill diagonal and normalize by total substitution rate
* q0 = (xx + np.diag(np.squeeze(np.asarray(-np.sum(xx, axis=1))))) / np.sum(piQ * np.sum(xx, axis=1))
* expm(q0 * d)
*/
protected EqualBaseFrequencies test0 = new EqualBaseFrequencies() {
@Override
public Double [] getRates() {
return new Double[] {0.2, 10.0, 0.3, 0.4, 5.0, 0.5};
}

@Override
public double getDistance() {
return 0.1;
}

@Override
public double[] getExpectedResult() {
return new double[]{
0.886360401447, 0.002594129576, 0.107315348219, 0.003730120758,
0.002594129576, 0.93573730447 , 0.004733198723, 0.056935367232,
0.107315348219, 0.004733198723, 0.882087475436, 0.005863977622,
0.003730120758, 0.056935367232, 0.005863977622, 0.933470534387
};
}
};


EqualBaseFrequencies[] all = {test0};

public void testSYM() throws Exception {
for (EqualBaseFrequencies test : all) {

SYM sym = new SYM();
RealParameter symRates = new RealParameter(test.getRates());
sym.initByName("rates", symRates);
sym.printQ(System.out); // to obtain XQ for python script
// for (int i = 0; i < 6; ++i)
// System.out.println("Rate " + sym.getSubstitution(i) + " : " + sym.getRate(i));

assertEquals(false, sym.getRateAC()== sym.getRateAG());
assertEquals(false, sym.getRateAC()== sym.getRateAT());
assertEquals(false, sym.getRateAC()== sym.getRateCG());
assertEquals(false, sym.getRateAC()== sym.getRateCT());
assertEquals(false, sym.getRateAC()== sym.getRateGT());
assertEquals(false, sym.getRateAG()== sym.getRateAT());
assertEquals(false, sym.getRateAG()== sym.getRateCG());
assertEquals(false, sym.getRateAG()== sym.getRateCT());
assertEquals(false, sym.getRateAG()== sym.getRateGT());
assertEquals(false, sym.getRateCG()== sym.getRateGT());
assertEquals(false, sym.getRateCT()== sym.getRateGT());

double distance = test.getDistance();
double[] mat = new double[4 * 4];
sym.getTransitionProbabilities(null, distance, 0, 1, mat);

final double[] result = test.getExpectedResult();
for (int k = 0; k < mat.length; ++k) {
assertEquals(mat[k], result[k], 1e-10);
System.out.println(k + " : " + (mat[k] - result[k]));
}
}
}
}
100 changes: 100 additions & 0 deletions src/test/substmodels/nucleotide/TVMTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package test.substmodels.nucleotide;


import beast.core.Description;
import beast.core.parameter.RealParameter;
import beast.evolution.substitutionmodel.Frequencies;
import junit.framework.TestCase;
import substmodels.nucleotide.TVM;

/**
* Test TVM matrix exponentiation
*
*/
@Description("Test TVM matrix exponentiation")
public class TVMTest extends TestCase {

/*
* import numpy as np
* from scipy.linalg import expm
*
* piQ = np.diag([0.4, 0.3, 0.2, 0.1])
* d = 0.1
* # Q matrix with zeroed diagonal
* XQ = np.matrix([[0, 0.2, 10, .3], [0.2, 0, 0.4, 10], [10, 0.4, 0, 5], [0.3, 10, 5, 0]])
*
* xx = XQ * piQ
*
* # fill diagonal and normalize by total substitution rate
* q0 = (xx + np.diag(np.squeeze(np.asarray(-np.sum(xx, axis=1))))) / np.sum(piQ * np.sum(xx, axis=1))
* expm(q0 * d)
*/
protected UnequalBaseFrequencies test0 = new UnequalBaseFrequencies() {
@Override
public Double[] getPi() {
return new Double[]{0.4, 0.3, 0.2, 0.1};
}

@Override
public Double [] getRates() {
return new Double[] {0.2, 10.0, 0.3, 0.4, 5.0};
}

@Override
public double getDistance() {
return 0.1;
}

@Override
public double[] getExpectedResult() {
return new double[]{
0.926032879344, 0.002501473304, 0.069681075145, 0.001784572207,
0.003335297739, 0.957195914822, 0.00364943052 , 0.035819356919,
0.13936215029 , 0.005474145781, 0.838262134317, 0.016901569612,
0.00713828883 , 0.107458070758, 0.033803139225, 0.851600501187
};
}
};


UnequalBaseFrequencies[] all = {test0};

public void testTVM() throws Exception {
for (UnequalBaseFrequencies test : all) {

RealParameter f = new RealParameter(test.getPi());
Frequencies freqs = new Frequencies();
freqs.initByName("frequencies", f); // "estimate", true

TVM tvm = new TVM();
RealParameter tvmRates = new RealParameter(test.getRates());
tvm.initByName("rates", tvmRates, "frequencies", freqs);
tvm.printQ(System.out); // to obtain XQ for python script
// for (int i = 0; i < 6; ++i)
// System.out.println("Rate " + tvm.getSubstitution(i) + " : " + tvm.getRate(i));

assertEquals(false, tvm.getRateAC()== tvm.getRateAG());
assertEquals(false, tvm.getRateAC()== tvm.getRateAT());
assertEquals(false, tvm.getRateAC()== tvm.getRateCG());
assertEquals(false, tvm.getRateAC()== tvm.getRateCT());
assertEquals(false, tvm.getRateAC()== tvm.getRateGT());
assertEquals(false, tvm.getRateAG()== tvm.getRateAT());
assertEquals(false, tvm.getRateAG()== tvm.getRateCG());
assertEquals(false, tvm.getRateAG()== tvm.getRateGT());
assertEquals(false, tvm.getRateCG()== tvm.getRateGT());
assertEquals(false, tvm.getRateCT()== tvm.getRateGT());
// AG=CT
assertEquals(true, tvm.getRateAG()== tvm.getRateCT());

double distance = test.getDistance();
double[] mat = new double[4 * 4];
tvm.getTransitionProbabilities(null, distance, 0, 1, mat);

final double[] result = test.getExpectedResult();
for (int k = 0; k < mat.length; ++k) {
assertEquals(mat[k], result[k], 1e-10);
System.out.println(k + " : " + (mat[k] - result[k]));
}
}
}
}
90 changes: 90 additions & 0 deletions src/test/substmodels/nucleotide/TVMefTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package test.substmodels.nucleotide;


import beast.core.Description;
import beast.core.parameter.RealParameter;
import junit.framework.TestCase;
import substmodels.nucleotide.TVMef;

/**
* Test TVMef matrix exponentiation
*
*/
@Description("Test TVMef matrix exponentiation")
public class TVMefTest extends TestCase {

/*
* import numpy as np
* from scipy.linalg import expm
*
* piQ = np.diag([.25, .25, .25, .25])
* d = 0.1
* # Q matrix with zeroed diagonal
* XQ = np.matrix([[0, 0.2, 10, .3], [0.2, 0, 0.4, 10], [10, 0.4, 0, 5], [0.3, 10, 5, 0]])
*
* xx = XQ * piQ
*
* # fill diagonal and normalize by total substitution rate
* q0 = (xx + np.diag(np.squeeze(np.asarray(-np.sum(xx, axis=1))))) / np.sum(piQ * np.sum(xx, axis=1))
* expm(q0 * d)
*/
protected EqualBaseFrequencies test0 = new EqualBaseFrequencies() {
@Override
public Double [] getRates() {
return new Double[] {0.2, 10.0, 0.3, 0.4, 5.0};
}

@Override
public double getDistance() {
return 0.1;
}

@Override
public double[] getExpectedResult() {
return new double[]{
0.924841386585, 0.001651397824, 0.070005984376, 0.003501231215,
0.001651397824, 0.924131486211, 0.004198599114, 0.07001851685,
0.070005984376, 0.004198599114, 0.891234320911, 0.034561095599,
0.003501231215, 0.07001851685 , 0.034561095599, 0.891919156335
};
}
};


EqualBaseFrequencies[] all = {test0};

public void testTVMef() throws Exception {
for (EqualBaseFrequencies test : all) {

TVMef tvmef = new TVMef();
RealParameter tvmefRates = new RealParameter(test.getRates());
tvmef.initByName("rates", tvmefRates);
tvmef.printQ(System.out); // to obtain XQ for python script
// for (int i = 0; i < 6; ++i)
// System.out.println("Rate " + tvmef.getSubstitution(i) + " : " + tvmef.getRate(i));

assertEquals(false, tvmef.getRateAC()== tvmef.getRateAG());
assertEquals(false, tvmef.getRateAC()== tvmef.getRateAT());
assertEquals(false, tvmef.getRateAC()== tvmef.getRateCG());
assertEquals(false, tvmef.getRateAC()== tvmef.getRateCT());
assertEquals(false, tvmef.getRateAC()== tvmef.getRateGT());
assertEquals(false, tvmef.getRateAG()== tvmef.getRateAT());
assertEquals(false, tvmef.getRateAG()== tvmef.getRateCG());
assertEquals(false, tvmef.getRateAG()== tvmef.getRateGT());
assertEquals(false, tvmef.getRateCG()== tvmef.getRateGT());
assertEquals(false, tvmef.getRateCT()== tvmef.getRateGT());
// AG=CT
assertEquals(true, tvmef.getRateAG()== tvmef.getRateCT());

double distance = test.getDistance();
double[] mat = new double[4 * 4];
tvmef.getTransitionProbabilities(null, distance, 0, 1, mat);

final double[] result = test.getExpectedResult();
for (int k = 0; k < mat.length; ++k) {
assertEquals(mat[k], result[k], 1e-10);
System.out.println(k + " : " + (mat[k] - result[k]));
}
}
}
}

0 comments on commit 5865377

Please sign in to comment.