diff --git a/docs/api.rst b/docs/api.rst index 73c8512a0..b3d3e8e30 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -132,6 +132,7 @@ Utilities convert_call_to_index convert_probability_to_call display_genotypes + display_pedigree filter_partial_calls infer_call_ploidy infer_sample_ploidy diff --git a/sgkit/__init__.py b/sgkit/__init__.py index 37bb86c4c..8c453ac39 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -1,6 +1,6 @@ from pkg_resources import DistributionNotFound, get_distribution # type: ignore[import] -from .display import display_genotypes +from .display import display_genotypes, display_pedigree from .distance.api import pairwise_distance from .io.dataset import load_dataset, save_dataset from .io.vcfzarr_reader import read_scikit_allel_vcfzarr @@ -88,6 +88,7 @@ "count_variant_genotypes", "create_genotype_dosage_dataset", "display_genotypes", + "display_pedigree", "filter_partial_calls", "genee", "genomic_relationship", diff --git a/sgkit/display.py b/sgkit/display.py index 595067e4f..e43faa599 100644 --- a/sgkit/display.py +++ b/sgkit/display.py @@ -1,10 +1,13 @@ -from typing import Any, Hashable, Mapping, Tuple +from typing import Any, Dict, Hashable, Mapping, Optional, Tuple import numpy as np import pandas as pd import xarray as xr +from sgkit import variables +from sgkit.stats.pedigree import parent_indices from sgkit.typing import ArrayLike +from sgkit.utils import define_variable_if_absent class GenotypeDisplay: @@ -209,3 +212,77 @@ def display_genotypes( max_variants, max_samples, ) + + +def display_pedigree( + ds: xr.Dataset, + parent: Hashable = variables.parent, + graph_attrs: Optional[Dict[Hashable, str]] = None, + node_attrs: Optional[Dict[Hashable, ArrayLike]] = None, + edge_attrs: Optional[Dict[Hashable, ArrayLike]] = None, +) -> Any: + """Display a pedigree dataset as a directed acyclic graph. + + Parameters + ---------- + ds + Dataset containing pedigree structure. + parent + Input variable name holding parents of each sample as defined by + :data:`sgkit.variables.parent_spec`. + If the variable is not present in ``ds``, it will be computed + using :func:`parent_indices`. + graph_attrs + Key-value pairs to pass through to graphviz as graph attributes. + node_attrs + Key-value pairs to pass through to graphviz as node attributes. + Values will be broadcast to have shape (samples, ). + edge_attrs + Key-value pairs to pass through to graphviz as edge attributes. + Values will be broadcast to have shape (samples, parents). + + Raises + ------ + RuntimeError + If the `Graphviz library `_ is not installed. + + Returns + ------- + A digraph representation of the pedigree. + """ + try: + from graphviz import Digraph + except ImportError: # pragma: no cover + raise RuntimeError( + "Visualizing pedigrees requires the `graphviz` python library and the `graphviz` system library to be installed." + ) + ds = define_variable_if_absent(ds, variables.parent, parent, parent_indices) + variables.validate(ds, {parent: variables.parent_spec}) + parent = ds[parent].values + n_samples, n_parent_types = parent.shape + graph_attrs = graph_attrs or {} + node_attrs = node_attrs or {} + edge_attrs = edge_attrs or {} + # default to using samples coordinates for labels + if ("label" not in node_attrs) and ("samples" in ds.coords): + node_attrs["label"] = ds.samples.values + # numpy broadcasting + node_attrs = {k: np.broadcast_to(v, n_samples) for k, v in node_attrs.items()} + edge_attrs = {k: np.broadcast_to(v, parent.shape) for k, v in edge_attrs.items()} + # initialize graph + graph = Digraph() + graph.attr(**graph_attrs) + # add nodes + for i in range(n_samples): + d = {k: str(v[i]) for k, v in node_attrs.items()} + graph.node(str(i), **d) + # add edges + for i in range(n_samples): + for j in range(n_parent_types): + p = parent[i, j] + if p >= 0: + d = {} + for k, v in edge_attrs.items(): + d[k] = str(v[i, j]) + graph.edge(str(p), str(i), **d) + return graph diff --git a/sgkit/tests/test_display.py b/sgkit/tests/test_display.py index 0ca95e2a7..979634e03 100644 --- a/sgkit/tests/test_display.py +++ b/sgkit/tests/test_display.py @@ -4,7 +4,7 @@ import pytest import xarray as xr -from sgkit import display_genotypes +from sgkit import display_genotypes, display_pedigree from sgkit.display import genotype_as_bytes from sgkit.testing import simulate_genotype_call_dataset @@ -417,3 +417,192 @@ def test_genotype_as_bytes(genotype, phased, max_allele_chars, expect): expect, genotype_as_bytes(genotype, phased, max_allele_chars), ) + + +def pedigree_Hamilton_Kerr(): + ds = xr.Dataset() + ds["sample_id"] = "samples", ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8"] + ds["parent_id"] = ["samples", "parents"], [ + [".", "."], + [".", "."], + [".", "S2"], + ["S1", "."], + ["S1", "S3"], + ["S1", "S3"], + ["S6", "S2"], + ["S6", "S2"], + ] + ds["stat_Hamilton_Kerr_tau"] = ["samples", "parents"], [ + [1, 1], + [2, 2], + [0, 2], + [2, 0], + [1, 1], + [2, 2], + [2, 2], + [2, 2], + ] + ds["stat_Hamilton_Kerr_lambda"] = ["samples", "parents"], [ + [0.0, 0.0], + [0.167, 0.167], + [0.0, 0.167], + [0.041, 0.0], + [0.0, 0.0], + [0.918, 0.041], + [0.167, 0.167], + [0.167, 0.167], + ] + return ds + + +def test_display_pedigree__no_coords(): + ds = pedigree_Hamilton_Kerr() + graph = display_pedigree(ds) + expect = """ digraph { + \t0 + \t1 + \t2 + \t3 + \t4 + \t5 + \t6 + \t7 + \t1 -> 2 + \t0 -> 3 + \t0 -> 4 + \t2 -> 4 + \t0 -> 5 + \t2 -> 5 + \t5 -> 6 + \t1 -> 6 + \t5 -> 7 + \t1 -> 7 + } + """ + assert str(graph) == dedent(expect) + + +def test_display_pedigree__samples_coords(): + ds = pedigree_Hamilton_Kerr() + ds = ds.assign_coords(samples=ds.sample_id) + graph = display_pedigree(ds) + expect = """ digraph { + \t0 [label=S1] + \t1 [label=S2] + \t2 [label=S3] + \t3 [label=S4] + \t4 [label=S5] + \t5 [label=S6] + \t6 [label=S7] + \t7 [label=S8] + \t1 -> 2 + \t0 -> 3 + \t0 -> 4 + \t2 -> 4 + \t0 -> 5 + \t2 -> 5 + \t5 -> 6 + \t1 -> 6 + \t5 -> 7 + \t1 -> 7 + } + """ + assert str(graph) == dedent(expect) + + +def test_display_pedigree__samples_coords_reorder(): + ds = pedigree_Hamilton_Kerr() + ds = ds.sel(samples=[7, 3, 5, 0, 4, 1, 2, 6]) + ds = ds.assign_coords(samples=ds.sample_id) + graph = display_pedigree(ds) + expect = """ digraph { + \t0 [label=S8] + \t1 [label=S4] + \t2 [label=S6] + \t3 [label=S1] + \t4 [label=S5] + \t5 [label=S2] + \t6 [label=S3] + \t7 [label=S7] + \t2 -> 0 + \t5 -> 0 + \t3 -> 1 + \t3 -> 2 + \t6 -> 2 + \t3 -> 4 + \t6 -> 4 + \t5 -> 6 + \t2 -> 7 + \t5 -> 7 + } + """ + assert str(graph) == dedent(expect) + + +def test_display_pedigree__samples_labels(): + ds = pedigree_Hamilton_Kerr() + graph = display_pedigree(ds, node_attrs=dict(label=ds.sample_id)) + expect = """ digraph { + \t0 [label=S1] + \t1 [label=S2] + \t2 [label=S3] + \t3 [label=S4] + \t4 [label=S5] + \t5 [label=S6] + \t6 [label=S7] + \t7 [label=S8] + \t1 -> 2 + \t0 -> 3 + \t0 -> 4 + \t2 -> 4 + \t0 -> 5 + \t2 -> 5 + \t5 -> 6 + \t1 -> 6 + \t5 -> 7 + \t1 -> 7 + } + """ + assert str(graph) == dedent(expect) + + +def test_display_pedigree__broadcast(): + ds = pedigree_Hamilton_Kerr() + inbreeding = np.array([0.0, 0.077, 0.231, 0.041, 0.0, 0.197, 0.196, 0.196]) + label = (ds.sample_id.str + "\n").str + inbreeding.astype("U") + edges = xr.where( + ds.stat_Hamilton_Kerr_tau == 2, + "black:black", + "black", + ) + graph = display_pedigree( + ds, + graph_attrs=dict(splines="false", outputorder="edgesfirst"), + node_attrs=dict( + style="filled", fillcolor="black", fontcolor="white", label=label + ), + edge_attrs=dict(arrowhead="crow", color=edges), + ) + expect = """ digraph { + \toutputorder=edgesfirst splines=false + \t0 [label="S1\n 0.0" fillcolor=black fontcolor=white style=filled] + \t1 [label="S2\n 0.077" fillcolor=black fontcolor=white style=filled] + \t2 [label="S3\n 0.231" fillcolor=black fontcolor=white style=filled] + \t3 [label="S4\n 0.041" fillcolor=black fontcolor=white style=filled] + \t4 [label="S5\n 0.0" fillcolor=black fontcolor=white style=filled] + \t5 [label="S6\n 0.197" fillcolor=black fontcolor=white style=filled] + \t6 [label="S7\n 0.196" fillcolor=black fontcolor=white style=filled] + \t7 [label="S8\n 0.196" fillcolor=black fontcolor=white style=filled] + \t1 -> 2 [arrowhead=crow color="black:black"] + \t0 -> 3 [arrowhead=crow color="black:black"] + \t0 -> 4 [arrowhead=crow color=black] + \t2 -> 4 [arrowhead=crow color=black] + \t0 -> 5 [arrowhead=crow color="black:black"] + \t2 -> 5 [arrowhead=crow color="black:black"] + \t5 -> 6 [arrowhead=crow color="black:black"] + \t1 -> 6 [arrowhead=crow color="black:black"] + \t5 -> 7 [arrowhead=crow color="black:black"] + \t1 -> 7 [arrowhead=crow color="black:black"] + } + """ + assert str(graph) == dedent(expect)