diff --git a/examples/HCV_ad_e1_type4_rev.xml b/examples/HCV_ad_e1_type4_rev.xml new file mode 100644 index 0000000..65a2e35 --- /dev/null +++ b/examples/HCV_ad_e1_type4_rev.xml @@ -0,0 +1,249 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/HCV_oup_40steps.xml b/examples/HCV_oup_40steps.xml new file mode 100644 index 0000000..d36f948 --- /dev/null +++ b/examples/HCV_oup_40steps.xml @@ -0,0 +1,258 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/HCV_oup_bdsky.R b/examples/HCV_oup_bdsky.R new file mode 100644 index 0000000..c1f30ca --- /dev/null +++ b/examples/HCV_oup_bdsky.R @@ -0,0 +1,41 @@ +# This script plots OU-BDSKY plot + +# origin_post is a posterior vector of origins +# r0 is a data table of posterior samples of r0 vectors, one row per sample +# time_grid is a vector of times to evaluate the skyline at +bdsky_post <- function(origin_post, r0, time_grid) { + + r0_time_gridded <- list() + + n <- ncol(r0) + + for (s in 1:length(origin_post)) { + + origin <- origin_post[s] + r0_vec <- r0[s,] + + ind <- pmax(1,n - floor(time_grid / origin * n)) + r0_time_gridded[[s]] <- r0_vec[ind] + } + + return (r0_time_gridded) +} + +setwd("~/Git/bdsky/examples/") + +lf <- read.table("HCV_oup_40_1447852063188.log", sep="\t", header=T) + +origin_post <- lf$orig_root + +r0_subset <- lf[grepl("R0", names(lf))] + +time_grid <- 1:400 + +bdskypost <- bdsky_post(origin_post, r0_subset, time_grid) + +pdf("hcv_oubdsky.pdf") +plot(time_grid,bdskypost[[950]],type='S', xlab="Time (years before present)", ylab="R0",col=rgb(0,0,1,0.1)) +for (s in 20:200*50) { + lines(time_grid, bdskypost[[s]], type='S',col=rgb(0,0,1,0.1)) +} +dev.off() diff --git a/examples/testOUPrior.xml b/examples/testOUPrior.xml new file mode 100644 index 0000000..a11a111 --- /dev/null +++ b/examples/testOUPrior.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lib/jchart2d-3.2.2.jar b/lib/jchart2d-3.2.2.jar new file mode 100644 index 0000000..698ff25 Binary files /dev/null and b/lib/jchart2d-3.2.2.jar differ diff --git a/src/bdsky/BDSSkylineSegment.java b/src/bdsky/BDSSkylineSegment.java new file mode 100644 index 0000000..209c340 --- /dev/null +++ b/src/bdsky/BDSSkylineSegment.java @@ -0,0 +1,35 @@ +package bdsky; + +/** + * A piecewise constant segment of a skyline. + */ +public class BDSSkylineSegment extends SkylineSegment { + + public BDSSkylineSegment(double lambda, double mu, double psi, double r, double t1, double t2) { + + super(t1, t2, new double[]{lambda, mu, psi, r}); + } + + /** + * @return the birth rate per unit time. + */ + public double lambda() { return value[0]; }; + + /** + * @return the death rate per unit time. + */ + public double mu() { return value[1]; }; + + /** + * @return the sampling rate per unit time. + */ + public double psi() { return value[2]; }; + + /** + * @return the removal probability, i.e. the probability that sampling causes recovery/removal/death. + */ + public double r() { return value[3]; }; + + // TODO fold rho sampling events into the Skyline + // public boolean hasRho(); +} diff --git a/src/bdsky/MultiSkyline.java b/src/bdsky/MultiSkyline.java new file mode 100644 index 0000000..23ab6c5 --- /dev/null +++ b/src/bdsky/MultiSkyline.java @@ -0,0 +1,135 @@ +package bdsky; + +import beast.core.CalculationNode; +import beast.core.Input; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * A multiskyline made up of simple skylines. + * This skyline will have a number of segments equal to the union of the number of unique boundaries in the daughter skylines. + */ +public class MultiSkyline extends CalculationNode implements Skyline { + + public Input> skylineInput = new Input<>("skyline", "the simple skylines making up this multiple skyline", new ArrayList<>()); + + public MultiSkyline(SimpleSkyline... skyline) { + try { + initByName("skyline", Arrays.asList(skyline)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void initAndValidate() throws Exception {} + + @Override + public List getSegments() { + + List skylines = skylineInput.get(); + + + List boundaries = new ArrayList<>(); + + for (int i = 0; i < skylines.size(); i++) { + Skyline skyline = skylines.get(i); + List segments = skyline.getSegments(); + for (int j = 0; j < segments.size(); j++) { + boundaries.add(new Boundary2(j, segments.get(j).t1, skyline, i)); + } + } + Collections.sort(boundaries, (o1, o2) -> Double.compare(o1.time, o2.time)); + + System.out.println(boundaries); + + int[] index = new int[skylines.size()]; + + List segments = new ArrayList<>(); + + + System.out.println("Boundaries.size = " + boundaries.size()); + + int i = 0; + double start = boundaries.get(0).time; + while (i < boundaries.size()) { + + int j = i + 1; + + double end = Double.POSITIVE_INFINITY; + if (j != boundaries.size()) { + Boundary2 boundary = boundaries.get(j); + end = boundary.time; + + while (j < boundaries.size() && end == start) { + j += 1; + if (j == boundaries.size()) { + end = Double.POSITIVE_INFINITY; + } else { + end = boundaries.get(j).time; + } + } + System.out.println("next end = " + boundaries.get(j)); + } + + double[] value = new double[index.length]; + for (int k = 0; k < index.length; k++) { + + int ind = index[k]; + + value[k] = skylines.get(k).getValues().get(ind)[0]; + } + + SkylineSegment segment = new SkylineSegment(start, end, value); + segments.add(segment); + System.out.println("Added segment: " + segment); + + if (j != boundaries.size()) { + index[boundaries.get(j).skylineIndex] += 1; + System.out.println("incremented index for skyline " + boundaries.get(j).skylineIndex); + } + i = j; + start = end; + } + + return segments; + } + + @Override + public int getDimension() { + + int dim = 0; + for (SimpleSkyline skyline : skylineInput.get()) { + + dim += skyline.getDimension(); + } + return dim; + } + + class Boundary2 { + // the index + int index; + + // time of the boundary + double time; + + // the skyline the boundary is in + Skyline skyline; + + int skylineIndex; + + Boundary2(int index, double time, Skyline skyline, int skylineIndex) { + this.index = index; + this.time = time; + this.skyline = skyline; + this.skylineIndex = skylineIndex; + } + + public String toString() { + return "skyline[" + skylineIndex + "].time(" + index + ")=" + time; + } + } +} diff --git a/src/bdsky/SimpleSkyline.java b/src/bdsky/SimpleSkyline.java new file mode 100644 index 0000000..2d204a1 --- /dev/null +++ b/src/bdsky/SimpleSkyline.java @@ -0,0 +1,157 @@ +package bdsky; + +import bdsky.Skyline; +import bdsky.SkylineSegment; +import beast.core.CalculationNode; +import beast.core.Input; +import beast.core.parameter.RealParameter; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * A skyline function for a parameter + */ +public class SimpleSkyline extends CalculationNode implements Skyline { + + // the interval times for the skyline function (e.g. "0 1 2 3") + public Input timesInput = + new Input("times", "The times t_i specifying when the parameter changes occur. " + + "Times must be in ascending order.", (RealParameter) null); + + // the parameter values, must have the same length as times vector + // (e.g. "-1 3 4 -1.5", means -1 between [0,1), 3 between [1,2), ..., -1.5 between [3,infinity)) + public Input parameterInput = + new Input("parameter", + "The parameter values specifying the value for each piecewise constant segment of the skyline function. " + + "The first value is between t_0 and t_1, the last value is between t_n and infinity. " + + "Should be the same length as time vector", (RealParameter) null); + + @Override + public void initAndValidate() throws Exception { + + Double[] times = getTimes(); + + double smallest = Double.NEGATIVE_INFINITY; + for (Double time : times) { + if (time < smallest) { + throw new RuntimeException("Times must be in ascending order!"); + } + smallest = time; + } + } + + /** + * + * @return the times for this skyline function + */ + public Double[] getTimes() { + return timesInput.get().getValues(); + } + + /** + * + * @return the values for this skyline function + */ + public List getValues() { + + List values = new ArrayList<>(); + for (double val : parameterInput.get().getValues()) { + values.add(new double[] {val}); + } + + return values; + } + + private Double[] rawValues() { + return parameterInput.get().getValues(); + } + + public double[] getValue(double time) { + Double[] times = getTimes(); + + if (time < times[0]) { + throw new RuntimeException("Time is smaller than smallest time in skyline function!"); + } + + int index = Arrays.binarySearch(times,time); + + Double[] values = rawValues(); + if (index < 0) { + //returns (-(insertion point) - 1) + int insertionPoint = -(index+1); + return new double[] {values[insertionPoint-1]}; + } else { + return new double[] {values[index]}; + } + } + + /** + * @param time1 + * @param time2 + * @return the segments of the skyline plot between the two times. + */ + public List getSegments(double time1, double time2) { + + Double[] times = getTimes(); + + if (time1 < times[0] || time2 < times[0]) { + throw new RuntimeException("Time is smaller than smallest time in skyline function!"); + } + + if (time1 > time2 || time1 == time2) { + throw new RuntimeException("time1 must be smaller than time2!"); + } + + List segments = new ArrayList<>(); + + int index1 = Arrays.binarySearch(times, time1); + int index2 = Arrays.binarySearch(times, time2); + + Double[] rawValues = rawValues(); + + // same insertion point + if (index1 == index2) { + int insertionPoint = -(index1 + 1); + segments.add(new SkylineSegment(time1, time2, rawValues[insertionPoint-1])); + return segments; + } + + // not same insertion points + if (index1 < 0) { + int insertionPoint = -(index1 + 1); + segments.add(new SkylineSegment(time1,times[insertionPoint], rawValues[insertionPoint-1])); + index1 = insertionPoint; + if (index1 == index2) return segments; + } else { + segments.add(new SkylineSegment(times[index1],times[index1+1], rawValues[index1])); + index1 += 1; + if (index1 == index2) return segments; + } + if (index2 < 0) { + int insertionPoint = -(index2 + 1); + for (int i = index1; i < insertionPoint-1; i++ ) { + segments.add(new SkylineSegment(times[i],times[i+1], rawValues[i])); + } + segments.add(new SkylineSegment(times[insertionPoint-1],time2, rawValues[insertionPoint-1])); + return segments; + } else { + for (int i = index1; i < index2; i++ ) { + segments.add(new SkylineSegment(times[i],times[i+1], rawValues[i])); + } + } + return segments; + } + + + @Override + public List getSegments() { + return getSegments(0, Double.POSITIVE_INFINITY); + } + + @Override + public int getDimension() { + return 1; + } +} diff --git a/src/bdsky/Skyline.java b/src/bdsky/Skyline.java new file mode 100644 index 0000000..b85435b --- /dev/null +++ b/src/bdsky/Skyline.java @@ -0,0 +1,79 @@ +package bdsky; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * A multivariate skyline interface + */ +public interface Skyline { + + /** + * @return a list of segments in increasing time order. + */ + List getSegments(); + + /** + * @param time the time of interest + * @return the value of this skyline at the given time + */ + default double[] getValue(double time) { + + Double[] times = getTimes(); + + if (time < times[0]) { + throw new RuntimeException("Time is smaller than smallest time in skyline function!"); + } + + int index = Arrays.binarySearch(times,time); + + List values = getValues(); + if (index < 0) { + //returns (-(insertion point) - 1) + int insertionPoint = -(index+1); + return values.get(insertionPoint-1); + } else { + return values.get(index); + } + } + + /** + * @return the start times of the segments in index order. + */ + default Double[] getTimes() { + + List segments = getSegments(); + + Double[] times = new Double[segments.size()]; + for (int i = 0; i < times.length; i++) { + times[i] = segments.get(i).t1; + } + + if (segments.get(segments.size()-1).t2 < Double.POSITIVE_INFINITY) { + throw new RuntimeException("Last segment should extend to positive infinity!"); + } + + return times; + } + + /** + * @return the values of the segments in index order. + */ + default List getValues() { + + List segments = getSegments(); + + List values = new ArrayList<>(); + for (int i = 0; i < segments.size(); i++) { + values.add(segments.get(i).value); + } + return values; + } + + /** + * This is not the number of segments, but the dimension of each segment. + * @return the dimension of the parameter in this skyline. + */ + int getDimension(); +} diff --git a/src/bdsky/SkylinePlot.java b/src/bdsky/SkylinePlot.java new file mode 100644 index 0000000..768a480 --- /dev/null +++ b/src/bdsky/SkylinePlot.java @@ -0,0 +1,91 @@ +package bdsky; + +import java.awt.*; +import java.awt.event.WindowAdapter; +import java.awt.event.WindowEvent; +import java.util.List; +import java.util.Random; +import javax.swing.JFrame; + +import beast.core.BEASTObject; +import beast.core.parameter.RealParameter; +import info.monitorenter.gui.chart.Chart2D; +import info.monitorenter.gui.chart.ITrace2D; +import info.monitorenter.gui.chart.traces.Trace2DSimple; + +public class SkylinePlot { + + private SkylinePlot() { + super(); + } + + public static void addTrace(Skyline skyline, Chart2D chart, Color color) { + + + List segments = skyline.getSegments(); + int size = segments.get(0).value.length; + + for (int i = 0; i < size; i++) { + + ITrace2D trace = new Trace2DSimple(); + // Add the trace to the chart. This has to be done before adding points (deadlock prevention): + + chart.addTrace(trace); + trace.setColor(color); + + for (SkylineSegment segment : segments) { + trace.addPoint(segment.t1, segment.value[0]); + + if (segment.t2 != Double.POSITIVE_INFINITY) { + trace.addPoint(segment.t2, segment.value[0]); + } else { + double extra = 1.0; + if (segments.size() > 1) { + extra = (segment.t1 - trace.getMinX()) / (segments.size() - 1); + } + trace.addPoint(segment.t1 + extra, segment.value[0]); + } + } + } + } + + public static void main(String[] args) throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3 4")); + skyline.setInputValue("parameter", new RealParameter("0.5 0.0 3 2 5.5")); + skyline.setID("skyline1"); + + SimpleSkyline skyline2 = new SimpleSkyline(); + skyline2.setInputValue("times", new RealParameter("0 1.1 2.2 3.3 4.4")); + skyline2.setInputValue("parameter", new RealParameter("1.5 3.2 0.2 5 2.5")); + skyline2.setID("skyline2"); + + // Create a chart: + Chart2D chart = new Chart2D(); + // Create an ITrace: + + addTrace(skyline, chart, Color.red); + addTrace(skyline2, chart, Color.blue); + // Add all points, as it is static: + + + + + // Make it visible: + // Create a frame. + JFrame frame = new JFrame("SkylinePlot"); + // add the chart to the frame: + frame.getContentPane().add(chart); + frame.setSize(800, 600); + // Enable the termination button [cross on the upper right edge]: + frame.addWindowListener( + new WindowAdapter() { + public void windowClosing(WindowEvent e) { + System.exit(0); + } + } + ); + frame.setVisible(true); + } +} \ No newline at end of file diff --git a/src/bdsky/SkylineSegment.java b/src/bdsky/SkylineSegment.java new file mode 100644 index 0000000..afc9343 --- /dev/null +++ b/src/bdsky/SkylineSegment.java @@ -0,0 +1,70 @@ +package bdsky; + +import java.util.Arrays; + +/** + * A piecewise constant segment of a skyline. + */ +public class SkylineSegment { + + // the start time of this segment + public double t1; + + // the end time of this segment + public double t2; + + // the parameter values + public double[] value; + + SkylineSegment next, prev = null; + + public SkylineSegment(double start, double end, double value) { + this.t1 = start; + this.t2 = end; + this.value = new double[] {value}; + } + + public SkylineSegment(double start, double end, double[] value) { + this.t1 = start; + this.t2 = end; + this.value = value; + } + + void setNextSegment(SkylineSegment next) { + if (next.t1 == t2) { + this.next = next; + } else { + throw new RuntimeException("next.t1 must equal this.t2!"); + } + if (next.prev != this) { + next.setPreviousSegment(this); + } + } + + void setPreviousSegment(SkylineSegment prev) { + if (prev.t2 == t1) { + this.prev = prev; + } else { + throw new RuntimeException("prev.t2 must equal this.t1!"); + } + if (prev.next != this) { + prev.setNextSegment(this); + } + } + + /** + * @return the start time of the segment. + */ + public final double start() { return t1; } + + /** + * @return the end time of the segment. + */ + public final double end() { return t2; } + + public String toString() { + return "segment(" + t1 + "," + t2 + ") = " + Arrays.toString(value); + } + + +} diff --git a/src/beast/evolution/speciation/BDSParameterization.java b/src/beast/evolution/speciation/BDSParameterization.java new file mode 100644 index 0000000..68c702f --- /dev/null +++ b/src/beast/evolution/speciation/BDSParameterization.java @@ -0,0 +1,83 @@ +package beast.evolution.speciation; + +import bdsky.BDSSkylineSegment; +import bdsky.MultiSkyline; +import bdsky.Skyline; +import bdsky.SkylineSegment; +import beast.core.CalculationNode; +import beast.core.Input; +import beast.core.parameter.RealParameter; + +import java.util.ArrayList; +import java.util.List; + +/** + * A parameterization of the birth-death skyline model + */ +public abstract class BDSParameterization extends CalculationNode { + + MultiSkyline multiSkyline; + + public Input origin = + new Input("origin", "The time from origin to last sample (must be larger than tree height)", Input.Validate.REQUIRED); + + public final void setMultiSkyline(MultiSkyline multiSkyline) { + this.multiSkyline = multiSkyline; + } + + /** + * @return the canonical segments for this skyline model. + */ + public final List canonicalSegments(){ + + List canonical = new ArrayList<>(); + + for (SkylineSegment seg : multiSkyline.getSegments()) { + canonical.add(toCanonicalSegment(seg)); + } + return canonical; + } + + /** + * @return the number of segments in this parameterization. + */ + public final int size() { + return multiSkyline.getSegments().size(); + } + + public final void populateCanonical(Double[] birth, Double[] death, Double[] psi, Double[] r, Double[] times) { + int size = size(); + if (birth.length != size || death.length != size || psi.length != size || r.length != size) { + throw new RuntimeException("array size unexpected!"); + } + List canonicalSegments = canonicalSegments(); + for (int i = 0; i < size; i++) { + BDSSkylineSegment seg = canonicalSegments.get(i); + birth[i] = seg.lambda(); + death[i] = seg.mu(); + psi[i] = seg.psi(); + r[i] = seg.r(); + times[i] = seg.start(); + } + } + + /** + * @return the time of the origin of the process before the present. + */ + public double origin() { + return origin.get().getValue(); + } + + /** + * @return true if any segments have a removalProbability < 1 + */ + public final boolean isSampledAncestorModel() { + for (BDSSkylineSegment seg : canonicalSegments()) { + if (seg.r() < 1.0) return true; + } + return false; + } + + public abstract BDSSkylineSegment toCanonicalSegment(SkylineSegment segment); + +} diff --git a/src/beast/evolution/speciation/BirthDeathSkylineModel.java b/src/beast/evolution/speciation/BirthDeathSkylineModel.java index 34c2aa1..9359f7b 100644 --- a/src/beast/evolution/speciation/BirthDeathSkylineModel.java +++ b/src/beast/evolution/speciation/BirthDeathSkylineModel.java @@ -24,7 +24,7 @@ "to allow for birth and death rates to change at times t_i") @Citation("Stadler, T., Kuehnert, D., Bonhoeffer, S., and Drummond, A. J. (2013):\n Birth-death skyline " + "plot reveals temporal changes of\n epidemic spread in HIV and hepatitis C virus (HCV). PNAS 110(1): 228–33.\n" + - "If sampled ancestors are used then please also site: Gavryushkina A, Welch D, Stadler T, Drummond AJ (2014) \n" + + "If sampled ancestors are used then please also cite: Gavryushkina A, Welch D, Stadler T, Drummond AJ (2014) \n" + "Bayesian inference of sampled ancestor trees for epidemiology and fossil calibration. \n" + "PLoS Comput Biol 10(12): e1003919. doi:10.1371/journal.pcbi.1003919") public class BirthDeathSkylineModel extends SpeciesTreeDistribution { @@ -988,4 +988,4 @@ public Boolean isSeasonalBDSIR() { public int getSIRdimension() { throw new RuntimeException("This is not an SIR"); } -} \ No newline at end of file +} diff --git a/src/beast/evolution/speciation/CanonicalParameterization.java b/src/beast/evolution/speciation/CanonicalParameterization.java new file mode 100644 index 0000000..e3a20ae --- /dev/null +++ b/src/beast/evolution/speciation/CanonicalParameterization.java @@ -0,0 +1,48 @@ +package beast.evolution.speciation; + +import bdsky.BDSSkylineSegment; +import bdsky.MultiSkyline; +import bdsky.SimpleSkyline; +import bdsky.SkylineSegment; +import beast.core.Input; +import beast.core.parameter.RealParameter; + +/** + * Created by alexeid on 7/12/15. + */ +public class CanonicalParameterization extends BDSParameterization { + + public Input birthRate = + new Input<>("birthRate", "BirthRate = BirthRateVector * birthRateScalar, birthrate can change over time"); + public Input deathRate = + new Input<>("deathRate", "The deathRate vector with birthRates between times"); + public Input samplingRate = + new Input<>("samplingRate", "The sampling rate per individual"); // psi + public Input removalProbability = + new Input<>("removalProbability", "The probability of an individual to become noninfectious immediately after the sampling"); + + @Override + public void initAndValidate() throws Exception { + + MultiSkyline multiSkyline = new MultiSkyline( + birthRate.get(), + deathRate.get(), + samplingRate.get(), + removalProbability.get() + ); + setMultiSkyline(multiSkyline); + } + + + @Override + public BDSSkylineSegment toCanonicalSegment(SkylineSegment segment) { + + double birth = segment.value[0]; // lambda = birth rate + double death = segment.value[1]; // mu = death rate + double psi = segment.value[2]; // psi = sampling rate + double r = segment.value[3]; // removal probability + + return new BDSSkylineSegment(birth, death, psi, r, segment.t1, segment.t2); + } + +} diff --git a/src/beast/evolution/speciation/OUPrior.java b/src/beast/evolution/speciation/OUPrior.java new file mode 100644 index 0000000..8b9916f --- /dev/null +++ b/src/beast/evolution/speciation/OUPrior.java @@ -0,0 +1,113 @@ +package beast.evolution.speciation; + +import beast.core.Distribution; +import beast.core.Function; +import beast.core.Input; +import beast.core.State; +import beast.core.parameter.RealParameter; +import beast.math.distributions.ParametricDistribution; + +import java.util.List; +import java.util.Random; + +/** + * @author Alexei Drummond. + */ +public class OUPrior extends Distribution { + + + // the trajectory to compute Ornstein-Uhlenbeck prior of + public Input xInput = + new Input<>("x", "The x_i values", (Function) null); + + // the times associated with the x_i values + public Input timeInput = + new Input<>("times", "The times t_i specifying when x changes", (Function) null); + + // mean + public Input meanInput = + new Input("mean", "The mean of the equilibrium distribution", (RealParameter) null); + + // sigma + public Input sigmaInput = + new Input("sigma", "The standard deviation parameter of the equilibrium distribution", (RealParameter) null); + + // nu + public Input nuInput = + new Input("nu", "The reversion parameter of the Ornstein-Uhlenbeck mean reversion process", (RealParameter) null); + + public Input x0PriorInput = + new Input<>("x0Prior", "The prior to use on x0, or null if none.", (ParametricDistribution) null); + + public Input logSpace = new Input<>("logspace", "true if prior should be applied to log(x).", false); + + public double calculateLogP() throws Exception { + + double mu = meanInput.get().getValue(); + double sigma = sigmaInput.get().getValue(); + double sigsq = sigma * sigma; + double nu = nuInput.get().getValue(); + + ParametricDistribution x0Prior = x0PriorInput.get(); + + double[] t = timeInput.get().getDoubleValues(); + double[] x = xInput.get().getDoubleValues(); + + boolean logspace = logSpace.get(); + if (logspace) { + for (int i = 0; i < x.length; i++) { + x[i] = Math.log(x[i]); + } + } + + int n = x.length - 1; + + double logL = -n/2.0 * Math.log(sigsq / (2.0*nu)); + + for (int i = 1; i <= n; i++) { + + double relterm = 1.0-Math.exp(-2.0*nu*(t[i]-t[i-1])); + + logL -= Math.log(relterm)/2.0; + + double term = x[i] - mu - (x[i-1]-mu) * Math.exp(-nu*(t[i]-t[i-1])); + + logL -= nu / sigsq * (term*term / relterm); + } + + if (x0Prior != null) logL += x0Prior.calcLogP(new Function() { + @Override + public int getDimension() { + return 1; + } + + @Override + public double getArrayValue() { + return x[0]; + } + + @Override + public double getArrayValue(int iDim) { + return x[0]; + } + }); + + logP = logL; + return logP; + } + + @Override + public List getArguments() { + return null; + } + + @Override + public List getConditions() { + return null; + } + + @Override + public void sample(State state, Random random) { + + } +} diff --git a/src/beast/evolution/speciation/ParameterizedBirthDeathSkylineModel.java b/src/beast/evolution/speciation/ParameterizedBirthDeathSkylineModel.java new file mode 100644 index 0000000..ebd066f --- /dev/null +++ b/src/beast/evolution/speciation/ParameterizedBirthDeathSkylineModel.java @@ -0,0 +1,779 @@ +package beast.evolution.speciation; + + +import bdsky.BDSSkylineSegment; +import beast.core.Citation; +import beast.core.Description; +import beast.core.Input; +import beast.core.parameter.BooleanParameter; +import beast.core.parameter.RealParameter; +import beast.evolution.alignment.Taxon; +import beast.evolution.tree.Tree; +import beast.evolution.tree.TreeInterface; + +import java.util.*; + +/** + * @author Alexei Drummond + * @author Denise Kuehnert + * @author Alexandra Gavryushkina + *

