Skip to content

Commit

Permalink
Refactor into read_matrix()
Browse files Browse the repository at this point in the history
  • Loading branch information
tanghaibao committed Apr 28, 2024
1 parent 3f1c377 commit ad86bf6
Showing 1 changed file with 63 additions and 48 deletions.
111 changes: 63 additions & 48 deletions jcvi/assembly/hic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import os
import os.path as op
import sys

from collections import defaultdict
from functools import partial
from multiprocessing import Pool
from typing import List, Tuple

import numpy as np

Expand Down Expand Up @@ -683,6 +685,65 @@ def generate_groups(groupsfile):
yield seqids, color


def read_matrix(
npyfile: str, header: dict, contig: str, groups: List[Tuple[str, str]], opts
):
"""
Read the matrix from the npy file and apply log transformation and thresholding.
"""
# Load the matrix
A = np.load(npyfile)
total_bins = header["total_bins"]

# Select specific submatrix
if contig:
contig_start = header["starts"][contig]
contig_size = header["sizes"][contig]
contig_end = contig_start + contig_size
A = A[contig_start:contig_end, contig_start:contig_end]
else:
A = A[:total_bins, :total_bins]

# Convert seqids to positions for each group
new_groups = []
for seqids, color in groups:
seqids = seqids.split(",")
assert all(
x in header["starts"] for x in seqids
), f"{seqids} contain ids not found in starts"
assert all(
x in header["sizes"] for x in seqids
), f"{seqids} contain ids not found in sizes"
start = min(header["starts"][x] for x in seqids)
end = max(header["starts"][x] + header["sizes"][x] for x in seqids)
position_seqids = []
for seqid in seqids:
seqid_start = header["starts"][seqid]
seqid_size = header["sizes"][seqid]
position_seqids.append((seqid_start + seqid_size / 2, seqid))
new_groups.append((start, end, position_seqids, color))

# Several concerns in practice:
# The diagonal counts may be too strong, this can either be resolved by
# masking them. Or perform a log transform on the entire heatmap.
B = A.astype("float64")
B += 1.0
B = np.log(B)
vmin, vmax = opts.vmin, opts.vmax
B[B < vmin] = vmin
B[B > vmax] = vmax
print(B)
logger.debug("Matrix log-transformation and thresholding (%d-%d) done", vmin, vmax)

breaks = list(header["starts"].values())
breaks += [total_bins] # This is actually discarded
breaks = sorted(breaks)[1:]
if contig or opts.nobreaks:
breaks = []

return B, new_groups, breaks


def heatmap(args):
"""
%prog heatmap input.npy genome.json
Expand Down Expand Up @@ -727,64 +788,18 @@ def heatmap(args):
groups = list(generate_groups(opts.groups)) if opts.groups else []

# Load contig/chromosome starts and sizes
header = json.loads(open(jsonfile).read())
header = json.loads(open(jsonfile, encoding="utf-8").read())
resolution = header.get("resolution")
assert resolution is not None, "`resolution` not found in `{}`".format(jsonfile)
logger.debug("Resolution set to %d", resolution)
# Load the matrix
A = np.load(npyfile)
total_bins = header["total_bins"]

# Select specific submatrix
if contig:
contig_start = header["starts"][contig]
contig_size = header["sizes"][contig]
contig_end = contig_start + contig_size
A = A[contig_start:contig_end, contig_start:contig_end]
else:
A = A[:total_bins, :total_bins]

# Convert seqids to positions for each group
new_groups = []
for seqids, color in groups:
seqids = seqids.split(",")
assert all(
x in header["starts"] for x in seqids
), f"{seqids} contain ids not found in starts"
assert all(
x in header["sizes"] for x in seqids
), f"{seqids} contain ids not found in sizes"
start = min(header["starts"][x] for x in seqids)
end = max(header["starts"][x] + header["sizes"][x] for x in seqids)
position_seqids = []
for seqid in seqids:
seqid_start = header["starts"][seqid]
seqid_size = header["sizes"][seqid]
position_seqids.append((seqid_start + seqid_size / 2, seqid))
new_groups.append((start, end, position_seqids, color))

# Several concerns in practice:
# The diagonal counts may be too strong, this can either be resolved by
# masking them. Or perform a log transform on the entire heatmap.
B = A.astype("float64")
B += 1.0
B = np.log(B)
vmin, vmax = opts.vmin, opts.vmax
B[B < vmin] = vmin
B[B > vmax] = vmax
print(B)
logger.debug("Matrix log-transformation and thresholding (%d-%d) done", vmin, vmax)
B, new_groups, breaks = read_matrix(npyfile, header, contig, groups, opts)

# Canvas
fig = plt.figure(1, (iopts.w, iopts.h))
root = fig.add_axes((0, 0, 1, 1)) # whole canvas
ax = fig.add_axes((0.05, 0.05, 0.9, 0.9)) # just the heatmap

breaks = list(header["starts"].values())
breaks += [total_bins] # This is actually discarded
breaks = sorted(breaks)[1:]
if contig or opts.nobreaks:
breaks = []
plot_heatmap(ax, B, breaks, groups=new_groups, binsize=resolution)

# Title
Expand Down

0 comments on commit ad86bf6

Please sign in to comment.