From 3575817c321afb7b085ef0efc5a4143054fc7d16 Mon Sep 17 00:00:00 2001 From: Denise Date: Mon, 2 May 2016 16:59:39 +0200 Subject: [PATCH] Some refactoring to remove duplicates, and getting package ready for BEAUTI. --- build.xml | 211 +++++++++ examples/BDMM_migration_example.xml | 8 +- .../BDMUC_example_SequenceSimAnaLyzer.xml | 4 +- .../speciation/BirthDeathMigrationModel.java | 296 +----------- .../BirthDeathMigrationModelUncoloured.java | 406 +--------------- ...ewiseBirthDeathMigrationDistribution.java} | 346 ++++++++++++-- .../speciation/BirthDeathMigrationTest.java | 4 +- .../BirthDeathMigrationUncolouredTest.java | 107 +++-- templates/BDMM.xml | 442 ++++++++++++++++++ version.xml | 6 +- 10 files changed, 1096 insertions(+), 734 deletions(-) create mode 100644 build.xml rename src/beast/evolution/speciation/{PiecewiseBirthDeathSamplingDistribution.java => PiecewiseBirthDeathMigrationDistribution.java} (60%) create mode 100644 templates/BDMM.xml diff --git a/build.xml b/build.xml new file mode 100644 index 0000000..5993402 --- /dev/null +++ b/build.xml @@ -0,0 +1,211 @@ + + + + Build BDMM. + Also used by Hudson BDMM project. + JUnit test is available for this build. + $Id: build_BDMM.xml $ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/BDMM_migration_example.xml b/examples/BDMM_migration_example.xml index 4fe259f..e4c6c65 100644 --- a/examples/BDMM_migration_example.xml +++ b/examples/BDMM_migration_example.xml @@ -215,7 +215,7 @@ - + 1.0 @@ -223,7 +223,7 @@ - + @@ -286,11 +286,11 @@ - + diff --git a/examples/BDMUC_example_SequenceSimAnaLyzer.xml b/examples/BDMUC_example_SequenceSimAnaLyzer.xml index 0571887..f16c187 100644 --- a/examples/BDMUC_example_SequenceSimAnaLyzer.xml +++ b/examples/BDMUC_example_SequenceSimAnaLyzer.xml @@ -305,7 +305,7 @@ - + diff --git a/src/beast/evolution/speciation/BirthDeathMigrationModel.java b/src/beast/evolution/speciation/BirthDeathMigrationModel.java index a38765e..de21ba3 100755 --- a/src/beast/evolution/speciation/BirthDeathMigrationModel.java +++ b/src/beast/evolution/speciation/BirthDeathMigrationModel.java @@ -3,15 +3,8 @@ import beast.core.Description; import beast.core.util.Utils; import beast.evolution.tree.*; -import beast.core.parameter.RealParameter; import beast.core.Input; -import math.p0_ODE; -import math.p0ge_ODE; - -import org.apache.commons.math3.ode.FirstOrderIntegrator; -import org.apache.commons.math3.ode.nonstiff.*; - /** * @author Denise Kuehnert @@ -21,44 +14,13 @@ @Description("This model implements a multi-deme version of the BirthDeathSkylineModel with discrete locations and migration events among demes. " + "This should only be used when the migration process along the phylogeny is important. Otherwise the computationally less intense BirthDeathMigrationModelUncoloured can be employed.") -public class BirthDeathMigrationModel extends PiecewiseBirthDeathSamplingDistribution { - - public Input frequencies = - new Input<>("frequencies", "state frequencies", Input.Validate.REQUIRED); - - public Input origin = - new Input<>("origin", "The origin of infection x1"); +public class BirthDeathMigrationModel extends PiecewiseBirthDeathMigrationDistribution { public Input originBranchInput = new Input<>("originBranch", "MultiTypeRootBranch for origin coloring"); - public Input originIsRootEdge = - new Input<>("originIsRootEdge", "The origin is only the length of the root edge", false); - - public Input maxEvaluations = - new Input<>("maxEvaluations", "The maximum number of evaluations for ODE solver", 20000); - - public Input conditionOnSurvival = - new Input<>("conditionOnSurvival", "condition on at least one survival? Default true.", true); - - public Input tolerance = - new Input<>("tolerance", "tolerance for numerical integration", 1e-14); - MultiTypeTree coltree; - - Double[] freq; - double T; - double orig; MultiTypeRootBranch originBranch; - int ntaxa; - - p0_ODE P; - p0ge_ODE PG; - - FirstOrderIntegrator pg_integrator; - public int maxEvalsUsed; - public Double minstep; - public Double maxstep; Boolean print = false; @@ -70,69 +32,32 @@ public void initAndValidate() { coltree = (MultiTypeTree) treeInput.get(); if (origin.get()==null){ + T = coltree.getRoot().getHeight(); } - - else{ + else { originBranch = originBranchInput.get(); if (originBranch==null) throw new RuntimeException("Error: Origin specified but originBranch missing!"); - updateOrigin(coltree.getRoot()); - - if (!Boolean.valueOf(System.getProperty("beast.resume")) && treeInput.get().getRoot().getHeight() >= origin.get().getValue()) - throw new RuntimeException("Error: origin("+T+") must be larger than tree height("+coltree.getRoot().getHeight()+")!"); + checkOrigin(coltree); } ntaxa = coltree.getLeafNodeCount(); - if (birthRate.get() != null && deathRate.get() != null && samplingRate.get() != null){ - - transform = false; - death = deathRate.get().getValues(); - psi = samplingRate.get().getValues(); - birth = birthRate.get().getValues(); - } - else if (R0.get() != null && becomeUninfectiousRate.get() != null && samplingProportion.get() != null){ - - transform = true; - } - - else{ - throw new RuntimeException("Either specify birthRate, deathRate and samplingRate OR specify R0, becomeUninfectiousRate and samplingProportion!"); - } - - freq = frequencies.get().getValues(); + int contempCount = 0; + for (Node node : coltree.getExternalNodes()) + if (node.getHeight()==0.) + contempCount++; + if (checkRho.get() && contempCount>1 && rho==null) + throw new RuntimeException("Error: multiple tips given at present, but sampling probability \'rho\' is not specified."); collectTimes(T); setRho(); - } - - void setupIntegrators(){ // set up ODE's and integrators - - if (minstep == null) minstep = tolerance.get(); - if (maxstep == null) maxstep = 1000.; - - P = new p0_ODE(birth,null, death,psi,M, n, totalIntervals, times); - PG = new p0ge_ODE(birth, null, death,psi,M, n, totalIntervals, T, times, P, maxEvaluations.get(), true); - - if (!useRKInput.get()) { - pg_integrator = new DormandPrince853Integrator(minstep, maxstep, tolerance.get(), tolerance.get()); // - pg_integrator.setMaxEvaluations(maxEvaluations.get()); - - PG.p_integrator = new DormandPrince853Integrator(minstep, maxstep, tolerance.get(), tolerance.get()); // - PG.p_integrator.setMaxEvaluations(maxEvaluations.get()); - } else { - pg_integrator = new ClassicalRungeKuttaIntegrator(T/1000); - PG.p_integrator = new ClassicalRungeKuttaIntegrator(T/1000); - } - } - - double updateRates(){ birth = new Double[n*totalIntervals]; @@ -141,66 +66,17 @@ void setupIntegrators(){ // set up ODE's and integrators M = new Double[totalIntervals*(n*(n-1))]; if (SAModel) r = new Double[n * totalIntervals]; - if (transform) { + if (transform) transformParameters(); - } - else { - - Double[] birthRates = birthRate.get().getValues(); - Double[] deathRates = deathRate.get().getValues(); - Double[] samplingRates = samplingRate.get().getValues(); - Double[] removalProbabilities = new Double[1]; - - if (SAModel) { - removalProbabilities = removalProbability.get().getValues(); - r = new Double[n*totalIntervals]; - } - - int state; - - for (int i = 0; i < n*totalIntervals; i++) { - state = i/totalIntervals; - - birth[i] = birthRates[birthRates.length > n ? (birthChanges+1)*state+index(times[i%totalIntervals], birthRateChangeTimes) : state]; - death[i] = deathRates[deathRates.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state]; - psi[i] = samplingRates[samplingRates.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state]; - if (SAModel) r[i] = removalProbabilities[removalProbabilities.length > n ? (rChanges+1)*state+index(times[i%totalIntervals], rChangeTimes) : state]; - - } - } + else + updateBirthDeathPsiParams(); Double[] migRates = migrationMatrix.get().getValues(); - for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) { - for (int dt = 0; dt < totalIntervals; dt++) { - if (i != j) { - M[(i * (n - 1) + (j < i ? j : j - 1)) * totalIntervals + dt] - = migRates[(migRates.length > (n * (n - 1))) - ? (migChanges + 1) * (n - 1) * i + index(times[dt], migChangeTimes) - : (i * (n - 1) + (j < i ? j : j - 1))]; - } - } - } - } + updateAmongParameter(M, migRates, migChanges, migChangeTimes); - - if (m_rho.get() != null && (m_rho.get().getDimension()==1 || rhoSamplingTimes.get() != null)) { - - Double[] rhos = m_rho.get().getValues(); - rho = new Double[n*totalIntervals]; - int state; - - for (int i = 0; i < totalIntervals*n; i++) { - - state = i/totalIntervals; - - rho[i]= rhoChanges>0? - rhoSamplingChangeTimes.contains(times[i]) ? rhos[rhos.length > n ? (rhoChanges+1)*state+index(times[i%totalIntervals], rhoSamplingChangeTimes) : state] : 0. - : rhos[0]; - } - } + updateRho(); freq = frequencies.get().getValues(); @@ -221,7 +97,6 @@ void computeRhoTips(){ for (Double time:rhoSamplingChangeTimes){ if (Math.abs(time-tipTime) < 1e-10 && rho[((MultiTypeNode)tip).getNodeType()*totalIntervals + Utils.index(time, times, totalIntervals)]>0) isRhoTip[tip.getNr()] = true; - } } } @@ -232,84 +107,11 @@ public double[] getG(double t, double[] PG0, double t0, Node node){ // PG0 conta System.arraycopy(PG.getP(t0, m_rho.get()!=null, rho), 0, PG0, 0, n); } -// else if (rhoSamplingChangeTimes.contains(t)){ -// -// int nodestate = ((MultiTypeNode)node).getNodeType(); -// PG0[nodestate] *= (1-rho[nodestate*totalIntervals+ Utils.index(t,times,totalIntervals)]); -// } - - return getG(t, PG0, t0); - } - - - public double[] getG(double t, double[] PG0, double t0){ // PG0 contains initial condition for p0 (0..n-1) and for ge (n..2n-1) - - try { - - if (Math.abs(T-t)<1e-10 || Math.abs(t0-t)<1e-10 || T < t) { - return PG0; - } - - double from = t; - double to = t0; - double oneMinusRho; - - int indexFrom = Utils.index(from, times, times.length); - int index = Utils.index(to, times, times.length); - - int steps = index - indexFrom; - if (Math.abs(from-times[indexFrom])<1e-10) steps--; - if (index>0 && Math.abs(to-times[index-1])<1e-10) { - steps--; - index--; - } - index--; - - while (steps > 0){ - - from = times[index];// + 1e-14; - - pg_integrator.integrate(PG, to, PG0, from, PG0); // solve PG , store solution in PG0 - - if (rhoChanges>0){ - for (int i=0; i maxEvalsUsed) maxEvalsUsed = pg_integrator.getEvaluations(); - return PG0; + return getG(t, PG0, t0, pg_integrator, PG, T, maxEvalsUsed); } - void updateOrigin(Node root){ - - T = origin.get().getValue(); - orig = T - root.getHeight(); - - if (originIsRootEdge.get()) { - - orig = origin.get().getValue(); - T = orig + coltree.getRoot().getHeight(); - } - - } @Override public double calculateTreeLogLikelihood(TreeInterface tree) { @@ -319,8 +121,6 @@ public double calculateTreeLogLikelihood(TreeInterface tree) { coltree = (MultiTypeTree) tree; MultiTypeNode root = (MultiTypeNode) coltree.getRoot(); -// int node_state = ((MultiTypeNode) coltree.getRoot()).getNodeType(); //.getNodeColour(root); - if (!coltree.isValid() || (origin.get()!=null && !originBranchIsValid(root))){ logP = Double.NEGATIVE_INFINITY; @@ -330,13 +130,12 @@ public double calculateTreeLogLikelihood(TreeInterface tree) { int node_state; if (origin.get()==null) { T = root.getHeight(); - node_state = ((MultiTypeNode) coltree.getRoot()).getNodeType(); //.getNodeColour(root); + node_state = ((MultiTypeNode) coltree.getRoot()).getNodeType(); } else{ updateOrigin(root); node_state = (originBranch.getChangeCount()>0) ? originBranch.getChangeType(originBranch.getChangeCount()-1) : ((MultiTypeNode) coltree.getRoot()).getNodeType(); - //originBranch.getFinalType(); if (orig < 0){ logP = Double.NEGATIVE_INFINITY; @@ -438,10 +237,9 @@ public double calculateTreeLogLikelihood(TreeInterface tree) { g = calculateOriginLikelihood(migIndex, to, T - originBranch.getChangeTime(migIndex)); System.arraycopy(g, 0, init, 0, n); -// init[n+prevcol] = M[prevcol*(n-1)+(col 0)? ((MultiTypeNode) node).getChangeType(migIndex-1): ((MultiTypeNode) node).getNodeType(); // (migIndex > 0)? coltree.getChangeColour(node, migIndex-1): coltree.getNodeColour(node); + int prevcol = ((MultiTypeNode) node).getChangeType(migIndex); + int col = (migIndex > 0)? ((MultiTypeNode) node).getChangeType(migIndex-1): ((MultiTypeNode) node).getNodeType(); double time ; migIndex--; - time = (migIndex >= 0)? ((MultiTypeNode) node).getChangeTime(migIndex) :node.getHeight();// (migIndex >= 0)?coltree.getChangeTime(node, migIndex):node.getHeight(); + time = (migIndex >= 0)? ((MultiTypeNode) node).getChangeTime(migIndex) :node.getHeight(); double[] g = calculateSubtreeLikelihood(node, (migIndex >= 0), migIndex, to, T-time); System.arraycopy(g, 0, init, 0, n); -// init[n+prevcol] = M[prevcol*(n-1)+(col 0) t1 = T - ((MultiTypeNode)node.getChild(childIndex)).getChangeTime(childChangeCount-1); @@ -538,52 +331,11 @@ else if (node.getChildCount()==2){ // birth / infection event return getG(from, init, to, node); } + public void transformParameters() { - - - public void transformParameters(){ - - Double[] p = samplingProportion.get().getValues(); - Double[] ds = becomeUninfectiousRate.get().getValues(); - Double[] R = R0.get().getValues(); - Double[] removalProbabilities = new Double[1]; - if (SAModel) removalProbabilities = removalProbability.get().getValues(); - - int state; - - for (int i = 0; i < totalIntervals*n; i++){ - - state = i/totalIntervals; - - birth[i] = R[R.length > n ? (birthChanges+1)*state+index(times[i%totalIntervals], birthRateChangeTimes) : state] - * ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state]; - -// psi[i] = p[p.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state] -// * ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] ; -// -// death[i] = ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] - psi[i]; -// - if (!SAModel) { - psi[i] = p[p.length > n ? (samplingChanges + 1) * state + index(times[i % totalIntervals], samplingRateChangeTimes) : state] - * ds[ds.length > n ? (deathChanges + 1) * state + index(times[i % totalIntervals], deathRateChangeTimes) : state]; - - death[i] = ds[ds.length > n ? (deathChanges + 1) * state + index(times[i % totalIntervals], deathRateChangeTimes) : state] - psi[i]; - } - else { - r[i] = removalProbabilities[removalProbabilities.length > n ? (rChanges+1)*state+index(times[i%totalIntervals], rChangeTimes) : state]; - - psi[i] = p[p.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state] - * ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] - / (1+(r[i]-1)*p[p.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state]); - - - death[i] = ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] - psi[i]*r[i]; - } - } - + transformWithinParameters(); } - public Boolean originBranchIsValid(MultiTypeNode root){ int count = originBranch.getChangeCount(); diff --git a/src/beast/evolution/speciation/BirthDeathMigrationModelUncoloured.java b/src/beast/evolution/speciation/BirthDeathMigrationModelUncoloured.java index 32eb52e..b65ef86 100755 --- a/src/beast/evolution/speciation/BirthDeathMigrationModelUncoloured.java +++ b/src/beast/evolution/speciation/BirthDeathMigrationModelUncoloured.java @@ -1,17 +1,10 @@ package beast.evolution.speciation; import beast.evolution.tree.*; -import beast.core.parameter.RealParameter; import beast.core.Input; import beast.core.Description; import beast.core.util.Utils; -import math.p0_ODE; -import math.p0ge_ODE; - -import org.apache.commons.math3.ode.FirstOrderIntegrator; -import org.apache.commons.math3.ode.nonstiff.*; - /** * @author Denise Kuehnert @@ -23,45 +16,13 @@ @Description("This model implements a multi-deme version of the BirthDeathSkylineModel with discrete locations and migration events among demes. " + "This should be used when the migration process along the phylogeny is irrelevant. Otherwise the BirthDeathMigrationModel can be employed." + "This implementation also works with sampled ancestor trees.") -public class BirthDeathMigrationModelUncoloured extends PiecewiseBirthDeathSamplingDistribution { - - - public Input frequencies = - new Input<>("frequencies", "state frequencies", Input.Validate.REQUIRED); - - public Input origin = - new Input<>("origin", "The origin of infection x1"); - - public Input originIsRootEdge = - new Input<>("originIsRootEdge", "The origin is only the length of the root edge", false); +public class BirthDeathMigrationModelUncoloured extends PiecewiseBirthDeathMigrationDistribution { - public Input maxEvaluations = - new Input<>("maxEvaluations", "The maximum number of evaluations for ODE solver", 20000); - - public Input conditionOnSurvival = - new Input<>("conditionOnSurvival", "condition on at least one survival? Default true.", true); - - public Input tolerance = - new Input<>("tolerance", "tolerance for numerical integration", 1e-14); public Input tiptypes = new Input<>("tiptypes", "trait information for initializing traits (like node types/locations) in the tree", Input.Validate.REQUIRED); public Input typeLabel = new Input<>("typeLabel", "type label in tree for initializing traits (like node types/locations) in the tree", Input.Validate.XOR, tiptypes); public Input storeNodeTypes = new Input<>("storeNodeTypes", "store tip node types? this assumes that tip types cannot change (default false)", false); - public Input checkRho = new Input<>("checkRho", "check if rho is set if multiple tips are given at present (default true)", true); - - Double[] freq; - double T; - double orig; - int ntaxa; - - p0_ODE P; - p0ge_ODE PG; - - FirstOrderIntegrator pg_integrator; - public int maxEvalsUsed; - public Double minstep; - public Double maxstep; private int[] nodeStates; @@ -74,46 +35,19 @@ public void initAndValidate() { TreeInterface tree = treeInput.get(); - if (origin.get()==null){ - T = tree.getRoot().getHeight(); - } - - else{ - - T = origin.get().getValue(); - orig = T - tree.getRoot().getHeight(); - - - if (originIsRootEdge.get()) { - - orig = origin.get().getValue(); - T = orig + tree.getRoot().getHeight(); - } - - if (!Boolean.valueOf(System.getProperty("beast.resume")) && orig < 0) - throw new RuntimeException("Error: origin("+T+") must be larger than tree height("+tree.getRoot().getHeight()+")!"); - } + checkOrigin(tree); ntaxa = tree.getLeafNodeCount(); birthAmongDemes = (birthRateAmongDemes.get() !=null || R0AmongDemes.get()!=null); - if (birthRate.get() != null && deathRate.get() != null && samplingRate.get() != null){ - - if (birthAmongDemes) b_ij = birthRateAmongDemes.get().getValues(); - - transform = false; - death = deathRate.get().getValues(); - psi = samplingRate.get().getValues(); - birth = birthRate.get().getValues(); - } - else if (R0.get() != null && becomeUninfectiousRate.get() != null && samplingProportion.get() != null){ + if (storeNodeTypes.get()) { - transform = true; - } + nodeStates = new int[ntaxa]; - else{ - throw new RuntimeException("Either specify birthRate, deathRate and samplingRate OR specify R0, becomeUninfectiousRate and samplingProportion!"); + for (Node node : tree.getExternalNodes()){ + nodeStates[node.getNr()] = getNodeState(node, true); + } } int contempCount = 0; @@ -123,80 +57,13 @@ else if (R0.get() != null && becomeUninfectiousRate.get() != null && samplingPro if (checkRho.get() && contempCount>1 && rho==null) throw new RuntimeException("Error: multiple tips given at present, but sampling probability \'rho\' is not specified."); - freq = frequencies.get().getValues(); - - double freqSum = 0; - for (double f : freq) freqSum+= f; - if (freqSum!=1.) throw new RuntimeException("Error: frequencies must add up to 1 but currently add to " + freqSum + "."); - -// // calculate equilibrium frequencies for 2 types: -// double LambMu = -b[0]-b[1]-(d[0]+s[0])+(d[1]+s[1]); -// double c = Math.sqrt(Math.pow(LambMu,2) +4*b_ij[0]*b_ij[1]); -// freq[0] = (c+LambMu)/(c+LambMu+2*b_ij[0]) ; -// freq[1] = 1 - freq[0]; - - collectTimes(T); setRho(); - - maxEvalsUsed = 0; - - if (storeNodeTypes.get()) { - - nodeStates = new int[ntaxa]; - - for (Node node : tree.getExternalNodes()){ - nodeStates[node.getNr()] = getNodeState(node, true); - } - } - } - void setupIntegrators(){ // set up ODE's and integrators - - if (minstep == null) minstep = tolerance.get(); - if (maxstep == null) maxstep = 1000.; - - P = new p0_ODE(birth, (birthAmongDemes ? b_ij : null), death,psi,M, n, totalIntervals, times); - PG = new p0ge_ODE(birth, (birthAmongDemes ? b_ij : null), death,psi,M, n, totalIntervals, T, times, P, maxEvaluations.get(), false); - - - if (!useRKInput.get()) { - pg_integrator = new DormandPrince853Integrator(minstep, maxstep, tolerance.get(), tolerance.get()); // - pg_integrator.setMaxEvaluations(maxEvaluations.get()); - - PG.p_integrator = new DormandPrince853Integrator(minstep, maxstep, tolerance.get(), tolerance.get()); // - PG.p_integrator.setMaxEvaluations(maxEvaluations.get()); - } else { - pg_integrator = new ClassicalRungeKuttaIntegrator(T / 1000); - PG.p_integrator = new ClassicalRungeKuttaIntegrator(T / 1000); - - } - } - protected Double updateRates(TreeInterface tree) { - if (origin.get()==null){ - T = tree.getRoot().getHeight(); - } - - else{ - - T = origin.get().getValue(); - orig = T - tree.getRoot().getHeight(); - - - if (originIsRootEdge.get()) { - - orig = origin.get().getValue(); - T = orig + tree.getRoot().getHeight(); - } - - if (!Boolean.valueOf(System.getProperty("beast.resume")) && orig < 0) - throw new RuntimeException("Error: origin("+T+") must be larger than tree height("+tree.getRoot().getHeight()+")!"); - } - birth = new Double[n*totalIntervals]; death = new Double[n*totalIntervals]; psi = new Double[n*totalIntervals]; @@ -209,80 +76,23 @@ protected Double updateRates(TreeInterface tree) { } else { - Double[] birthRates = birthRate.get().getValues(); - Double[] deathRates = deathRate.get().getValues(); - Double[] samplingRates = samplingRate.get().getValues(); Double[] birthAmongDemesRates = new Double[1]; - if (birthAmongDemes) birthAmongDemesRates = birthRateAmongDemes.get().getValues(); - Double[] removalProbabilities = new Double[1]; - - if (SAModel) { - removalProbabilities = removalProbability.get().getValues(); - r = new Double[n*totalIntervals]; - } - - int state; - - for (int i = 0; i < n*totalIntervals; i++) { - state = i/totalIntervals; - - birth[i] = birthRates[birthRates.length > n ? (birthChanges+1)*state+index(times[i%totalIntervals], birthRateChangeTimes) : state]; - death[i] = deathRates[deathRates.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state]; - psi[i] = samplingRates[samplingRates.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state]; - if (SAModel) r[i] = removalProbabilities[removalProbabilities.length > n ? (rChanges+1)*state+index(times[i%totalIntervals], rChangeTimes) : state]; + if (birthAmongDemes) birthAmongDemesRates = birthRateAmongDemes.get().getValues(); - } + updateBirthDeathPsiParams(); if (birthAmongDemes) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) { - for (int dt = 0; dt < totalIntervals; dt++) { - if (i != j) { - b_ij[(i * (n - 1) + (j < i ? j : j - 1)) * totalIntervals + dt] - = birthAmongDemesRates[(birthAmongDemesRates.length > (n * (n - 1))) - ? (b_ij_Changes + 1) * (n - 1) * i + index(times[dt], b_ijChangeTimes) - : (i * (n - 1) + (j < i ? j : j - 1))]; - } - } - } - } - } + updateAmongParameter(b_ij, birthAmongDemesRates, b_ij_Changes, b_ijChangeTimes); + } } Double[] migRates = migrationMatrix.get().getValues(); - for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) { - for (int dt = 0; dt < totalIntervals; dt++) { - if (i != j) { - M[(i * (n - 1) + (j < i ? j : j - 1)) * totalIntervals + dt] - = migRates[(migRates.length > (n * (n - 1))) - ? (migChanges + 1) * (n - 1) * i + index(times[dt], migChangeTimes) - : (i * (n - 1) + (j < i ? j : j - 1))]; - } - } - } - } - //todo: remove duplicate (make it a new method) - - - if (m_rho.get() != null && (m_rho.get().getDimension()==1 || rhoSamplingTimes.get() != null)) { - - Double[] rhos = m_rho.get().getValues(); - rho = new Double[n*totalIntervals]; - int state; + updateAmongParameter(M, migRates, migChanges, migChangeTimes); - for (int i = 0; i < totalIntervals*n; i++) { - - state = i/totalIntervals; - - rho[i]= rhoChanges>0? - rhoSamplingChangeTimes.contains(times[i]) ? rhos[rhos.length > n ? (rhoChanges+1)*state+index(times[i%totalIntervals], rhoSamplingChangeTimes) : state] : 0. - : rhos[0]; - } - } + updateRho(); freq = frequencies.get().getValues(); @@ -310,62 +120,13 @@ void computeRhoTips(){ public double[] getG(double t, double[] PG0, double t0, Node node){ // PG0 contains initial condition for p0 (0..n-1) and for ge (n..2n-1) - try { - - if (node.isLeaf()) { - - System.arraycopy(PG.getP(t0, m_rho.get()!=null, rho), 0, PG0, 0, n); - } - - if (Math.abs(T-t)<1e-10 || Math.abs(t0-t)<1e-10 || T < t) { - return PG0; - } - - double from = t; - double to = t0; - double oneMinusRho; - - int indexFrom = Utils.index(from, times, times.length); - int index = Utils.index(to, times, times.length); - - int steps = index - indexFrom; - if (Math.abs(from-times[indexFrom])<1e-10) steps--; - if (index>0 && Math.abs(to-times[index-1])<1e-10) { - steps--; - index--; - } - index--; - - while (steps > 0){ - - from = times[index];// + 1e-14; - - pg_integrator.integrate(PG, to, PG0, from, PG0); // solve PG , store solution in PG0 - - if (rhoChanges>0){ - for (int i=0; i maxEvalsUsed) maxEvalsUsed = pg_integrator.getEvaluations(); + return getG(t, PG0, t0, pg_integrator, PG, T, maxEvalsUsed); - return PG0; } @@ -374,24 +135,7 @@ public double calculateTreeLogLikelihood(TreeInterface tree) { Node root = tree.getRoot(); - if (origin.get()==null){ - T =tree.getRoot().getHeight(); - } - else{ - - T = origin.get().getValue(); - orig = T - root.getHeight(); - - - if (originIsRootEdge.get()) { - - orig = origin.get().getValue(); - T = orig + tree.getRoot().getHeight(); - } - - if (orig < 0) - return Double.NEGATIVE_INFINITY; - } + checkOrigin(tree); collectTimes(T); setRho(); @@ -422,10 +166,12 @@ public double calculateTreeLogLikelihood(TreeInterface tree) { } double[] p; + if ( orig > 0 ) { p = calculateSubtreeLikelihood(root,0,orig); } else { + int childIndex = 0; if (root.getChild(1).getNr() > root.getChild(0).getNr()) childIndex = 1; // always start with the same child to avoid numerical differences @@ -532,7 +278,6 @@ private int getNodeState(Node node, Boolean init){ else { if (!isRhoTip[node.getNr()]) -// init[n+nodestate] = psi[nodestate*totalIntervals+index]; init[n + nodestate] = SAModel ? psi[nodestate * totalIntervals + index]* (r[nodestate * totalIntervals + index] + (1-r[nodestate * totalIntervals + index])*PG.getP(to, m_rho.get()!=null, rho)[nodestate]) // with SA: ψ_i(r + (1 − r)p_i(τ)) : psi[nodestate * totalIntervals + index]; @@ -596,26 +341,16 @@ else if (node.getChildCount()==2){ // birth / infection event or sampled ancest for (int j = 0; j < n; j++) { if (childstate != j) { -// if (b_ij.length>(n*(n-1))){ // b_ij can change over time -// throw new RuntimeException("ratechanges in b_ij not implemented!"); init[n + childstate] += 0.5 * b_ij[totalIntervals * (childstate * (n - 1) + (j < childstate ? j : j - 1)) + index] * (g0[n + childstate] * g1[n + j] + g0[n + j] * g1[n + childstate]); -// } else { // b_ij cannot change over time -// init[n+childstate] += 0.5 * b_ij[childstate*(n-1)+(j(n*(n-1))) ? (b_ij_Changes+1)*n*(n-1)*state + index(times[i%totalIntervals], b_ijChangeTimes) : (state*(n-1)+(j n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] ; -// } -// } - - birth[i] = R[R.length > n ? (birthChanges+1)*state+index(times[i%totalIntervals], birthRateChangeTimes) : state] - * ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] ; - - if (!SAModel) { - psi[i] = p[p.length > n ? (samplingChanges + 1) * state + index(times[i % totalIntervals], samplingRateChangeTimes) : state] - * ds[ds.length > n ? (deathChanges + 1) * state + index(times[i % totalIntervals], deathRateChangeTimes) : state]; - - death[i] = ds[ds.length > n ? (deathChanges + 1) * state + index(times[i % totalIntervals], deathRateChangeTimes) : state] - psi[i]; - } - - else { - r[i] = removalProbabilities[removalProbabilities.length > n ? (rChanges+1)*state+index(times[i%totalIntervals], rChangeTimes) : state]; - - psi[i] = p[p.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state] - * ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] - / (1+(r[i]-1)*p[p.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state]); - - - death[i] = ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] - psi[i]*r[i]; - } - } - - if (birthAmongDemes) { - - for (int i = 0; i < n; i++){ - - for (int j=0; j(n*(n-1))) - ? (b_ij_Changes+1)*(n-1)*i + index(times[dt], b_ijChangeTimes) - : (i*(n-1)+(j n ? (deathChanges+1)*i+index(times[dt], deathRateChangeTimes) : i]; - } - } - } - - } - } + transformWithinParameters(); + transformAmongParameters(); } + // used to indicate that the state assignment went wrong protected class ConstraintViolatedException extends RuntimeException { private static final long serialVersionUID = 1L; diff --git a/src/beast/evolution/speciation/PiecewiseBirthDeathSamplingDistribution.java b/src/beast/evolution/speciation/PiecewiseBirthDeathMigrationDistribution.java similarity index 60% rename from src/beast/evolution/speciation/PiecewiseBirthDeathSamplingDistribution.java rename to src/beast/evolution/speciation/PiecewiseBirthDeathMigrationDistribution.java index 4fc2cce..c9d3b3a 100644 --- a/src/beast/evolution/speciation/PiecewiseBirthDeathSamplingDistribution.java +++ b/src/beast/evolution/speciation/PiecewiseBirthDeathMigrationDistribution.java @@ -5,7 +5,15 @@ import beast.core.State; import beast.core.parameter.BooleanParameter; import beast.core.parameter.RealParameter; +import beast.core.util.Utils; +import beast.evolution.tree.Node; import beast.evolution.tree.Tree; +import beast.evolution.tree.TreeInterface; +import math.p0_ODE; +import math.p0ge_ODE; +import org.apache.commons.math3.ode.FirstOrderIntegrator; +import org.apache.commons.math3.ode.nonstiff.ClassicalRungeKuttaIntegrator; +import org.apache.commons.math3.ode.nonstiff.DormandPrince853Integrator; import java.util.*; @@ -17,9 +25,27 @@ */ @Description("Piece-wise constant rates are assumed to be ordered by state and time. First k entries of an array give " + "values belonging to type 1, for intervals 1 to k, second k intervals for type 2 etc.") -public abstract class PiecewiseBirthDeathSamplingDistribution extends SpeciesTreeDistribution { +public abstract class PiecewiseBirthDeathMigrationDistribution extends SpeciesTreeDistribution { + public Input frequencies = + new Input<>("frequencies", "The frequencies for each type", Input.Validate.REQUIRED); + + public Input origin = + new Input<>("origin", "The origin of infection x1"); + + public Input originIsRootEdge = + new Input<>("originIsRootEdge", "The origin is only the length of the root edge", false); + + public Input maxEvaluations = + new Input<>("maxEvaluations", "The maximum number of evaluations for ODE solver", 20000); + + public Input conditionOnSurvival = + new Input<>("conditionOnSurvival", "condition on at least one survival? Default true.", true); + + public Input tolerance = + new Input<>("tolerance", "tolerance for numerical integration", 1e-14); + // the interval times for the migration rates public Input migChangeTimesInput = new Input<>("migChangeTimes", "The times t_i specifying when migration rate changes occur", (RealParameter) null); @@ -98,7 +124,7 @@ public abstract class PiecewiseBirthDeathSamplingDistribution extends SpeciesTre new Input<>("migrationMatrix", "Flattened migration matrix, can be asymmetric, diagnonal entries omitted", Input.Validate.REQUIRED); public Input birthRateAmongDemes = - new Input<>("birthRateAmongDemes", "birth rate vector with rate at which transmissions occur among locations"); + new Input<>("birthRateAmongDemes", "birth rate vector with rate at which transmissions occur among locations"); public Input R0AmongDemes = new Input<>("R0AmongDemes", "The basic reproduction number determining transmissions occur among locations"); @@ -108,24 +134,29 @@ public abstract class PiecewiseBirthDeathSamplingDistribution extends SpeciesTre new Input("removalProbability", "The probability of an individual to become noninfectious immediately after the sampling"); - //coupling R0 changes: assume there are 2 types with R0 = [r1,r2], R0AmongDemes = [r12,r21] and coupledR0Changes=[c1a,c1b,c2a,c2b], - // i.e. there are dim(coupledR0Changes)/(#types) = 2 rate changes through time for the R0's - // this translates to R0 = [r1,r1*c1a,r1*c1b,r2,r2*c2a,r2*c2b] - // and R0AmongDemes = [r12,r12*c2a,r12*c2b,r21,r21*c1a,r21*c1b] // note that the scale of change is determined by the "receiving" deme here - public Input coupledR0Changes = - new Input<>("coupledR0Changes", "The scale of change in R0 and R0AmongDemes per interval, when they are assumed to be equal for both"); - - public Input stateNumber = new Input<>("stateNumber", "The number of states or locations", Input.Validate.REQUIRED); public Input adjustTimesInput = new Input<>("adjustTimes", "Origin of MASTER sims which has to be deducted from the change time arrays"); - // + // public Input useRKInput = new Input<>("useRK", "Use fixed step size Runge-Kutta with 1000 steps. Default true", true); + public Input checkRho = new Input<>("checkRho", "check if rho is set if multiple tips are given at present (default true)", true); + + double T; + double orig; + int ntaxa; + + p0_ODE P; + p0ge_ODE PG; + + FirstOrderIntegrator pg_integrator; + public int maxEvalsUsed; + public Double minstep; + public Double maxstep; // these four arrays are totalIntervals in length protected Double[] birth; @@ -188,6 +219,8 @@ public abstract class PiecewiseBirthDeathSamplingDistribution extends SpeciesTre Double[] b_ij; Boolean birthAmongDemes = false; + Double[] freq; + @Override public void initAndValidate() { @@ -281,20 +314,77 @@ public void initAndValidate() { rhoChanges = m_rho.get().getDimension()/n - 1; } - if (coupledR0Changes.get()!=null){ + freq = frequencies.get().getValues(); + + double freqSum = 0; + for (double f : freq) freqSum+= f; + if (freqSum!=1.) throw new RuntimeException("Error: frequencies must add up to 1 but currently add to " + freqSum + "."); + +// // calculate equilibrium frequencies for 2 types: +// double LambMu = -b[0]-b[1]-(d[0]+s[0])+(d[1]+s[1]); +// double c = Math.sqrt(Math.pow(LambMu,2) +4*b_ij[0]*b_ij[1]); +// freq[0] = (c+LambMu)/(c+LambMu+2*b_ij[0]) ; +// freq[1] = 1 - freq[0]; + + } + + public double[] getG(double t, double[] PG0, double t0, + FirstOrderIntegrator pg_integrator, p0ge_ODE PG, Double T, int maxEvalsUsed){ // PG0 contains initial condition for p0 (0..n-1) and for ge (n..2n-1) + + try { + + if (Math.abs(T-t)<1e-10 || Math.abs(t0-t)<1e-10 || T < t) { + return PG0; + } + + double from = t; + double to = t0; + double oneMinusRho; + + int indexFrom = Utils.index(from, times, times.length); + int index = Utils.index(to, times, times.length); + + int steps = index - indexFrom; + if (Math.abs(from-times[indexFrom])<1e-10) steps--; + if (index>0 && Math.abs(to-times[index-1])<1e-10) { + steps--; + index--; + } + index--; + + while (steps > 0){ + + from = times[index];// + 1e-14; + + pg_integrator.integrate(PG, to, PG0, from, PG0); // solve PG , store solution in PG0 + + if (rhoChanges>0){ + for (int i=0; i0 && birthChanges!=n) || (b_ij_Changes>0 && b_ij_Changes!=n*(n-1))) throw new RuntimeException("if coupledR0Changes!=null R0 and R0AmongDemes must be of dimension 1"); + }catch(Exception e){ - birthChanges = coupledR0Changes.get().getDimension()/n; - b_ij_Changes = coupledR0Changes.get().getDimension()/n; + throw new RuntimeException("couldn't calculate g"); } -// collectTimes(T); // these need to be called from implementing initAndValidate -// setRho(); + if (pg_integrator.getEvaluations() > maxEvalsUsed) maxEvalsUsed = pg_integrator.getEvaluations(); + + return PG0; } + void setRho(){ isRhoTip = new Boolean[ treeInput.get().getLeafNodeCount()]; @@ -330,7 +420,7 @@ void setRho(){ Arrays.fill(rho, 0.); for (int i = 0; i < totalIntervals; i++) { for (int j=0;j changeTimes, RealParamet throw new RuntimeException("First time in interval times parameter should always be zero."); } - if (numChanges > 0 && coupledR0Changes.get()==null && intervalTimes.getDimension() != numChanges + 1) { + if (numChanges > 0 && intervalTimes.getDimension() != numChanges + 1) { throw new RuntimeException("The time interval parameter should be numChanges + 1 long (" + (numChanges + 1) + ")."); } @@ -471,6 +561,73 @@ public void getChangeTimes(double maxTime, List changeTimes, RealParamet } } + + + void updateBirthDeathPsiParams(){ + + Double[] birthRates = birthRate.get().getValues(); + Double[] deathRates = deathRate.get().getValues(); + Double[] samplingRates = samplingRate.get().getValues(); + Double[] removalProbabilities = new Double[1]; + + if (SAModel) { + removalProbabilities = removalProbability.get().getValues(); + r = new Double[n*totalIntervals]; + } + + int state; + + for (int i = 0; i < n*totalIntervals; i++) { + + state = i/totalIntervals; + + birth[i] = birthRates[birthRates.length > n ? (birthChanges+1)*state+index(times[i%totalIntervals], birthRateChangeTimes) : state]; + death[i] = deathRates[deathRates.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state]; + psi[i] = samplingRates[samplingRates.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state]; + if (SAModel) r[i] = removalProbabilities[removalProbabilities.length > n ? (rChanges+1)*state+index(times[i%totalIntervals], rChangeTimes) : state]; + + } + + } + + + void updateAmongParameter(Double[] param, Double[] paramFrom, int nrChanges, List changeTimes){ + + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + for (int dt = 0; dt < totalIntervals; dt++) { + if (i != j) { + param[(i * (n - 1) + (j < i ? j : j - 1)) * totalIntervals + dt] + = paramFrom[(paramFrom.length > (n * (n - 1))) + ? (nrChanges + 1) * (n - 1) * i + index(times[dt], changeTimes) + : (i * (n - 1) + (j < i ? j : j - 1))]; + } + } + } + } + + } + + void updateRho(){ + if (m_rho.get() != null && (m_rho.get().getDimension()==1 || rhoSamplingTimes.get() != null)) { + + Double[] rhos = m_rho.get().getValues(); + rho = new Double[n*totalIntervals]; + int state; + + for (int i = 0; i < totalIntervals*n; i++) { + + state = i/totalIntervals; + + rho[i]= rhoChanges>0? + rhoSamplingChangeTimes.contains(times[i]) ? rhos[rhos.length > n ? (rhoChanges+1)*state+index(times[i%totalIntervals], rhoSamplingChangeTimes) : state] : 0. + : rhos[0]; + } + } + } + + + /** * @param t the time in question * @return the index of the given time in the list of times, or if the time is not in the list, the index of the @@ -488,21 +645,140 @@ public int index(double t, List times) { return epoch; } - // Interface requirements: - @Override - public List getArguments() { - return null; - } + public void transformWithinParameters(){ + + Double[] p = samplingProportion.get().getValues(); + Double[] ds = becomeUninfectiousRate.get().getValues(); + Double[] R = R0.get().getValues(); + Double[] removalProbabilities = new Double[1]; + if (SAModel) removalProbabilities = removalProbability.get().getValues(); + + int state; + + for (int i = 0; i < totalIntervals*n; i++){ + + state = i/totalIntervals; + + birth[i] = R[R.length > n ? (birthChanges+1)*state+index(times[i%totalIntervals], birthRateChangeTimes) : state] + * ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state]; + + if (!SAModel) { + psi[i] = p[p.length > n ? (samplingChanges + 1) * state + index(times[i % totalIntervals], samplingRateChangeTimes) : state] + * ds[ds.length > n ? (deathChanges + 1) * state + index(times[i % totalIntervals], deathRateChangeTimes) : state]; + + death[i] = ds[ds.length > n ? (deathChanges + 1) * state + index(times[i % totalIntervals], deathRateChangeTimes) : state] - psi[i]; + } + + else { + r[i] = removalProbabilities[removalProbabilities.length > n ? (rChanges+1)*state+index(times[i%totalIntervals], rChangeTimes) : state]; + + psi[i] = p[p.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state] + * ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] + / (1+(r[i]-1)*p[p.length > n ? (samplingChanges+1)*state+index(times[i%totalIntervals], samplingRateChangeTimes) : state]); + + + death[i] = ds[ds.length > n ? (deathChanges+1)*state+index(times[i%totalIntervals], deathRateChangeTimes) : state] - psi[i]*r[i]; + } + } + + } + + public void transformAmongParameters(){ + + Double[] RaD = (birthAmongDemes) ? R0AmongDemes.get().getValues() : new Double[1]; + Double[] ds = becomeUninfectiousRate.get().getValues(); + + if (birthAmongDemes) { + + for (int i = 0; i < n; i++){ + + for (int j=0; j(n*(n-1))) + ? (b_ij_Changes+1)*(n-1)*i + index(times[dt], b_ijChangeTimes) + : (i*(n-1)+(j n ? (deathChanges+1)*i+index(times[dt], deathRateChangeTimes) : i]; + } + } + } + + } + } + } + + void checkOrigin(TreeInterface tree){ + + if (origin.get()==null){ + T = tree.getRoot().getHeight(); + } + else { + + updateOrigin(tree.getRoot()); + + if (!Boolean.valueOf(System.getProperty("beast.resume")) && orig < 0) + throw new RuntimeException("Error: origin("+T+") must be larger than tree height("+tree.getRoot().getHeight()+")!"); + } + + } + + void updateOrigin(Node root){ + + T = origin.get().getValue(); + orig = T - root.getHeight(); + + if (originIsRootEdge.get()) { + + orig = origin.get().getValue(); + T = orig + root.getHeight(); + } + + } + + void setupIntegrators(){ // set up ODE's and integrators + + if (minstep == null) minstep = tolerance.get(); + if (maxstep == null) maxstep = 1000.; + + Boolean augmented = this instanceof BirthDeathMigrationModel; + + P = new p0_ODE(birth, ((!augmented && birthAmongDemes) ? b_ij : null), death,psi,M, n, totalIntervals, times); + PG = new p0ge_ODE(birth, ((!augmented && birthAmongDemes) ? b_ij : null), death,psi,M, n, totalIntervals, T, times, P, maxEvaluations.get(), augmented); + + + if (!useRKInput.get()) { + pg_integrator = new DormandPrince853Integrator(minstep, maxstep, tolerance.get(), tolerance.get()); // + pg_integrator.setMaxEvaluations(maxEvaluations.get()); + + PG.p_integrator = new DormandPrince853Integrator(minstep, maxstep, tolerance.get(), tolerance.get()); // + PG.p_integrator.setMaxEvaluations(maxEvaluations.get()); + } else { + pg_integrator = new ClassicalRungeKuttaIntegrator(T / 1000); + PG.p_integrator = new ClassicalRungeKuttaIntegrator(T / 1000); + + } + } + + + // Interface requirements: + + @Override + public List getArguments() { + return null; + } - @Override - public List getConditions() { - return null; - } + @Override + public List getConditions() { + return null; + } - @Override - public void sample(State state, Random random) { - } + @Override + public void sample(State state, Random random) { + } @Override public boolean requiresRecalculation(){ diff --git a/src/test/beast/evolution/speciation/BirthDeathMigrationTest.java b/src/test/beast/evolution/speciation/BirthDeathMigrationTest.java index d8309fc..3fa7aa2 100755 --- a/src/test/beast/evolution/speciation/BirthDeathMigrationTest.java +++ b/src/test/beast/evolution/speciation/BirthDeathMigrationTest.java @@ -44,11 +44,11 @@ public void testSALikelihoodCalculationWithoutAncestors() throws Exception { bdm.setInputValue("becomeUninfectiousRate", new RealParameter("1.5")); bdm.setInputValue("samplingProportion", new RealParameter("0.3") ); bdm.setInputValue("removalProbability", new RealParameter("0.9") ); - bdm.setInputValue("conditionOnSurvival", false); + bdm.setInputValue("conditionOnSurvival", true); bdm.initAndValidate(); - assertEquals(-16.281647428602657, bdm.calculateLogP(), 1e-4); // this result is from BEAST (BirthDeathMigrationModelUncoloured), not double checked in R + assertEquals(-15.99699690815937, bdm.calculateLogP(), 1e-4); // this result is from BEAST (BirthDeathMigrationModelUncoloured), not double checked in R } diff --git a/src/test/beast/evolution/speciation/BirthDeathMigrationUncolouredTest.java b/src/test/beast/evolution/speciation/BirthDeathMigrationUncolouredTest.java index 30c2196..4bfbd2c 100755 --- a/src/test/beast/evolution/speciation/BirthDeathMigrationUncolouredTest.java +++ b/src/test/beast/evolution/speciation/BirthDeathMigrationUncolouredTest.java @@ -7,7 +7,7 @@ import beast.evolution.alignment.TaxonSet; import beast.evolution.alignment.Taxon; import beast.util.TreeParser; -import beast.util.ZeroBranchSATreeParser; +//import beast.util.ZeroBranchSATreeParser; import junit.framework.TestCase; import org.junit.Test; @@ -207,49 +207,95 @@ public void testLikelihood1dim() throws Exception { } +// @Test +// public void testSALikelihoodCalculation1() throws Exception { +// +// for (int i=0; i<2; i++) { +// +// BirthDeathMigrationModelUncoloured model = new BirthDeathMigrationModelUncoloured(); +// +// ZeroBranchSATreeParser tree = (i==0)? new ZeroBranchSATreeParser("((1[&type=0]:1.0)2[&type=0]:1.0)3[&type=0]:0.0", true, false, 1) +// : new ZeroBranchSATreeParser("((1:1.5,2:0.5):0.5)3:0.0", true, false, 1); +// +// model.setInputValue("tree", tree); +// model.setInputValue("origin", new RealParameter("10.")); +// +// model.setInputValue("birthRate", new RealParameter("2.")); +// model.setInputValue("deathRate", new RealParameter("0.99")); +// model.setInputValue("samplingRate", new RealParameter("0.5")); +// +// // model.setInputValue("R0", new RealParameter(new Double[]{2./1.49})); +// // model.setInputValue("becomeUninfectiousRate", new RealParameter("1.49")); +// // model.setInputValue("samplingProportion", new RealParameter(new Double[]{0.5/1.49}) ); +// +// model.setInputValue("removalProbability", new RealParameter("0.9")); +// model.setInputValue("conditionOnSurvival", false); +// +// model.setInputValue("stateNumber", "1"); +// model.setInputValue("typeLabel", "type"); +// model.setInputValue("migrationMatrix", "0."); +// model.setInputValue("frequencies", "1"); +// +// model.setInputValue("R0", new RealParameter("1.5")); +// model.setInputValue("becomeUninfectiousRate", new RealParameter("1.5")); +// model.setInputValue("samplingProportion", new RealParameter("0.3")); +// +// model.initAndValidate(); +// +// // these values ate calculated with Mathematica +// if (i==0) assertEquals(-25.3707, model.calculateTreeLogLikelihood(tree), 1e-5); // likelihood conditioning only on parameters and origin time +// else assertEquals(-22.524157039646802, model.calculateTreeLogLikelihood(tree), 1e-5); +// } +// } + @Test - public void testSALikelihoodCalculation1() throws Exception { + public void testSALikelihoodCalculationWithoutAncestors() throws Exception { - for (int i=0; i<2; i++) { - BirthDeathMigrationModelUncoloured model = new BirthDeathMigrationModelUncoloured(); + BirthDeathMigrationModelUncoloured bdm = new BirthDeathMigrationModelUncoloured(); + + ArrayList taxa = new ArrayList(); - ZeroBranchSATreeParser tree = (i==0)? new ZeroBranchSATreeParser("((1[&type=0]:1.0)2[&type=0]:1.0)3[&type=0]:0.0", true, false, 1) - : new ZeroBranchSATreeParser("((1:1.5,2:0.5):0.5)3:0.0", true, false, 1); + for (int i=1; i<=4; i++){ + taxa.add(new Taxon(""+i)); + } - model.setInputValue("tree", tree); - model.setInputValue("origin", new RealParameter("10.")); + Tree tree = new TreeParser(); + tree.setInputValue("taxonset", new TaxonSet(taxa)); + tree.setInputValue("adjustTipHeights", "false"); + tree.setInputValue("IsLabelledNewick", "true"); + tree.setInputValue("newick", "((3 : 1.5, 4 : 0.5) : 1 , (1 : 2, 2 : 1) : 3);"); + tree.initAndValidate(); - model.setInputValue("birthRate", new RealParameter("2.")); - model.setInputValue("deathRate", new RealParameter("0.99")); - model.setInputValue("samplingRate", new RealParameter("0.5")); + TraitSet trait = new TraitSet(); + trait.setInputValue("taxa", new TaxonSet(taxa)); + trait.setInputValue("value", "1=0,2=0,3=0,4=0"); + trait.setInputValue("traitname", "tiptypes"); + trait.initAndValidate(); - // model.setInputValue("R0", new RealParameter(new Double[]{2./1.49})); - // model.setInputValue("becomeUninfectiousRate", new RealParameter("1.49")); - // model.setInputValue("samplingProportion", new RealParameter(new Double[]{0.5/1.49}) ); + bdm.setInputValue("tree", tree); + bdm.setInputValue("tiptypes", trait); - model.setInputValue("removalProbability", new RealParameter("0.9")); - model.setInputValue("conditionOnSurvival", false); + bdm.setInputValue("origin", "10."); + bdm.setInputValue("stateNumber", "1"); + bdm.setInputValue("migrationMatrix", "0."); + bdm.setInputValue("frequencies", "1"); - model.setInputValue("stateNumber", "1"); - model.setInputValue("typeLabel", "type"); - model.setInputValue("migrationMatrix", "0."); - model.setInputValue("frequencies", "1"); + bdm.setInputValue("R0", new RealParameter("1.5")); + bdm.setInputValue("becomeUninfectiousRate", new RealParameter("1.5")); + bdm.setInputValue("samplingProportion", new RealParameter("0.3") ); + bdm.setInputValue("removalProbability", new RealParameter("0.9") ); + bdm.setInputValue("conditionOnSurvival", true); - model.setInputValue("R0", new RealParameter("1.5")); - model.setInputValue("becomeUninfectiousRate", new RealParameter("1.5")); - model.setInputValue("samplingProportion", new RealParameter("0.3")); + bdm.initAndValidate(); - model.initAndValidate(); + // likelihood conditioning on at least one sampled individual - "true" result from BEAST one-deme SA model 09 June 2015 (DK) + assertEquals(-25.991511346557598, bdm.calculateLogP(), 1e-4); - // these values ate calculated with Mathematica - if (i==0) assertEquals(-25.3707, model.calculateTreeLogLikelihood(tree), 1e-5); // likelihood conditioning only on parameters and origin time - else assertEquals(-22.524157039646802, model.calculateTreeLogLikelihood(tree), 1e-5); - } } @Test - public void testSALikelihoodCalculationWithoutAncestors() throws Exception { + public void testSALikelihoodCalculationWithoutAncestorsWithoutOrigin() throws Exception { BirthDeathMigrationModelUncoloured bdm = new BirthDeathMigrationModelUncoloured(); @@ -276,7 +322,6 @@ public void testSALikelihoodCalculationWithoutAncestors() throws Exception { bdm.setInputValue("tree", tree); bdm.setInputValue("tiptypes", trait); - bdm.setInputValue("origin", "10."); bdm.setInputValue("stateNumber", "1"); bdm.setInputValue("migrationMatrix", "0."); bdm.setInputValue("frequencies", "1"); @@ -290,7 +335,7 @@ public void testSALikelihoodCalculationWithoutAncestors() throws Exception { bdm.initAndValidate(); // likelihood conditioning on at least one sampled individual - "true" result from BEAST one-deme SA model 09 June 2015 (DK) - assertEquals(-25.671303367076007, bdm.calculateLogP(), 1e-4); + assertEquals(-15.99699690815937, bdm.calculateLogP(), 1e-4); } diff --git a/templates/BDMM.xml b/templates/BDMM.xml new file mode 100644 index 0000000..bd4d8ac --- /dev/null +++ b/templates/BDMM.xml @@ -0,0 +1,442 @@ + + + beast.app.beauti.BeautiConnector + beast.app.beauti.BeautiSubTemplate + beast.math.distributions.Uniform + beast.math.distributions.Normal + beast.math.distributions.OneOnX + beast.math.distributions.LogNormalDistributionModel + beast.math.distributions.Exponential + beast.math.distributions.Gamma + beast.math.distributions.Beta + beast.math.distributions.LaplaceDistribution + beast.math.distributions.InverseGamma + beast.math.distributions.Prior + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +]]> + + + + + + + + + + + + + + + + + + + Prior on gamma shape for partition s:$(n) + Prior on proportion invariant for partition s:$(n) + + + + + + + Scales proportion of invariant sites parameter of partition $(n) + Scales mutation rate of partition s:$(n) + Scales gamma shape parameter of partition s:$(n) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/version.xml b/version.xml index 652ef5b..1164005 100644 --- a/version.xml +++ b/version.xml @@ -1,6 +1,4 @@ - + - - - +