Skip to content

Commit

Permalink
Add display_pedigree method #1097
Browse files Browse the repository at this point in the history
  • Loading branch information
timothymillar authored and jeromekelleher committed Sep 8, 2023
1 parent c897aff commit b532598
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Utilities
convert_call_to_index
convert_probability_to_call
display_genotypes
display_pedigree
filter_partial_calls
infer_call_ploidy
infer_sample_ploidy
Expand Down
3 changes: 2 additions & 1 deletion sgkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pkg_resources import DistributionNotFound, get_distribution # type: ignore[import]

from .display import display_genotypes
from .display import display_genotypes, display_pedigree
from .distance.api import pairwise_distance
from .io.dataset import load_dataset, save_dataset
from .io.vcfzarr_reader import read_scikit_allel_vcfzarr
Expand Down Expand Up @@ -88,6 +88,7 @@
"count_variant_genotypes",
"create_genotype_dosage_dataset",
"display_genotypes",
"display_pedigree",
"filter_partial_calls",
"genee",
"genomic_relationship",
Expand Down
79 changes: 78 additions & 1 deletion sgkit/display.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any, Hashable, Mapping, Tuple
from typing import Any, Dict, Hashable, Mapping, Optional, Tuple

import numpy as np
import pandas as pd
import xarray as xr

from sgkit import variables
from sgkit.stats.pedigree import parent_indices
from sgkit.typing import ArrayLike
from sgkit.utils import define_variable_if_absent


class GenotypeDisplay:
Expand Down Expand Up @@ -209,3 +212,77 @@ def display_genotypes(
max_variants,
max_samples,
)


def display_pedigree(
ds: xr.Dataset,
parent: Hashable = variables.parent,
graph_attrs: Optional[Dict[Hashable, str]] = None,
node_attrs: Optional[Dict[Hashable, ArrayLike]] = None,
edge_attrs: Optional[Dict[Hashable, ArrayLike]] = None,
) -> Any:
"""Display a pedigree dataset as a directed acyclic graph.
Parameters
----------
ds
Dataset containing pedigree structure.
parent
Input variable name holding parents of each sample as defined by
:data:`sgkit.variables.parent_spec`.
If the variable is not present in ``ds``, it will be computed
using :func:`parent_indices`.
graph_attrs
Key-value pairs to pass through to graphviz as graph attributes.
node_attrs
Key-value pairs to pass through to graphviz as node attributes.
Values will be broadcast to have shape (samples, ).
edge_attrs
Key-value pairs to pass through to graphviz as edge attributes.
Values will be broadcast to have shape (samples, parents).
Raises
------
RuntimeError
If the `Graphviz library <https://graphviz.readthedocs.io/en/stable/>`_ is not installed.
Returns
-------
A digraph representation of the pedigree.
"""
try:
from graphviz import Digraph
except ImportError: # pragma: no cover
raise RuntimeError(
"Visualizing pedigrees requires the `graphviz` python library and the `graphviz` system library to be installed."
)
ds = define_variable_if_absent(ds, variables.parent, parent, parent_indices)
variables.validate(ds, {parent: variables.parent_spec})
parent = ds[parent].values
n_samples, n_parent_types = parent.shape
graph_attrs = graph_attrs or {}
node_attrs = node_attrs or {}
edge_attrs = edge_attrs or {}
# default to using samples coordinates for labels
if ("label" not in node_attrs) and ("samples" in ds.coords):
node_attrs["label"] = ds.samples.values
# numpy broadcasting
node_attrs = {k: np.broadcast_to(v, n_samples) for k, v in node_attrs.items()}
edge_attrs = {k: np.broadcast_to(v, parent.shape) for k, v in edge_attrs.items()}
# initialize graph
graph = Digraph()
graph.attr(**graph_attrs)
# add nodes
for i in range(n_samples):
d = {k: str(v[i]) for k, v in node_attrs.items()}
graph.node(str(i), **d)
# add edges
for i in range(n_samples):
for j in range(n_parent_types):
p = parent[i, j]
if p >= 0:
d = {}
for k, v in edge_attrs.items():
d[k] = str(v[i, j])
graph.edge(str(p), str(i), **d)
return graph
191 changes: 190 additions & 1 deletion sgkit/tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import xarray as xr

from sgkit import display_genotypes
from sgkit import display_genotypes, display_pedigree
from sgkit.display import genotype_as_bytes
from sgkit.testing import simulate_genotype_call_dataset

Expand Down Expand Up @@ -417,3 +417,192 @@ def test_genotype_as_bytes(genotype, phased, max_allele_chars, expect):
expect,
genotype_as_bytes(genotype, phased, max_allele_chars),
)


def pedigree_Hamilton_Kerr():
ds = xr.Dataset()
ds["sample_id"] = "samples", ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8"]
ds["parent_id"] = ["samples", "parents"], [
[".", "."],
[".", "."],
[".", "S2"],
["S1", "."],
["S1", "S3"],
["S1", "S3"],
["S6", "S2"],
["S6", "S2"],
]
ds["stat_Hamilton_Kerr_tau"] = ["samples", "parents"], [
[1, 1],
[2, 2],
[0, 2],
[2, 0],
[1, 1],
[2, 2],
[2, 2],
[2, 2],
]
ds["stat_Hamilton_Kerr_lambda"] = ["samples", "parents"], [
[0.0, 0.0],
[0.167, 0.167],
[0.0, 0.167],
[0.041, 0.0],
[0.0, 0.0],
[0.918, 0.041],
[0.167, 0.167],
[0.167, 0.167],
]
return ds


