Skip to content

Commit

Permalink
add Unit test #4 TIM1,2,3
Browse files Browse the repository at this point in the history
  • Loading branch information
walterxie committed May 24, 2017
1 parent 5657fb5 commit 4aef94f
Show file tree
Hide file tree
Showing 6 changed files with 524 additions and 0 deletions.
82 changes: 82 additions & 0 deletions src/test/substmodels/nucleotide/TIM1Test.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package test.substmodels.nucleotide;

import beast.core.Description;
import beast.core.parameter.RealParameter;
import junit.framework.TestCase;
import substmodels.nucleotide.TIM1;

@Description("Test TIM1 matrix exponentiation")
public class TIM1Test extends TestCase {

/*
* import numpy as np
* from scipy.linalg import expm
*
* piQ = np.diag([.25, .25, .25, .25])
* d = 0.3
* # Q matrix with zeroed diagonal
* XQ = np.matrix([[0, 1, 2, 3], [1, 0, 3, 4], [2, 3, 0, 1], [3, 4, 1, 0]])
*
* xx = XQ * piQ
*
* # fill diagonal and normalize by total substitution rate
* q0 = (xx + np.diag(np.squeeze(np.asarray(-np.sum(xx, axis=1))))) / np.sum(piQ * np.sum(xx, axis=1))
* expm(q0 * d)
*/
protected EqualBaseFrequencies test0 = new EqualBaseFrequencies() {
@Override
public Double [] getRates() {
return new Double[] {1.0, 2.0, 3.0, 4.0};
}

@Override
public double getDistance() {
return 0.3;
}

@Override
public double[] getExpectedResult() {
return new double[]{
0.783537640428, 0.044577861463, 0.071332157416, 0.100552340692,
0.044577861463, 0.7275631612 , 0.100552340692, 0.127306636645,
0.071332157416, 0.100552340692, 0.783537640428, 0.044577861463,
0.100552340692, 0.127306636645, 0.044577861463, 0.7275631612
};
}
};


EqualBaseFrequencies[] all = {test0};

public void testTIM1() throws Exception {
for (EqualBaseFrequencies test : all) {

TIM1 tim1 = new TIM1();
RealParameter rates = new RealParameter(test.getRates());
tim1.initByName("rates", rates);
tim1.printQ(System.out); // to obtain XQ for python script
// for (int i = 0; i < 6; ++i)
// System.out.println("Rate " + tim1.getSubstitution(i) + " : " + tim1.getRate(i));

// AC=GT
assertEquals(true, tim1.getRateAC()== tim1.getRateGT() &&
tim1.getRateAC()!=tim1.getRateAT() && tim1.getRateAC()!=tim1.getRateAG());
// AT=CG
assertEquals(true, tim1.getRateAT()== tim1.getRateCG() &&
tim1.getRateAT()!=tim1.getRateAC() && tim1.getRateAT()!=tim1.getRateAG());
// AG!=CT
assertEquals(true, tim1.getRateAC()!=tim1.getRateAG() &&
tim1.getRateAC()!=tim1.getRateCT() && tim1.getRateAG()!=tim1.getRateCT());

double distance = test.getDistance();
double[] mat = new double[4 * 4];
tim1.getTransitionProbabilities(null, distance, 0, 1, mat);

final double[] result = test.getExpectedResult();
for (int k = 0; k < mat.length; ++k) {
assertEquals(mat[k], result[k], 1e-10);
System.out.println(k + " : " + (mat[k] - result[k]));
}
}
}
}
93 changes: 93 additions & 0 deletions src/test/substmodels/nucleotide/TIM1ufTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package test.substmodels.nucleotide;

import beast.core.Description;
import beast.core.parameter.RealParameter;
import beast.evolution.substitutionmodel.Frequencies;
import junit.framework.TestCase;
import substmodels.nucleotide.TIM1uf;

