Skip to content

Commit

Permalink
Merge pull request #345 from bacpop/network_relabelling
Browse files Browse the repository at this point in the history
Updates to assiging and visualisation for beebop and mandrake
  • Loading branch information
nickjcroucher authored Jan 23, 2025
2 parents 7459bb1 + 1d74b1e commit e8d7921
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 85 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.7.3'
__version__ = '2.7.4'

# Minimum sketchlib version
SKETCHLIB_MAJOR = 2
Expand Down
2 changes: 1 addition & 1 deletion PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_options():
oGroup.add_argument('--update-db', help='Update reference database with query sequences', default=False, action='store_true')
oGroup.add_argument('--overwrite', help='Overwrite any existing database files', default=False, action='store_true')
oGroup.add_argument('--graph-weights', help='Save within-strain Euclidean distances into the graph', default=False, action='store_true')
oGroup.add_argument('--save-partial-query-graph', help='Save the network components to which queries are assigned', default=False, action='store_true')
oGroup.add_argument('--save-partial-query-graph', help='Save only the network components to which queries are assigned', default=False, action='store_true')

# comparison metrics
kmerGroup = parser.add_argument_group('Kmer comparison options')
Expand Down
6 changes: 3 additions & 3 deletions PopPUNK/mandrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pp_sketchlib
from SCE import wtsne
try:
from SCE import wtsne_gpu_fp64
from SCE import wtsne_gpu_fp32
gpu_fn_available = True
except ImportError:
gpu_fn_available = False
Expand Down Expand Up @@ -63,7 +63,7 @@ def generate_embedding(seqLabels, accMat, perplexity, outPrefix, overwrite, kNN
sys.stderr.write("Mandrake analysis already exists; add --overwrite to replace\n")
else:
sys.stderr.write("Running mandrake\n")
kNN = max(kNN, len(seqLabels) - 1)
kNN = min(kNN, len(seqLabels) - 1)
I, J, dists = poppunk_refine.get_kNN_distances(accMat, kNN, 1, n_threads)

# Set up function call with either CPU or GPU
Expand All @@ -76,7 +76,7 @@ def generate_embedding(seqLabels, accMat, perplexity, outPrefix, overwrite, kNN
sys.stderr.write("Running on GPU\n")
n_workers = 65536
maxIter = round(maxIter / n_workers)
wtsne_call = partial(wtsne_gpu_fp64,
wtsne_call = partial(wtsne_gpu_fp32,
perplexity=perplexity,
maxIter=maxIter,
blockSize=128,
Expand Down
79 changes: 79 additions & 0 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from multiprocessing import Pool
import pickle
import graph_tool.all as gt
import pp_sketchlib

# Load GPU libraries
try:
Expand Down Expand Up @@ -2016,3 +2017,81 @@ def remove_non_query_components(G, rlist, qlist, use_gpu = False):
query_subgraph = gt.GraphView(G, vfilt=query_filter)

return query_subgraph, pruned_names

def generate_network_from_distances(mode,
model,
core_distMat = None,
acc_distMat = None,
sparse_mat = None,
previous_mst = None,
combined_seq = None,
rlist = None,
old_rlist = None,
distance_type = 'core',
threads = 1,
gpu_graph = False):
"""
Generates a network from a distance matrix.
Args:
mode (str)
Whether a core or sparse distance matrix is being analysed
model (ClusterFit or LineageFit)
A fitted model object
coreMat (numpy.array)
NxN array of core distances for N sequences
accMat (numpy.array)
NxN array of accessory distances for N sequences
sparse_mat (scipy or cupyx sparse matrix)
Sparse matrix of kNN from lineage fit
previous_mst (str or graph object)
Path of file containing existing network, or already-loaded
graph object
combined_seq (list)
Ordered list of isolate names
rlist (list)
List of reference sequence labels
old_rlist (list)
List of reference sequence labels for previous MST
distance_type (str)
Whether to use core or accessory distances for MST calculation
or dense network weighting
threads (int)
Number of threads to use in calculations
use_gpu (bool)
Whether to use GPUs for network construction
Returns:
G (graph)
The resulting network
pruned_names (list)
The labels of the sequences in the pruned network
"""
if mode == 'sparse':
G = generate_mst_from_sparse_input(sparse_mat,
rlist,
old_rlist = old_rlist,
previous_mst = previous_mst,
gpu_graph = gpu_graph)
elif mode == 'dense':
# Get distance matrix
complete_distMat = \
np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1),
pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1)))
# Identify short distances and use these to extend the model
indivAssignments = model.assign(complete_distMat)
G = construct_network_from_assignments(combined_seq,
combined_seq,
indivAssignments,
model.within_label,
distMat = complete_distMat,
weights_type = distance_type,
use_gpu = gpu_graph,
summarise = False)
if gpu_graph:
G = cugraph.minimum_spanning_tree(G, weight='weights')

else:
sys.stderr.write('Unknown network mode - expect dense or sparse\n')

return G
40 changes: 19 additions & 21 deletions PopPUNK/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,27 +558,26 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv,
if use_partial_query_graph is None:
save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True)

# Save each component too (useful for very large graphs)
# Store query names
querySet = frozenset(queryList) if queryList is not None else frozenset()

# Save each cluster too (useful for very large graphs)
example_cluster_title = list(clustering.keys())[0]
component_assignments, component_hist = gt.label_components(G)
for component_idx in range(len(component_hist)):
# Naming must reflect the full graph size
component_name = component_idx + 1
get_component_name = (use_partial_query_graph is not None)
if use_partial_query_graph is not None:
represented_clusters = set(clustering[example_cluster_title][isolate] for isolate in isolate_names)
else:
represented_clusters = set(clustering[example_cluster_title].values())
for cluster in represented_clusters:
# Filter the graph for the current component
comp_filter = G.new_vertex_property("bool")
for v in G.vertices():
comp_filter[v] = (component_assignments[v] == component_idx)
# If using partial query graph find the component name from the clustering
if get_component_name and comp_filter[v]:
example_isolate_name = seqLabels[int(v)]
component_name = clustering[example_cluster_title][example_isolate_name]
get_component_name = False
vertex_name = seqLabels[int(v)]
comp_filter[v] = (clustering[example_cluster_title][vertex_name] == cluster)
G_component = gt.GraphView(G, vfilt=comp_filter)
# Purge the component to remove unreferenced vertices (optional but recommended)
G_component.purge_vertices()
# Save the component network
save_network(G_component, prefix = outPrefix, suffix = "_component_" + str(component_name), use_graphml = True)
save_network(G_component, prefix = outPrefix, suffix = "_component_" + str(cluster), use_graphml = True)

if G_mst != None:
isolate_labels = isolateNameToLabel(G_mst.vp.id)
Expand Down Expand Up @@ -730,14 +729,13 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering,
d['Status'].append("Reference")
if epiCsv is not None:
if label in epiData.index:
if label in epiData.index:
for col, value in zip(epiData.columns.values, epiData.loc[[label]].iloc[0].values):
if col not in columns_to_be_omitted:
d[col].append(str(value))
else:
for col in epiData.columns.values:
if col not in columns_to_be_omitted:
d[col].append('nan')
for col, value in zip(epiData.columns.values, epiData.loc[[label]].iloc[0].values):
if col not in columns_to_be_omitted:
d[col].append(str(value))
else:
for col in epiData.columns.values:
if col not in columns_to_be_omitted:
d[col].append('')

else:
sys.stderr.write("Cannot find " + name + " in clustering\n")
Expand Down
Loading

0 comments on commit e8d7921

Please sign in to comment.