def test_display_pedigree__no_coords():
ds = pedigree_Hamilton_Kerr()
graph = display_pedigree(ds)
expect = """ digraph {
\t0
\t1
\t2
\t3
\t4
\t5
\t6
\t7
\t1 -> 2
\t0 -> 3
\t0 -> 4
\t2 -> 4
\t0 -> 5
\t2 -> 5
\t5 -> 6
\t1 -> 6
\t5 -> 7
\t1 -> 7
}
"""
assert str(graph) == dedent(expect)


def test_display_pedigree__samples_coords():
ds = pedigree_Hamilton_Kerr()
ds = ds.assign_coords(samples=ds.sample_id)
graph = display_pedigree(ds)
expect = """ digraph {
\t0 [label=S1]
\t1 [label=S2]
\t2 [label=S3]
\t3 [label=S4]
\t4 [label=S5]
\t5 [label=S6]
\t6 [label=S7]
\t7 [label=S8]
\t1 -> 2
\t0 -> 3
\t0 -> 4
\t2 -> 4
\t0 -> 5
\t2 -> 5
\t5 -> 6
\t1 -> 6
\t5 -> 7
\t1 -> 7
}
"""
assert str(graph) == dedent(expect)


def test_display_pedigree__samples_coords_reorder():
ds = pedigree_Hamilton_Kerr()
ds = ds.sel(samples=[7, 3, 5, 0, 4, 1, 2, 6])
ds = ds.assign_coords(samples=ds.sample_id)
graph = display_pedigree(ds)
expect = """ digraph {
\t0 [label=S8]
\t1 [label=S4]
\t2 [label=S6]
\t3 [label=S1]
\t4 [label=S5]
\t5 [label=S2]
\t6 [label=S3]
\t7 [label=S7]
\t2 -> 0
\t5 -> 0
\t3 -> 1
\t3 -> 2
\t6 -> 2
\t3 -> 4
\t6 -> 4
\t5 -> 6
\t2 -> 7
\t5 -> 7
}
"""
assert str(graph) == dedent(expect)


def test_display_pedigree__samples_labels():
ds = pedigree_Hamilton_Kerr()
graph = display_pedigree(ds, node_attrs=dict(label=ds.sample_id))
expect = """ digraph {
\t0 [label=S1]
\t1 [label=S2]
\t2 [label=S3]
\t3 [label=S4]
\t4 [label=S5]
\t5 [label=S6]
\t6 [label=S7]
\t7 [label=S8]
\t1 -> 2
\t0 -> 3
\t0 -> 4
\t2 -> 4
\t0 -> 5
\t2 -> 5
\t5 -> 6
\t1 -> 6
\t5 -> 7
\t1 -> 7
}
"""
assert str(graph) == dedent(expect)


def test_display_pedigree__broadcast():
ds = pedigree_Hamilton_Kerr()
inbreeding = np.array([0.0, 0.077, 0.231, 0.041, 0.0, 0.197, 0.196, 0.196])
label = (ds.sample_id.str + "\n").str + inbreeding.astype("U")
edges = xr.where(
ds.stat_Hamilton_Kerr_tau == 2,
"black:black",
"black",
)
graph = display_pedigree(
ds,
graph_attrs=dict(splines="false", outputorder="edgesfirst"),
node_attrs=dict(
style="filled", fillcolor="black", fontcolor="white", label=label
),
edge_attrs=dict(arrowhead="crow", color=edges),
)
expect = """ digraph {
\toutputorder=edgesfirst splines=false
\t0 [label="S1\n 0.0" fillcolor=black fontcolor=white style=filled]
\t1 [label="S2\n 0.077" fillcolor=black fontcolor=white style=filled]
\t2 [label="S3\n 0.231" fillcolor=black fontcolor=white style=filled]
\t3 [label="S4\n 0.041" fillcolor=black fontcolor=white style=filled]
\t4 [label="S5\n 0.0" fillcolor=black fontcolor=white style=filled]
\t5 [label="S6\n 0.197" fillcolor=black fontcolor=white style=filled]
\t6 [label="S7\n 0.196" fillcolor=black fontcolor=white style=filled]
\t7 [label="S8\n 0.196" fillcolor=black fontcolor=white style=filled]
\t1 -> 2 [arrowhead=crow color="black:black"]
\t0 -> 3 [arrowhead=crow color="black:black"]
\t0 -> 4 [arrowhead=crow color=black]
\t2 -> 4 [arrowhead=crow color=black]
\t0 -> 5 [arrowhead=crow color="black:black"]
\t2 -> 5 [arrowhead=crow color="black:black"]
\t5 -> 6 [arrowhead=crow color="black:black"]
\t1 -> 6 [arrowhead=crow color="black:black"]
\t5 -> 7 [arrowhead=crow color="black:black"]
\t1 -> 7 [arrowhead=crow color="black:black"]
}
"""
assert str(graph) == dedent(expect)

0 comments on commit b532598

Please sign in to comment.