+ * maths: Tanja Stadler, sampled ancestor extension Alexandra Gavryushkina + */ + +@Description("BirthDeathSkylineModel with generalized parameterizations") +@Citation("Stadler, T., Kuehnert, D., Bonhoeffer, S., and Drummond, A. J. (2013):\n Birth-death skyline " + + "plot reveals temporal changes of\n epidemic spread in HIV and hepatitis C virus (HCV). PNAS 110(1): 228–33.\n" + + "If sampled ancestors are used then please also cite: Gavryushkina A, Welch D, Stadler T, Drummond AJ (2014) \n" + + "Bayesian inference of sampled ancestor trees for epidemiology and fossil calibration. \n" + + "PLoS Comput Biol 10(12): e1003919. doi:10.1371/journal.pcbi.1003919") +public class ParameterizedBirthDeathSkylineModel extends SpeciesTreeDistribution { + + public Input parameterizationInput = new Input<>("parameterization", "The parameterization to use.", Input.Validate.REQUIRED); + + // the times for rho sampling + public Input rhoSamplingTimes = + new Input("rhoSamplingTimes", "The times t_i specifying when rho-sampling occurs", (RealParameter) null); + + // the rho parameter, one for each rho sampling time + public Input rhoInput = + new Input("rho", "The proportion of lineages sampled at rho-sampling times (default 0.)"); + + public Input originIsRootEdge = + new Input<>("originIsRootEdge", "The origin is only the length of the root edge", false); + + public Input contemp = + new Input("contemp", "Only contemporaneous sampling (i.e. all tips are from same sampling time, default false)", false); + + public Input conditionOnSurvival = + new Input("conditionOnSurvival", "if is true then condition on sampling at least one individual (psi-sampling).", true); + public Input conditionOnRhoSampling = + new Input ("conditionOnRhoSampling","if is true then condition on sampling at least one individual at present.", false); + + double t_root; + protected double[] p0, p0hat; + protected double[] Ai, Aihat; + protected double[] Bi, Bihat; + protected int[] N; // number of leaves sampled at each time t_i + + // these four arrays are totalIntervals in length + protected Double[] birth; + Double[] death; + Double[] psi; + Double[] rho; + Double[] r; + + // true if the node of the given index occurs at the time of a rho-sampling event + boolean[] isRhoTip; + + /** + * The number of change points in the birth rate + */ + protected int birthChanges; + + /** + * The number of change points in the death rate + */ + int deathChanges; + + /** + * The number of change points in the sampling rate + */ + int samplingChanges; + int rhoChanges; + + /** + * The number of change point in the removal probability + */ + int rChanges; + + /** + * The number of times rho-sampling occurs + */ + int rhoSamplingCount; + Boolean constantRho; + + /** + * Total interval count + */ + protected int totalIntervals; + + protected List birthRateChangeTimes = new ArrayList(); + protected List deathRateChangeTimes = new ArrayList(); + protected List samplingRateChangeTimes = new ArrayList(); + protected List rhoSamplingChangeTimes = new ArrayList(); + protected List rChangeTimes = new ArrayList(); + + Boolean contempData; + //List intervals = new ArrayList(); + SortedSet timesSet = new TreeSet(); + + protected Double[] times = new Double[]{0.}; + + protected Boolean transform; + Boolean m_forceRateChange; + + Boolean birthRateTimesRelative = false; + Boolean deathRateTimesRelative = false; + Boolean samplingRateTimesRelative = false; + Boolean rTimesRelative = false; + Boolean[] reverseTimeArrays; + + public boolean SAModel; + + enum ConditionOn {NONE, SURVIVAL, RHO_SAMPLING}; + protected ConditionOn conditionOn= ConditionOn.SURVIVAL; + + public Boolean printTempResults; + + @Override + public void initAndValidate() throws Exception { + super.initAndValidate(); + + if (!originIsRootEdge.get() && treeInput.get().getRoot().getHeight() >= origin()) + throw new RuntimeException("Origin parameter ("+ origin() +" ) must be larger than tree height("+treeInput.get().getRoot().getHeight()+" ). Please change initial origin value!"); + + // check if this is a sampled ancestor model + if (parameterizationInput.get().isSampledAncestorModel()) SAModel = true; + + birth = null; + death = null; + psi = null; + rho = null; + r = null; + birthRateChangeTimes.clear(); + deathRateChangeTimes.clear(); + samplingRateChangeTimes.clear(); + if (SAModel) rChangeTimes.clear(); + totalIntervals = 0; + + contempData = contemp.get(); + rhoSamplingCount = 0; + printTempResults = false; + + //if (SAModel) rChanges = removalProbability.get().getDimension() -1; + + if (rhoInput.get()!=null) { + rho = rhoInput.get().getValues(); + rhoChanges = rhoInput.get().getDimension() - 1; + } + + collectTimes(); + + if (rhoInput.get() != null) { + + constantRho = !(rhoInput.get().getDimension() > 1); + + if (rhoInput.get().getDimension() == 1 && rhoSamplingTimes.get()==null || rhoSamplingTimes.get().getDimension() < 2) { + // TODO figure this out! + //if (!contempData && ((samplingProportion.get() != null && samplingProportion.get().getDimension() == 1 && samplingProportion.get().getValue() == 0.) || + // (samplingRate.get() != null && samplingRate.get().getDimension() == 1 && samplingRate.get().getValue() == 0.))) { + // contempData = true; + // if (printTempResults) + // System.out.println("Parameters were chosen for contemporaneously sampled data. Setting contemp=true."); + //} + } + + if (contempData) { + if (rhoInput.get().getDimension() != 1) + throw new RuntimeException("when contemp=true, rho must have dimension 1"); + + else { + rho = new Double[totalIntervals]; + Arrays.fill(rho, 0.); + rho[totalIntervals - 1] = rhoInput.get().getValue(); + rhoSamplingCount = 1; + } + } + + } else { + rho = new Double[totalIntervals]; + Arrays.fill(rho, 0.); + } + isRhoTip = new boolean[treeInput.get().getLeafNodeCount()]; + + if (conditionOnSurvival.get()) { + conditionOn = ConditionOn.SURVIVAL; + if (conditionOnRhoSampling.get()) { + throw new RuntimeException("conditionOnSurvival and conditionOnRhoSampling can not be both true at the same time." + + "Set one of them to true and another one to false."); + } + } else if (conditionOnRhoSampling.get()) { + if (!rhoSamplingConditionHolds()) { + throw new RuntimeException("Conditioning on rho-sampling is only available for sampled ancestor analyses where r " + + "is set to zero and all except the last rho are zero"); + } + conditionOn = ConditionOn.RHO_SAMPLING; + } else { + conditionOn = ConditionOn.NONE; + } + + printTempResults = false; + } + + private double origin() { + return parameterizationInput.get().origin(); + } + + + /** + * checks if r is zero, all elements of rho except the last one are + * zero and the last one is not zero + * @return + */ + private boolean rhoSamplingConditionHolds() { + + if (SAModel) { + for (BDSSkylineSegment segment : parameterizationInput.get().canonicalSegments()) { + if (segment.r() != 0.0) { + return false; + } + } + } else return false; + + for (int i=0; i changeTimes, RealParameter intervalTimes, int numChanges, boolean relative, + boolean reverse) { + changeTimes.clear(); + + if (printTempResults) System.out.println("relative = " + relative); + + double maxTime = originIsRootEdge.get()? treeInput.get().getRoot().getHeight() + origin() : origin(); + + if (intervalTimes == null) { //equidistant + + double intervalWidth = maxTime / (numChanges + 1); + + double end; + for (int i = 1; i <= numChanges; i++) { + end = (intervalWidth) * i; + changeTimes.add(end); + } + end = maxTime; + changeTimes.add(end); + + } else { + + int dim = intervalTimes.getDimension(); + + ArrayList sortedIntervalTimes = new ArrayList<>(); + for (int i=0; i< dim; i++) { + sortedIntervalTimes.add(intervalTimes.getValue(i)); + } + Collections.sort(sortedIntervalTimes); + + if (!reverse && sortedIntervalTimes.get(0) != 0.0) { + throw new RuntimeException("First time in interval times parameter should always be zero."); + } + +// if(intervalTimes.getValue(dim-1)==maxTime) changeTimes.add(0.); //rhoSampling + + double end; + for (int i = (reverse?0:1); i < dim; i++) { + end = reverse ? (maxTime - sortedIntervalTimes.get(dim - i - 1)) : sortedIntervalTimes.get(i); + if (relative) end *= maxTime; + if (end != maxTime) changeTimes.add(end); + } + end = maxTime; + changeTimes.add(end); + } +// } + } + + /* + * Counts the number of tips at each of the contemporaneous sampling times ("rho" sampling time) + * @return negative infinity if tips are found at a time when rho is zero, zero otherwise. + */ + private double computeN(TreeInterface tree) { + + isRhoTip = new boolean[tree.getLeafNodeCount()]; + + N = new int[totalIntervals]; + + int tipCount = tree.getLeafNodeCount(); + + double[] dates = new double[tipCount]; + + for (int i = 0; i < tipCount; i++) { + dates[i] = tree.getNode(i).getHeight(); + } + + for (int k = 0; k < totalIntervals; k++) { + + + for (int i = 0; i < tipCount; i++) { + + if (Math.abs((times[totalIntervals - 1] - times[k]) - dates[i]) < 1e-10) { + if (rho[k] == 0 && psi[k] == 0) { + return Double.NEGATIVE_INFINITY; + } + if (rho[k] > 0) { + N[k] += 1; + isRhoTip[i] = true; + } + } + } + } + return 0.; + } + + /** + * Collect all the times of multiskyline parameterization and the rho-sampling events + */ + private void collectTimes() { + + timesSet.clear(); + + for (BDSSkylineSegment seg : parameterizationInput.get().canonicalSegments()) { + timesSet.add(seg.start()); + } + + getChangeTimes(rhoSamplingChangeTimes, rhoSamplingTimes.get(), rhoChanges, false, reverseTimeArrays[3]); + + if (printTempResults) System.out.println("times = " + timesSet); + + times = timesSet.toArray(new Double[timesSet.size()]); + totalIntervals = times.length; + + if (printTempResults) System.out.println("total intervals = " + totalIntervals); + + } + + protected Double updateRatesAndTimes(TreeInterface tree) { + + collectTimes(); + + t_root = tree.getRoot().getHeight(); + + parameterizationInput.get().populateCanonical(birth, death, psi, r, times); +// for (int i = 0; i < totalIntervals; i++) { +// death[i] = deathRates[index(times[i], deathRateChangeTimes)]; +// psi[i] = samplingRates[index(times[i], samplingRateChangeTimes)]; +// if (SAModel) r[i] = removalProbabilities[index(times[i], rChangeTimes)]; +// +// if (printTempResults) { +// System.out.println("death[" + i + "]=" + death[i]); +// System.out.println("psi[" + i + "]=" + psi[i]); +// if (SAModel) System.out.println("r[" + i + "]=" + r[i]); +// } +// } + + if (rhoInput.get() != null && (rhoInput.get().getDimension()==1 || rhoSamplingTimes.get() != null)) { + + Double[] rhos = rhoInput.get().getValues(); + rho = new Double[totalIntervals]; + +// rho[totalIntervals-1]=rhos[rhos.length-1]; + for (int i = 0; i < totalIntervals; i++) { + + rho[i]= //rhoSamplingChangeTimes.contains(times[i]) ? rhos[rhoSamplingChangeTimes.indexOf(times[i])] : 0.; + rhoChanges>0? + rhoSamplingChangeTimes.contains(times[i]) ? rhos[rhoSamplingChangeTimes.indexOf(times[i])] : 0. + : rhos[0]; + } + } + + return 0.; + } + + + /* calculate and store Ai, Bi and p0 */ + public Double preCalculation(TreeInterface tree) { + + if (!originIsRootEdge.get() && tree.getRoot().getHeight() >= parameterizationInput.get().origin()) { + return Double.NEGATIVE_INFINITY; + } + + // updateRatesAndTimes must be called before calls to index() below + if (updateRatesAndTimes(tree) < 0) { + return Double.NEGATIVE_INFINITY; + } + + if (printTempResults) System.out.println("After update rates and times"); + + if (rhoInput.get() != null) { + if (contempData) { + rho = new Double[totalIntervals]; + Arrays.fill(rho, 0.); + rho[totalIntervals-1] = rhoInput.get().getValue(); + } + + } else { + rho = new Double[totalIntervals]; + Arrays.fill(rho, 0.0); + } + + if (rhoInput.get() != null) + if (computeN(tree) < 0) + return Double.NEGATIVE_INFINITY; + + int intervalCount = times.length; + + Ai = new double[intervalCount]; + Bi = new double[intervalCount]; + p0 = new double[intervalCount]; + + if (conditionOn == ConditionOn.RHO_SAMPLING) { + Aihat = new double[intervalCount]; + Bihat = new double[intervalCount]; + p0hat = new double[intervalCount]; + } + + for (int i = 0; i < intervalCount; i++) { + + Ai[i] = Ai(birth[i], death[i], psi[i]); + + if (conditionOn == ConditionOn.RHO_SAMPLING) { + Aihat[i] = Ai(birth[i], death[i], 0.0); + } + + if (printTempResults) System.out.println("Ai[" + i + "] = " + Ai[i] + " " + Math.log(Ai[i])); + } + + if (printTempResults) { + System.out.println("birth[m-1]=" + birth[totalIntervals - 1]); + System.out.println("death[m-1]=" + death[totalIntervals - 1]); + System.out.println("psi[m-1]=" + psi[totalIntervals - 1]); + System.out.println("rho[m-1]=" + rho[totalIntervals - 1]); + System.out.println("Ai[m-1]=" + Ai[totalIntervals - 1]); + } + + Bi[totalIntervals - 1] = Bi( + birth[totalIntervals - 1], + death[totalIntervals - 1], + psi[totalIntervals - 1], + rho[totalIntervals - 1], + Ai[totalIntervals - 1], 1.); // (p0[m-1] = 1) + + if (conditionOn == ConditionOn.RHO_SAMPLING) { + Bihat[totalIntervals - 1] = Bi( + birth[totalIntervals - 1], + death[totalIntervals - 1], + 0.0, + rho[totalIntervals - 1], + Aihat[totalIntervals - 1], 1.); // (p0[m-1] = 1) + } + + if (printTempResults) + System.out.println("Bi[m-1] = " + Bi[totalIntervals - 1] + " " + Math.log(Bi[totalIntervals - 1])); + for (int i = totalIntervals - 2; i >= 0; i--) { + + p0[i + 1] = p0(birth[i + 1], death[i + 1], psi[i + 1], Ai[i + 1], Bi[i + 1], times[i + 1], times[i]); + if (Math.abs(p0[i + 1] - 1) < 1e-10) { + return Double.NEGATIVE_INFINITY; + } + if (conditionOn == ConditionOn.RHO_SAMPLING) { + p0hat[i + 1] = p0(birth[i + 1], death[i + 1], 0.0, Aihat[i + 1], Bihat[i + 1], times[i + 1], times[i]); + if (Math.abs(p0hat[i + 1] - 1) < 1e-10) { + return Double.NEGATIVE_INFINITY; + } + } + if (printTempResults) System.out.println("p0[" + (i + 1) + "] = " + p0[i + 1]); + + Bi[i] = Bi(birth[i], death[i], psi[i], rho[i], Ai[i], p0[i + 1]); + if (conditionOn == ConditionOn.RHO_SAMPLING) { + Bihat[i] = Bi(birth[i], death[i], 0.0, rho[i], Aihat[i], p0hat[i + 1]); + } + + if (printTempResults) System.out.println("Bi[" + i + "] = " + Bi[i] + " " + Math.log(Bi[i])); + } + + if (printTempResults) { + System.out.println("g(0, x0, 0):" + g(0, times[0], 0)); + System.out.println("g(index(1),times[index(1)],1.) :" + g(index(1), times[index(1)], 1.)); + System.out.println("g(index(2),times[index(2)],2.) :" + g(index(2), times[index(2)], 2)); + System.out.println("g(index(4),times[index(4)],4.):" + g(index(4), times[index(4)], 4)); + } + + return 0.; + } + + public double Ai(double b, double g, double psi) { + + return Math.sqrt((b - g - psi) * (b - g - psi) + 4 * b * psi); + } + + public double Bi(double b, double g, double psi, double r, double A, double p0) { + + return ((1 - 2 * p0 * (1 - r)) * b + g + psi) / A; + } + + public double p0(int index, double t, double ti) { + + return p0(birth[index], death[index], psi[index], Ai[index], Bi[index], t, ti); + } + + public double p0(double b, double g, double psi, double A, double B, double ti, double t) { + + if (printTempResults) + System.out.println("in p0: b = " + b + "; g = " + g + "; psi = " + psi + "; A = " + A + " ; B = " + B + "; ti = " + ti + "; t = " + t); +// return ((b + g + psi - A *((Math.exp(A*(ti - t))*(1+B)-(1-B)))/(Math.exp(A*(ti - t))*(1+B)+(1-B)) ) / (2*b)); + // formula from manuscript slightly rearranged for numerical stability + return ((b + g + psi - A * ((1 + B) - (1 - B) * (Math.exp(A * (t - ti)))) / ((1 + B) + Math.exp(A * (t - ti)) * (1 - B))) / (2 * b)); + + } + + public double p0hat(int index, double t, double ti) { + + return p0(birth[index], death[index], 0.0, Aihat[index], Bihat[index], t, ti); + } + + + public double g(int index, double ti, double t) { + +// return (Math.exp(Ai[index]*(ti - t))) / (0.25*Math.pow((Math.exp(Ai[index]*(ti - t))*(1+Bi[index])+(1-Bi[index])),2)); + // formula from manuscript slightly rearranged for numerical stability + return (4 * Math.exp(Ai[index] * (t - ti))) / (Math.exp(Ai[index] * (t - ti)) * (1 - Bi[index]) + (1 + Bi[index])) / (Math.exp(Ai[index] * (t - ti)) * (1 - Bi[index]) + (1 + Bi[index])); + } + + /** + * @param t the time in question + * @return the index of the given time in the list of times, or if the time is not in the list, the index of the + * next smallest time + */ + public int index(double t, List times) { + + int epoch = Collections.binarySearch(times, t); + + if (epoch < 0) { + epoch = -epoch - 1; + } + + return epoch; + } + + + /** + * @param t the time in question + * @return the index of the given time in the times array, or if the time is not in the array the index of the time + * next smallest + */ + public int index(double t) { + + if (t >= times[totalIntervals - 1]) + return totalIntervals - 1; + + int epoch = Arrays.binarySearch(times, t); + + if (epoch < 0) { + epoch = -epoch - 1; + } + + return epoch; + } + + + /** + * @param time the time + * @param tree the tree + * @return the number of lineages that exist at the given time in the given tree. + */ + public int lineageCountAtTime(double time, TreeInterface tree) { + + int count = 1; + int tipCount = tree.getLeafNodeCount(); + for (int i = tipCount; i < tipCount + tree.getInternalNodeCount(); i++) { + if (tree.getNode(i).getHeight() > time) count += 1; + + } + for (int i = 0; i < tipCount; i++) { + if (tree.getNode(i).getHeight() >= time) count -= 1; + } + return count; + } + + /** + * @param time the time + * @param tree the tree + * @param k count the number of sampled ancestors at the given time + * @return the number of lineages that exist at the given time in the given tree. + */ + public int lineageCountAtTime(double time, TreeInterface tree, int[] k) { + + int count = 1; + k[0]=0; + int tipCount = tree.getLeafNodeCount(); + for (int i = tipCount; i < tipCount + tree.getInternalNodeCount(); i++) { + if (tree.getNode(i).getHeight() >= time) count += 1; + + } + for (int i = 0; i < tipCount; i++) { + if (tree.getNode(i).getHeight() > time) count -= 1; + if (Math.abs(tree.getNode(i).getHeight() - time) < 1e-10) { + count -= 1; + if (tree.getNode(i).isDirectAncestor()) { + count -= 1; + k[0]++; + } + + } + } + return count; + } + + @Override + public double calculateTreeLogLikelihood(TreeInterface tree) { + + int nTips = tree.getLeafNodeCount(); + + if (preCalculation(tree) < 0) { + return Double.NEGATIVE_INFINITY; + } + + // number of lineages at each time ti + int[] n = new int[totalIntervals]; + + double x0 = 0; + int index = 0; + + if (times[index] < 0.) + index = index(0.); + + double temp=0; + + switch (conditionOn) { + case NONE: + temp = Math.log(g(index, times[index], x0)); + break; + case SURVIVAL: + temp = p0(index, times[index], x0); + if (temp == 1) + return Double.NEGATIVE_INFINITY; + temp = Math.log(g(index, times[index], x0) / (1 - temp)); + break; + case RHO_SAMPLING: + temp = p0hat(index, times[index], x0); + if (temp == 1) + return Double.NEGATIVE_INFINITY; + temp = Math.log(g(index, times[index], x0) / (1 - temp)); + break; + default: + break; + } + + logP = temp; + if (Double.isInfinite(logP)) + return logP; + + if (printTempResults) System.out.println("first factor for origin = " + temp); + + // first product term in f[T] + for (int i = 0; i < tree.getInternalNodeCount(); i++) { + + double x = times[totalIntervals - 1] - tree.getNode(nTips + i).getHeight(); + index = index(x); + if (!(tree.getNode(nTips + i)).isFake()) { + temp = Math.log(birth[index] * g(index, times[index], x)); + logP += temp; + if (printTempResults) System.out.println("1st pwd" + + " = " + temp + "; interval = " + i); + if (Double.isInfinite(logP)) + return logP; + } + } + + // middle product term in f[T] + for (int i = 0; i < nTips; i++) { + + if (!isRhoTip[i] || rhoInput.get() == null) { + double y = times[totalIntervals - 1] - tree.getNode(i).getHeight(); + index = index(y); + + if (!(tree.getNode(i)).isDirectAncestor()) { + if (!SAModel) { + temp = Math.log(psi[index]) - Math.log(g(index, times[index], y)); + } else { + temp = Math.log(psi[index] * (r[index] + (1 - r[index]) * p0(index, times[index], y))) - Math.log(g(index, times[index], y)); + } + logP += temp; + if (printTempResults) System.out.println("2nd PI = " + temp); + if (psi[index] == 0 || Double.isInfinite(logP)) + return logP; + } else { + if (r[index] != 1) { + logP += Math.log((1 - r[index])*psi[index]); + if (Double.isInfinite(logP)) { + return logP; + } + } else { + //throw new Exception("There is a sampled ancestor in the tree while r parameter is 1"); + System.out.println("There is a sampled ancestor in the tree while r parameter is 1"); + System.exit(0); + } + } + } + } + + // last product term in f[T], factorizing from 1 to m // + double time; + for (int j = 0; j < totalIntervals; j++) { + time = j < 1 ? 0 : times[j - 1]; + int[] k = {0}; + if (!SAModel) { + n[j] = ((j == 0) ? 0 : lineageCountAtTime(times[totalIntervals - 1] - time, tree)); + } else { + n[j] = ((j == 0) ? 0 : lineageCountAtTime(times[totalIntervals - 1] - time, tree, k)); + } + if (n[j] > 0) { + temp = n[j] * (Math.log(g(j, times[j], time)) + Math.log(1 - rho[j-1])); + logP += temp; + if (printTempResults) + System.out.println("3rd factor (nj loop) = " + temp + "; interval = " + j + "; n[j] = " + n[j]);//+ "; Math.log(g(j, times[j], time)) = " + Math.log(g(j, times[j], time))); + if (Double.isInfinite(logP)) + return logP; + + } + + if (SAModel && j>0 && N != null) { // term for sampled leaves and two-degree nodes at time t_i + logP += k[0] * (Math.log(g(j, times[j], time)) + Math.log(1-r[j])) + //here g(j,..) corresponds to q_{i+1}, r[j] to r_{i+1}, + (N[j-1]-k[0])*(Math.log(r[j]+ (1-r[j])*p0(j, times[j], time))); //N[j-1] to N_i, k[0] to K_i,and thus N[j-1]-k[0] to M_i + if (Double.isInfinite(logP)) { + return logP; + } + } + + if (rho[j] > 0 && N[j] > 0) { + temp = N[j] * Math.log(rho[j]); // term for contemporaneous sampling + logP += temp; + if (printTempResults) + System.out.println("3rd factor (Nj loop) = " + temp + "; interval = " + j + "; N[j] = " + N[j]); + if (Double.isInfinite(logP)) + return logP; + + } + } + + if (SAModel) { + int internalNodeCount = tree.getLeafNodeCount() - ((Tree)tree).getDirectAncestorNodeCount()- 1; + logP += Math.log(2)*internalNodeCount; + } + + return logP; + } + + public double calculateTreeLogLikelihood(Tree tree, Set exclude) { + if (exclude.size() == 0) return calculateTreeLogLikelihood(tree); + throw new RuntimeException("Not implemented!"); + } + + @Override + protected boolean requiresRecalculation() { + return true; + } + + @Override + public boolean canHandleTipDates() { + return (rhoInput.get() == null); + } +} diff --git a/src/beast/evolution/speciation/R0Parameterization.java b/src/beast/evolution/speciation/R0Parameterization.java new file mode 100644 index 0000000..6bf5b9d --- /dev/null +++ b/src/beast/evolution/speciation/R0Parameterization.java @@ -0,0 +1,54 @@ +package beast.evolution.speciation; + +import bdsky.*; +import beast.core.Input; + +/** + * Created by alexeid on 7/12/15. + */ +public class R0Parameterization extends BDSParameterization { + + public Input R0 = + new Input<>("R0", + "The skyline of the basic reproduction number"); + public Input becomeUninfectiousRate = + new Input<>("becomeUninfectiousRate", + "Rate at which individuals become uninfectious (through recovery or sampling)"); + public Input samplingProportion = + new Input<>("samplingProportion", + "The samplingProportion = samplingRate / becomeUninfectiousRate"); + public Input removalProbability = + new Input<>("removalProbability", + "The probability of death/removal/recovery upon sampling. " + + "If 1.0 then no sampled ancestors are produced in that interval."); + + @Override + public void initAndValidate() throws Exception { + + MultiSkyline multiSkyline = new MultiSkyline( + R0.get(), + becomeUninfectiousRate.get(), + samplingProportion.get(), + removalProbability.get() + ); + setMultiSkyline(multiSkyline); + } + + + @Override + public BDSSkylineSegment toCanonicalSegment(SkylineSegment segment) { + + double R = segment.value[0]; // R + double b = segment.value[1]; // become uninfectious + double p = segment.value[2]; // sampling proportion + double r = segment.value[3]; // removal probability + + double birth = R * b; + double psi = p * b; + double death = b - psi*r; + + + return new BDSSkylineSegment(birth, death, psi, r, segment.t1, segment.t2); + } + +} diff --git a/src/test/bdsky/MultiSkylineTest.java b/src/test/bdsky/MultiSkylineTest.java new file mode 100644 index 0000000..cb06117 --- /dev/null +++ b/src/test/bdsky/MultiSkylineTest.java @@ -0,0 +1,64 @@ +package test.bdsky; + +import bdsky.MultiSkyline; +import bdsky.SimpleSkyline; +import bdsky.SkylineSegment; +import beast.core.parameter.RealParameter; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for simple skyline + */ +public class MultiSkylineTest { + + @Test + public void testGetValue() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + SimpleSkyline skyline2 = new SimpleSkyline(); + skyline2.setInputValue("times", new RealParameter("0 1.5 2.5 3.5")); + skyline2.setInputValue("parameter", new RealParameter("2 1.5 -2.7 4.5")); + + List simpleSkylines = new ArrayList<>(); + simpleSkylines.add(skyline); + simpleSkylines.add(skyline2); + + MultiSkyline multiSkyline = new MultiSkyline(); + multiSkyline.skylineInput.setValue(simpleSkylines,multiSkyline); + multiSkyline.initAndValidate(); + + assertEquals(2, multiSkyline.getDimension()); + + assertEquals(7, multiSkyline.getSegments().size()); + + assertEquals(-1, multiSkyline.getValue(0.75)[0], 0); + assertEquals(2, multiSkyline.getValue(0.75)[1], 0); + + assertEquals(3, multiSkyline.getValue(1.25)[0], 0); + assertEquals(2, multiSkyline.getValue(1.25)[1], 0); + + assertEquals(3, multiSkyline.getValue(1.75)[0], 0); + assertEquals(1.5, multiSkyline.getValue(1.75)[1], 0); + + assertEquals(4, multiSkyline.getValue(2.25)[0], 0); + assertEquals(1.5, multiSkyline.getValue(2.25)[1], 0); + + assertEquals(4, multiSkyline.getValue(2.75)[0], 0); + assertEquals(-2.7, multiSkyline.getValue(2.75)[1], 0); + + assertEquals(-1.5, multiSkyline.getValue(3.25)[0], 0); + assertEquals(-2.7, multiSkyline.getValue(3.25)[1], 0); + + assertEquals(-1.5, multiSkyline.getValue(3.75)[0], 0); + assertEquals(4.5, multiSkyline.getValue(3.75)[1], 0); + } + +} \ No newline at end of file diff --git a/src/test/bdsky/SimpleSkylineTest.java b/src/test/bdsky/SimpleSkylineTest.java new file mode 100644 index 0000000..2138235 --- /dev/null +++ b/src/test/bdsky/SimpleSkylineTest.java @@ -0,0 +1,193 @@ +package test.bdsky; + +import bdsky.SkylineSegment; +import beast.core.parameter.RealParameter; +import bdsky.SimpleSkyline; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.*; + +/** + * Tests for simple skyline + */ +public class SimpleSkylineTest { + + @Test + public void testGetValue() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + assertEquals(-1, skyline.getValue(0)[0], 0); + assertEquals(-1, skyline.getValue(0.5)[0], 0); + + assertEquals(3, skyline.getValue(1)[0], 0); + assertEquals(3, skyline.getValue(1.5)[0], 0); + + assertEquals(4, skyline.getValue(2)[0], 0); + assertEquals(4, skyline.getValue(2.5)[0], 0); + + assertEquals(-1.5, skyline.getValue(3)[0], 0); + assertEquals(-1.5, skyline.getValue(3.5)[0], 0); + } + + @Test + public void testGetSegments() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + List segments = skyline.getSegments(); + + assertEquals("Checking number of segments", 4, segments.size()); + + assertEquals(0, segments.get(0).t1, 0.0); + assertEquals(1, segments.get(1).t1, 0.0); + assertEquals(2, segments.get(2).t1, 0.0); + assertEquals(3, segments.get(3).t1, 0.0); + + assertEquals(1, segments.get(0).t2, 0.0); + assertEquals(2, segments.get(1).t2, 0.0); + assertEquals(3, segments.get(2).t2, 0.0); + assertEquals(Double.POSITIVE_INFINITY, segments.get(3).t2, 0.0); + + assertEquals(-1, segments.get(0).value[0], 0.0); + assertEquals(3, segments.get(1).value[0], 0.0); + assertEquals(4, segments.get(2).value[0], 0.0); + assertEquals(-1.5, segments.get(3).value[0], 0.0); + } + + @Test + public void testGetSegments1() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + List segments = skyline.getSegments(0,1); + + assertEquals("Checking number of segments", 1, segments.size()); + + assertEquals(0, segments.get(0).t1, 0.0); + + assertEquals(1, segments.get(0).t2, 0.0); + + assertEquals(-1, segments.get(0).value[0], 0.0); + + } + + @Test + public void testGetSegments2() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + List segments = skyline.getSegments(0.2,0.3); + + assertEquals("Checking number of segments", 1, segments.size()); + + assertEquals(0.2, segments.get(0).t1, 0.0); + + assertEquals(0.3, segments.get(0).t2, 0.0); + + assertEquals(-1, segments.get(0).value[0], 0.0); + + } + + @Test + public void testGetSegments3() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + List segments = skyline.getSegments(0.2,1.3); + + assertEquals("Checking number of segments", 2, segments.size()); + + assertEquals(0.2, segments.get(0).t1, 0.0); + assertEquals(1, segments.get(1).t1, 0.0); + + assertEquals(1, segments.get(0).t2, 0.0); + assertEquals(1.3, segments.get(1).t2, 0.0); + + assertEquals(-1, segments.get(0).value[0], 0.0); + assertEquals(3, segments.get(1).value[0], 0.0); + } + + @Test + public void testGetSegments4() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + List segments = skyline.getSegments(0.2,2.3); + + assertEquals("Checking number of segments", 3, segments.size()); + + assertEquals(0.2, segments.get(0).t1, 0.0); + assertEquals(1, segments.get(1).t1, 0.0); + assertEquals(2, segments.get(2).t1, 0.0); + + assertEquals(1, segments.get(0).t2, 0.0); + assertEquals(2, segments.get(1).t2, 0.0); + assertEquals(2.3, segments.get(2).t2, 0.0); + + assertEquals(-1, segments.get(0).value[0], 0.0); + assertEquals(3, segments.get(1).value[0], 0.0); + assertEquals(4, segments.get(2).value[0], 0.0); + } + + @Test + public void testGetSegments5() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + List segments = skyline.getSegments(1.1,5.3); + + assertEquals("Checking number of segments", 3, segments.size()); + + assertEquals(1.1, segments.get(0).t1, 0.0); + assertEquals(2, segments.get(1).t1, 0.0); + assertEquals(3, segments.get(2).t1, 0.0); + + assertEquals(2, segments.get(0).t2, 0.0); + assertEquals(3, segments.get(1).t2, 0.0); + assertEquals(5.3, segments.get(2).t2, 0.0); + + assertEquals(3, segments.get(0).value[0], 0.0); + assertEquals(4, segments.get(1).value[0], 0.0); + assertEquals(-1.5, segments.get(2).value[0], 0.0); + } + + @Test + public void testGetSegments6() throws Exception { + + SimpleSkyline skyline = new SimpleSkyline(); + skyline.setInputValue("times", new RealParameter("0 1 2 3")); + skyline.setInputValue("parameter", new RealParameter("-1 3 4 -1.5")); + + List segments = skyline.getSegments(1,3); + + assertEquals("Checking number of segments", 2, segments.size()); + + assertEquals(1, segments.get(0).t1, 0.0); + assertEquals(2, segments.get(1).t1, 0.0); + + assertEquals(2, segments.get(0).t2, 0.0); + assertEquals(3, segments.get(1).t2, 0.0); + + assertEquals(3, segments.get(0).value[0], 0.0); + assertEquals(4, segments.get(1).value[0], 0.0); + } + + +} \ No newline at end of file