diff --git a/stlearn/spatials/trajectory/__init__.py b/stlearn/spatials/trajectory/__init__.py index 0a1dc6c..bd6c482 100644 --- a/stlearn/spatials/trajectory/__init__.py +++ b/stlearn/spatials/trajectory/__init__.py @@ -11,4 +11,4 @@ from .compare_transitions import compare_transitions from .set_root import set_root -from .shortest_path_spatial_PAGA import shortest_path_spatial_PAGA \ No newline at end of file +from .shortest_path_spatial_PAGA import shortest_path_spatial_PAGA diff --git a/stlearn/spatials/trajectory/global_level.py b/stlearn/spatials/trajectory/global_level.py index 60f3965..6a89823 100644 --- a/stlearn/spatials/trajectory/global_level.py +++ b/stlearn/spatials/trajectory/global_level.py @@ -19,7 +19,6 @@ def global_level( verbose: bool = True, copy: bool = False, ) -> Optional[AnnData]: - """\ Perform global sptial trajectory inference. @@ -152,7 +151,6 @@ def global_level( labels = nx.get_edge_attributes(H_sub, "weight") for edge, _ in labels.items(): - dm = dm_list[order_big_dict[query_dict[edge[0]]]] sdm = sdm_list[order_big_dict[query_dict[edge[0]]]] @@ -160,7 +158,11 @@ def global_level( order_dict[edge[0]], order_dict[edge[1]] ] * (1 - w) H_sub[edge[0]][edge[1]]["weight"] = weight - # tmp = H_sub + + # Set edges with weight=None to weight=0 + for u, v, tmp in H_sub.edges(data=True): + if tmp.get("weight") is None: + H_sub[u][v]["weight"] = 0 H_sub = nx.algorithms.tree.minimum_spanning_arborescence(H_sub) H_nodes = list(range(len(H_sub.nodes))) @@ -236,7 +238,6 @@ def ordering_nodes(node_list, use_label, adata): def spatial_distance_matrix(adata, cluster1, cluster2, use_label): - tmp = adata.obs[adata.obs[use_label] == str(cluster1)] chosen_adata1 = adata[list(tmp.index)] tmp = adata.obs[adata.obs[use_label] == str(cluster2)] @@ -267,7 +268,6 @@ def spatial_distance_matrix(adata, cluster1, cluster2, use_label): def ge_distance_matrix(adata, cluster1, cluster2, use_label, use_rep, n_dims): - tmp = adata.obs[adata.obs[use_label] == str(cluster1)] chosen_adata1 = adata[list(tmp.index)] tmp = adata.obs[adata.obs[use_label] == str(cluster2)] diff --git a/stlearn/spatials/trajectory/pseudotime.py b/stlearn/spatials/trajectory/pseudotime.py index f0f674b..0c9df49 100644 --- a/stlearn/spatials/trajectory/pseudotime.py +++ b/stlearn/spatials/trajectory/pseudotime.py @@ -246,7 +246,6 @@ def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key if len(path) < max_nodes: all_paths[str(i) + "_" + str(source) + "_" + str(target)] = path - adata.uns["available_paths"] = all_paths print( "All available trajectory paths are stored in adata.uns['available_paths'] with length < " diff --git a/stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py b/stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py index 6340d27..bfd6b35 100644 --- a/stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py +++ b/stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py @@ -2,16 +2,21 @@ import numpy as np from stlearn.utils import _read_graph -def shortest_path_spatial_PAGA(adata,use_label,key="dpt_pseudotime",): + +def shortest_path_spatial_PAGA( + adata, + use_label, + key="dpt_pseudotime", +): # Read original PAGA graph G = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray()) edge_weights = nx.get_edge_attributes(G, "weight") - G.remove_edges_from((e for e, w in edge_weights.items() if w <0)) + G.remove_edges_from((e for e, w in edge_weights.items() if w < 0)) H = G.to_directed() - + # Get min_node and max_node - min_node,max_node = find_min_max_node(adata,key,use_label) - + min_node, max_node = find_min_max_node(adata, key, use_label) + # Calculate pseudotime for each node node_pseudotime = {} @@ -26,7 +31,7 @@ def shortest_path_spatial_PAGA(adata,use_label,key="dpt_pseudotime",): if node_pseudotime[edge[0]] - node_pseudotime[edge[1]] > 0: edge_to_remove.append(edge) H.remove_edges_from(edge_to_remove) - + # Extract all available paths all_paths = {} j = 0 @@ -34,36 +39,37 @@ def shortest_path_spatial_PAGA(adata,use_label,key="dpt_pseudotime",): for target in H.nodes: paths = nx.all_simple_paths(H, source=source, target=target) for i, path in enumerate(paths): - j+=1 + j += 1 all_paths[j] = path - + # Filter the target paths from min_node to max_node target_paths = [] for path in list(all_paths.values()): if path[0] == min_node and path[-1] == max_node: target_paths.append(path) - + # Get the global graph G = _read_graph(adata, "global_graph") - + centroid_dict = adata.uns["centroid_dict"] centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict} - + # Generate total length of every path. Store by dictionary dist_dict = {} for path in target_paths: - path_name = ",".join(list(map(str,path))) + path_name = ",".join(list(map(str, path))) result = [] query_node = get_node(path, adata.uns["split_node"]) for edge in G.edges(): if (edge[0] in query_node) and (edge[1] in query_node): result.append(edge) if len(result) >= len(path): - dist_dict[path_name] = calculate_total_dist(result,centroid_dict) - + dist_dict[path_name] = calculate_total_dist(result, centroid_dict) + # Find the shortest path shortest_path = min(dist_dict, key=lambda x: dist_dict[x]) - return shortest_path.split(',') + return shortest_path.split(",") + # get name of cluster by subcluster def get_cluster(search, dictionary): @@ -71,24 +77,28 @@ def get_cluster(search, dictionary): if search in sub: return cl + def get_node(node_list, split_node): result = np.array([]) for node in node_list: result = np.append(result, np.array(split_node[int(node)]).astype(int)) return result.astype(int) -def find_min_max_node(adata,key="dpt_pseudotime",use_label="leiden"): - min_cluster = int(adata.obs[adata.obs[key]==0][use_label].values[0]) - max_cluster = int(adata.obs[adata.obs[key]==1][use_label].values[0]) - - return [min_cluster,max_cluster] -def calculate_total_dist(result,centroid_dict): +def find_min_max_node(adata, key="dpt_pseudotime", use_label="leiden"): + min_cluster = int(adata.obs[adata.obs[key] == 0][use_label].values[0]) + max_cluster = int(adata.obs[adata.obs[key] == 1][use_label].values[0]) + + return [min_cluster, max_cluster] + + +def calculate_total_dist(result, centroid_dict): import math + total_dist = 0 for edge in result: source = centroid_dict[edge[0]] target = centroid_dict[edge[1]] - dist =math.dist(source,target) + dist = math.dist(source, target) total_dist += dist - return total_dist \ No newline at end of file + return total_dist