@Description("Test TIM1uf matrix exponentiation")
public class TIM1ufTest extends TestCase {

/*
* import numpy as np
* from scipy.linalg import expm
*
* piQ = np.diag([0.4, 0.3, 0.2, 0.1])
* d = 0.3
* # Q matrix with zeroed diagonal
* XQ = np.matrix([[0, 1, 2, 3], [1, 0, 3, 4], [2, 3, 0, 1], [3, 4, 1, 0]])
*
* xx = XQ * piQ
*
* # fill diagonal and normalize by total substitution rate
* q0 = (xx + np.diag(np.squeeze(np.asarray(-np.sum(xx, axis=1))))) / np.sum(piQ * np.sum(xx, axis=1))
* expm(q0 * d)
*/
protected UnequalBaseFrequencies test0 = new UnequalBaseFrequencies() {
@Override
public Double[] getPi() {
return new Double[]{0.4, 0.3, 0.2, 0.1};
}

@Override
public Double [] getRates() {
return new Double[] {1.0, 2.0, 3.0, 4.0};
}

@Override
public double getDistance() {
return 0.3;
}

@Override
public double[] getExpectedResult() {
return new double[]{
0.825871651293, 0.060874601338, 0.067043906735, 0.046209840634,
0.081166135118, 0.765617592587, 0.094589810557, 0.058626461738,
0.134087813471, 0.141884715835, 0.701565807625, 0.022461663069,
0.184839362534, 0.175879385215, 0.044923326138, 0.594357926113
};
}
};


UnequalBaseFrequencies[] all = {test0};

public void testTIM1uf() throws Exception {
for (UnequalBaseFrequencies test : all) {

RealParameter f = new RealParameter(test.getPi());
Frequencies freqs = new Frequencies();
freqs.initByName("frequencies", f); // "estimate", true

TIM1uf tim1uf = new TIM1uf();
RealParameter rates = new RealParameter(test.getRates());
tim1uf.initByName("rates", rates, "frequencies", freqs);
tim1uf.printQ(System.out); // to obtain XQ for python script
// for (int i = 0; i < 6; ++i)
// System.out.println("Rate " + tim1uf.getSubstitution(i) + " : " + tim1uf.getRate(i));

// AC=GT
assertEquals(true, tim1uf.getRateAC()== tim1uf.getRateGT() &&
tim1uf.getRateAC()!=tim1uf.getRateAT() && tim1uf.getRateAC()!=tim1uf.getRateAG());
// AT=CG
assertEquals(true, tim1uf.getRateAT()== tim1uf.getRateCG() &&
tim1uf.getRateAT()!=tim1uf.getRateAC() && tim1uf.getRateAT()!=tim1uf.getRateAG());
// AG!=CT
assertEquals(true, tim1uf.getRateAC()!=tim1uf.getRateAG() &&
tim1uf.getRateAC()!=tim1uf.getRateCT() && tim1uf.getRateAG()!=tim1uf.getRateCT());


double distance = test.getDistance();
double[] mat = new double[4 * 4];
tim1uf.getTransitionProbabilities(null, distance, 0, 1, mat);

final double[] result = test.getExpectedResult();
for (int k = 0; k < mat.length; ++k) {
assertEquals(mat[k], result[k], 1e-10);
System.out.println(k + " : " + (mat[k] - result[k]));
}
}
}
}
82 changes: 82 additions & 0 deletions src/test/substmodels/nucleotide/TIM2Test.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package test.substmodels.nucleotide;

import beast.core.Description;
import beast.core.parameter.RealParameter;
import junit.framework.TestCase;
import substmodels.nucleotide.TIM2;

@Description("Test TIM2 matrix exponentiation")
public class TIM2Test extends TestCase {

/*
* import numpy as np
* from scipy.linalg import expm
*
* piQ = np.diag([.25, .25, .25, .25])
* d = 0.3
* # Q matrix with zeroed diagonal
* XQ = np.matrix([[0, 1, 2, 1], [1, 0, 3, 4], [2, 3, 0, 3], [1, 4, 3, 0]])
*
* xx = XQ * piQ
*
* # fill diagonal and normalize by total substitution rate
* q0 = (xx + np.diag(np.squeeze(np.asarray(-np.sum(xx, axis=1))))) / np.sum(piQ * np.sum(xx, axis=1))
* expm(q0 * d)
*/
protected EqualBaseFrequencies test0 = new EqualBaseFrequencies() {
@Override
public Double [] getRates() {
return new Double[] {1.0, 2.0, 3.0, 4.0};
}

@Override
public double getDistance() {
return 0.3;
}

@Override
public double[] getExpectedResult() {
return new double[]{
0.847204213099, 0.040767943097, 0.071259900708, 0.040767943097,
0.040767943097, 0.727703865901, 0.101751858319, 0.129776332684,
0.071259900708, 0.101751858319, 0.725236382655, 0.101751858319,
0.040767943097, 0.129776332684, 0.101751858319, 0.727703865901
};
}
};


EqualBaseFrequencies[] all = {test0};

public void testTIM2() throws Exception {
for (EqualBaseFrequencies test : all) {

TIM2 tim2 = new TIM2();
RealParameter rates = new RealParameter(test.getRates());
tim2.initByName("rates", rates);
tim2.printQ(System.out); // to obtain XQ for python script
// for (int i = 0; i < 6; ++i)
// System.out.println("Rate " + tim2.getSubstitution(i) + " : " + tim2.getRate(i));

// AC=AT
assertEquals(true, tim2.getRateAC()== tim2.getRateAT() &&
tim2.getRateAC()!=tim2.getRateCG() && tim2.getRateAC()!=tim2.getRateAG());
// CG=GT
assertEquals(true, tim2.getRateCG()== tim2.getRateGT() &&
tim2.getRateCG()!=tim2.getRateCT() && tim2.getRateCG()!=tim2.getRateAG());
// AG!=CT
assertEquals(true, tim2.getRateAC()!=tim2.getRateAG() &&
tim2.getRateAC()!=tim2.getRateCT() && tim2.getRateAG()!=tim2.getRateCT());

double distance = test.getDistance();
double[] mat = new double[4 * 4];
tim2.getTransitionProbabilities(null, distance, 0, 1, mat);

final double[] result = test.getExpectedResult();
for (int k = 0; k < mat.length; ++k) {
assertEquals(mat[k], result[k], 1e-10);
System.out.println(k + " : " + (mat[k] - result[k]));
}
}
}
}
93 changes: 93 additions & 0 deletions src/test/substmodels/nucleotide/TIM2ufTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package test.substmodels.nucleotide;

