diff --git a/stream2/tools/_graph_utils.py b/stream2/tools/_graph_utils.py index d8fdde1..537ce6c 100644 --- a/stream2/tools/_graph_utils.py +++ b/stream2/tools/_graph_utils.py @@ -720,6 +720,41 @@ def extend_leaves( TrimmingRadius=float("inf"), key="epg", ): + """Extend leaves with additional nodes + + Parameters + ----------- + Mode: str, the mode used to extend the graph. + "QuantDists","QuantCentroid", "WeigthedCentroid" + LeafIDs: int vector, + The id of nodes to extend. If None, all the vertices will be extended. + TrimmingRadius: positive numeric + The trimming radius used to control distance + DoSA: bool + Should optimization (via simulated annealing) + be performed when Mode = "QuantDists"? + ControlPar: positive numeric + The parameter used to control the contribution of + the different data points + + The value of ControlPar has a different interpretation + depending on the valus of Mode. + In each case, for only the extreme points, + i.e., the points associated with the leaf node that + do not have a projection on any edge are considered. + + If Mode = "QuantCentroid", for each leaf node, + the extreme points are ordered by their distance from the node + and the centroid of the points farther away + than ControlPar is returned. + + If Mode = "WeightedCentroid", for each leaf node, + a weight is computed for each points + by raising the distance to the ControlPar power. + Hence, larger values of ControlPar result in a larger influence + of points farther from the node + """ + X = _get_graph_data(adata, key) PG = elpigraph.ExtendLeaves( @@ -737,6 +772,76 @@ def extend_leaves( adata.uns[key]["edge"] = PG["Edges"][0] _store_graph_attributes(adata, X, key) +def grow_leaves( + adata, + n_nodes=20, + use_weights=False, + epg_mu=None, + epg_lambda=None, + epg_cycle_mu=None, + epg_cycle_lambda=None, + key="epg", +): + """Grow leaves using elpigraph optimization + + Parameters + ---------- + use_weights: bool + Whether to weight points with adata.obs['pointweights'] + shift_nodes_pos: dict + Optional dict to hold some nodes fixed at specified positions + e.g., {2:[.5,.2]} will hold node 2 at coordinates [.5,.2] + epg_mu: float + ElPiGraph Mu parameter + epg_lambda: float + ElPiGraph Lambda parameter + cycle_epg_mu: float + ElPiGraph Mu parameter, specific for nodes that are part of cycles + cycle_epg_lambda: float + ElPiGraph Lambda parameter, specific for nodes that are part of cycles + """ + # --- Init parameters, variables + if epg_mu is None: + epg_mu = adata.uns[key]["params"]["epg_mu"] + if epg_lambda is None: + epg_lambda = adata.uns[key]["params"]["epg_lambda"] + if epg_cycle_mu is None: + epg_cycle_mu = epg_mu + if epg_cycle_lambda is None: + epg_cycle_lambda = epg_lambda + if use_weights: + weights = np.array(adata.obs["pointweights"])[:, None] + else: + weights = None + + X = _get_graph_data(adata, key) + PG = elpigraph.GrowLeaves( + X, + NumNodes=n_nodes+len(adata.uns["epg"]["node_pos"]), + InitNodePositions=adata.uns["epg"]["node_pos"], + InitEdges=adata.uns["epg"]["edge"], + PointWeights=weights, + Mu=epg_mu, + Lambda=epg_lambda, + verbose=1, + Do_PCA=False, + CenterData=False + )[0] + + adata.uns["epg"]["node_pos"] = PG["NodePositions"] + adata.uns["epg"]["edge"] = PG["Edges"][0] + + # update edge_len, conn, data projection + _store_graph_attributes(adata, X, key) + +def nodes_info(adata,key='epg'): + '''Return dict of graph nodes classified into leaf, branching, branch + ''' + g = elpigraph.src.graphs.ConstructGraph(stream2elpi(adata, key=key)) + leaf = np.where(np.array(g.degree()) == 1)[0] + branching = np.where(np.array(g.degree()) > 2)[0] + branch = np.where(np.array(g.degree()) == 2)[0] + return {'leaf':leaf, 'branching':branching, 'branch':branch} def use_graph_with_n_nodes(adata, n_nodes): """Use the graph at n_nodes. diff --git a/stream2/tools/_markers.py b/stream2/tools/_markers.py index d57bead..c71816b 100644 --- a/stream2/tools/_markers.py +++ b/stream2/tools/_markers.py @@ -366,7 +366,7 @@ def detect_transition_markers( + " cells ..." ) input_markers_expressed = np.array(input_markers)[ - np.where((df_sc[input_markers] > 0).sum(axis=0) > min_num_cells)[0] + np.where((df_sc[input_markers] != 0).sum(axis=0) > min_num_cells)[0] ].tolist() df_marker_detection = df_sc[input_markers_expressed].copy()