From 6dbf5b7d5cd8a3e39224be209d3c98b46f7f51fc Mon Sep 17 00:00:00 2001 From: Haibao Tang Date: Sun, 28 Apr 2024 00:24:56 -0700 Subject: [PATCH] Add test_formula.py --- jcvi/algorithms/formula.py | 44 ++++++++++++++++++---- jcvi/apps/base.py | 2 +- jcvi/assembly/geneticmap.py | 63 +++++++++----------------------- jcvi/assembly/hic.py | 8 ++-- jcvi/graphics/base.py | 9 ++++- tests/algorithms/test_formula.py | 19 ++++++++++ 6 files changed, 85 insertions(+), 60 deletions(-) create mode 100644 tests/algorithms/test_formula.py diff --git a/jcvi/algorithms/formula.py b/jcvi/algorithms/formula.py index 92d4b1c4..64384c26 100644 --- a/jcvi/algorithms/formula.py +++ b/jcvi/algorithms/formula.py @@ -5,16 +5,19 @@ Some math formula for various calculations """ import sys -import numpy as np +from collections import Counter +from functools import lru_cache from math import log, exp, sqrt -from jcvi.utils.cbook import human_size +import numpy as np +import scipy + +from ..utils.cbook import human_size def mean_confidence_interval(data, confidence=0.95): # Compute the confidence interval around the mean - import scipy a = 1.0 * np.array(data) n = len(a) @@ -176,7 +179,7 @@ def jukesCantorD(p, L=100): rD = 1 - 4.0 / 3 * p D = -0.75 * log(rD) - varD = p * (1 - p) / (rD ** 2 * L) + varD = p * (1 - p) / (rD**2 * L) return D, varD @@ -218,8 +221,35 @@ def velvet(readsize, genomesize, numreads, K): print("RAM usage: {0} (MAXKMERLENGTH=31)".format(ram), file=sys.stderr) -if __name__ == "__main__": +@lru_cache(maxsize=None) +def calc_ldscore(a: str, b: str) -> float: + """ + Calculate Linkage disequilibrium (r2) between two genotypes. + """ + assert len(a) == len(b), f"{a}\n{b}" + # Assumes markers as A/B + c = Counter(zip(a, b)) + c_aa = c[("A", "A")] + c_ab = c[("A", "B")] + c_ba = c[("B", "A")] + c_bb = c[("B", "B")] + n = c_aa + c_ab + c_ba + c_bb + if n == 0: + return 0 - import doctest + f = 1.0 / n + x_aa = c_aa * f + x_ab = c_ab * f + x_ba = c_ba * f + x_bb = c_bb * f + p_a = x_aa + x_ab + p_b = x_ba + x_bb + q_a = x_aa + x_ba + q_b = x_ab + x_bb + D = x_aa - p_a * q_a + denominator = p_a * p_b * q_a * q_b + if denominator == 0: + return 0 - doctest.testmod() + r2 = D * D / denominator + return r2 diff --git a/jcvi/apps/base.py b/jcvi/apps/base.py index 0d101bf9..5c96eec8 100644 --- a/jcvi/apps/base.py +++ b/jcvi/apps/base.py @@ -39,7 +39,7 @@ os.environ["LC_ALL"] = "C" JCVIHELP = "JCVI utility libraries {} [{}]\n".format(__version__, __copyright__) -TextCollection = Union[str, List[str], Tuple[str]] +TextCollection = Union[str, List[str], Tuple[str, ...]] def debug(level=logging.DEBUG): diff --git a/jcvi/assembly/geneticmap.py b/jcvi/assembly/geneticmap.py index 612bb000..35fe7ec5 100644 --- a/jcvi/assembly/geneticmap.py +++ b/jcvi/assembly/geneticmap.py @@ -8,20 +8,24 @@ import os.path as op import sys -from collections import Counter -from functools import lru_cache from itertools import combinations, groupby from random import sample -from typing import List import numpy as np import seaborn as sns from ..apps.base import OptionParser, ActionDispatcher, logger, need_update +from ..algorithms.formula import calc_ldscore from ..algorithms.matrix import symmetrize from ..formats.base import BaseFile, LineFile, must_open, read_block from ..formats.bed import Bed, fastaFromBed -from ..graphics.base import Rectangle, draw_cmap, plt, plot_heatmap, savefig +from ..graphics.base import ( + Rectangle, + draw_cmap, + plt, + plot_heatmap, + savefig, +) MSTheader = """population_type {0} @@ -332,40 +336,6 @@ def dotplot(args): fig.clear() -@lru_cache(maxsize=None) -def calc_ldscore(a: List[str], b: List[str]) -> float: - """ - Calculate Linkage disequilibrium (r2) between two genotypes. - """ - assert len(a) == len(b), "{0}\n{1}".format(a, b) - # Assumes markers as A/B - c = Counter(zip(a, b)) - c_aa = c[("A", "A")] - c_ab = c[("A", "B")] - c_ba = c[("B", "A")] - c_bb = c[("B", "B")] - n = c_aa + c_ab + c_ba + c_bb - if n == 0: - return 0 - - f = 1.0 / n - x_aa = c_aa * f - x_ab = c_ab * f - x_ba = c_ba * f - x_bb = c_bb * f - p_a = x_aa + x_ab - p_b = x_ba + x_bb - q_a = x_aa + x_ba - q_b = x_ab + x_bb - D = x_aa - p_a * q_a - denominator = p_a * p_b * q_a * q_b - if denominator == 0: - return 0 - - r2 = D * D / denominator - return r2 - - def heatmap(args): """ %prog heatmap map @@ -424,9 +394,6 @@ def heatmap(args): fig = plt.figure(1, (iopts.w, iopts.h)) root = fig.add_axes((0, 0, 1, 1)) ax = fig.add_axes((0.1, 0.1, 0.8, 0.8)) # the heatmap - cmap = sns.cubehelix_palette(rot=0.5, as_cmap=True) - - ax.imshow(M, cmap=cmap, interpolation="none") # Plot chromosomes breaks bed = Bed(markerbedfile) @@ -435,23 +402,27 @@ def heatmap(args): chr_labels = [] ignore_size = 20 + breaks = [] for seqid, beg, end in bed.get_breaks(): ignore = abs(end - beg) < ignore_size pos = (beg + end) / 2 chr_labels.append((seqid, pos, ignore)) if ignore: continue - ax.plot((end, end), extent, "w-", lw=1) - ax.plot(extent, (end, end), "w-", lw=1) + breaks.append(end) + + cmap = sns.color_palette("rocket", as_cmap=True) + plot_heatmap(ax, M, breaks, cmap=cmap, plot_breaks=True) # Plot chromosome labels for label, pos, ignore in chr_labels: - pos = 0.1 + pos * 0.8 / xsize if not ignore: + xpos = 0.1 + pos * 0.8 / xsize root.text( - pos, 0.91, label, ha="center", va="bottom", rotation=45, color="grey" + xpos, 0.91, label, ha="center", va="bottom", rotation=45, color="grey" ) - root.text(0.09, pos, label, ha="right", va="center", color="grey") + ypos = 0.9 - pos * 0.8 / xsize + root.text(0.09, ypos, label, ha="right", va="center", color="grey") ax.set_xlim(extent) ax.set_ylim((nmarkers, 0)) # Invert y-axis diff --git a/jcvi/assembly/hic.py b/jcvi/assembly/hic.py index d5bb099e..ca8cf367 100644 --- a/jcvi/assembly/hic.py +++ b/jcvi/assembly/hic.py @@ -1631,10 +1631,10 @@ def movieframe(args): M = read_clm(clm, totalbins, bins) fig = plt.figure(1, (iopts.w, iopts.h)) - root = fig.add_axes([0, 0, 1, 1]) # whole canvas - ax1 = fig.add_axes([0.05, 0.1, 0.4, 0.8]) # heatmap - ax2 = fig.add_axes([0.55, 0.1, 0.4, 0.8]) # dot plot - ax2_root = fig.add_axes([0.5, 0, 0.5, 1]) # dot plot canvas + root = fig.add_axes((0, 0, 1, 1)) # whole canvas + ax1 = fig.add_axes((0.05, 0.1, 0.4, 0.8)) # heatmap + ax2 = fig.add_axes((0.55, 0.1, 0.4, 0.8)) # dot plot + ax2_root = fig.add_axes((0.5, 0, 0.5, 1)) # dot plot canvas # Left axis: heatmap plot_heatmap(ax1, M, breaks, binsize=BINSIZE) diff --git a/jcvi/graphics/base.py b/jcvi/graphics/base.py index 9e00a620..ab9d1eed 100644 --- a/jcvi/graphics/base.py +++ b/jcvi/graphics/base.py @@ -24,6 +24,7 @@ from brewer2mpl import get_map from matplotlib import cm, rc, rcParams +from matplotlib.colors import Colormap from matplotlib.patches import ( Rectangle, Polygon, @@ -34,7 +35,7 @@ FancyArrowPatch, FancyBboxPatch, ) -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Union from ..apps.base import datadir, glob, listify, logger, sample_N, which from ..formats.base import LineFile @@ -506,6 +507,7 @@ def plot_heatmap( breaks: List[int], groups: List[Tuple[int, int, List[Tuple[int, str]], str]] = [], plot_breaks: bool = False, + cmap: Optional[Union[str, Colormap]] = None, binsize: Optional[int] = None, ): """Plot heatmap illustrating the contact probabilities in Hi-C data. @@ -517,9 +519,10 @@ def plot_heatmap( iopts (OptionParser options): Graphical options passed in from commandline groups (List, optional): [(start, end, [(position, seqid)], color)]. Defaults to []. plot_breaks (bool): Whether to plot white breaks. Defaults to False. + cmap (str | Colormap, optional): Colormap. Defaults to None, which uses cubehelix. binsize (int, optional): Resolution of the heatmap. """ - cmap = sns.cubehelix_palette(rot=0.5, as_cmap=True) + cmap = cmap or sns.cubehelix_palette(rot=0.5, as_cmap=True) ax.imshow(M, cmap=cmap, interpolation="none") _, xmax = ax.get_xlim() xlim = (0, xmax) @@ -546,6 +549,8 @@ def simplify_seqid(seqid): ax.set_xlim(xlim) ax.set_ylim((xlim[1], xlim[0])) # Flip the y-axis so the origin is at the top + ax.set_xticks(ax.get_xticks()) + ax.set_yticks(ax.get_yticks()) ax.set_xticklabels(ax.get_xticks(), family="Helvetica", color="gray") ax.set_yticklabels(ax.get_yticks(), family="Helvetica", color="gray", rotation=90) ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True) diff --git a/tests/algorithms/test_formula.py b/tests/algorithms/test_formula.py new file mode 100644 index 00000000..8a2d4ae1 --- /dev/null +++ b/tests/algorithms/test_formula.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import pytest + +from jcvi.algorithms.formula import calc_ldscore + + +@pytest.mark.parametrize( + "a,b,expected", + [ + ("AAA", "AAA", 0.0), + ("AAB", "ABB", 0.25), + ("AAB", "BBB", 0.0), + ("AABB", "BBAA", 1.0), + ], +) +def test_calc_ldscore(a: str, b: str, expected: float): + assert calc_ldscore(a, b) == expected