import beast.core.Description;
import beast.core.parameter.RealParameter;
import beast.evolution.substitutionmodel.Frequencies;
import junit.framework.TestCase;
import substmodels.nucleotide.TIM2uf;

@Description("Test TIM2uf matrix exponentiation")
public class TIM2ufTest extends TestCase {

/*
* import numpy as np
* from scipy.linalg import expm
*
* piQ = np.diag([0.4, 0.3, 0.2, 0.1])
* d = 0.3
* # Q matrix with zeroed diagonal
* XQ = np.matrix([[0, 1, 2, 1], [1, 0, 3, 4], [2, 3, 0, 3], [1, 4, 3, 0]])
*
* xx = XQ * piQ
*
* # fill diagonal and normalize by total substitution rate
* q0 = (xx + np.diag(np.squeeze(np.asarray(-np.sum(xx, axis=1))))) / np.sum(piQ * np.sum(xx, axis=1))
* expm(q0 * d)
*/
protected UnequalBaseFrequencies test0 = new UnequalBaseFrequencies() {
@Override
public Double[] getPi() {
return new Double[]{0.4, 0.3, 0.2, 0.1};
}

@Override
public Double [] getRates() {
return new Double[] {1.0, 2.0, 3.0, 4.0};
}

@Override
public double getDistance() {
return 0.3;
}

@Override
public double[] getExpectedResult() {
return new double[]{
0.848004762603, 0.061442861874, 0.070071421565, 0.020480953958,
0.081923815832, 0.755054682896, 0.099180935215, 0.063840566057,
0.140142843131, 0.148771402822, 0.661495286439, 0.049590467607,
0.081923815832, 0.19152169817 , 0.099180935215, 0.627373550783
};
}
};


UnequalBaseFrequencies[] all = {test0};

public void testTIM2uf() throws Exception {
for (UnequalBaseFrequencies test : all) {

RealParameter f = new RealParameter(test.getPi());
Frequencies freqs = new Frequencies();
freqs.initByName("frequencies", f); // "estimate", true

TIM2uf tim2uf = new TIM2uf();
RealParameter rates = new RealParameter(test.getRates());
tim2uf.initByName("rates", rates, "frequencies", freqs);
tim2uf.printQ(System.out); // to obtain XQ for python script
// for (int i = 0; i < 6; ++i)
// System.out.println("Rate " + tim2uf.getSubstitution(i) + " : " + tim2uf.getRate(i));

// AC=AT
assertEquals(true, tim2uf.getRateAC()== tim2uf.getRateAT() &&
tim2uf.getRateAC()!=tim2uf.getRateCG() && tim2uf.getRateAC()!=tim2uf.getRateAG());
// CG=GT
assertEquals(true, tim2uf.getRateCG()== tim2uf.getRateGT() &&
tim2uf.getRateCG()!=tim2uf.getRateCT() && tim2uf.getRateCG()!=tim2uf.getRateAG());
// AG!=CT
assertEquals(true, tim2uf.getRateAC()!=tim2uf.getRateAG() &&
tim2uf.getRateAC()!=tim2uf.getRateCT() && tim2uf.getRateAG()!=tim2uf.getRateCT());


double distance = test.getDistance();
double[] mat = new double[4 * 4];
tim2uf.getTransitionProbabilities(null, distance, 0, 1, mat);

final double[] result = test.getExpectedResult();
for (int k = 0; k < mat.length; ++k) {
assertEquals(mat[k], result[k], 1e-10);
System.out.println(k + " : " + (mat[k] - result[k]));
}
}
}
}
Loading

0 comments on commit 4aef94f

Please sign in to comment.