Skip to content

Commit

Permalink
Merge pull request #30 from pinellolab/jon_dev5
Browse files Browse the repository at this point in the history
fix markers nnz
  • Loading branch information
j-bac authored Mar 18, 2024
2 parents c55b064 + 8499b37 commit c4e92a7
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 1 deletion.
105 changes: 105 additions & 0 deletions stream2/tools/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,41 @@ def extend_leaves(
TrimmingRadius=float("inf"),
key="epg",
):
"""Extend leaves with additional nodes
Parameters
-----------
Mode: str, the mode used to extend the graph.
"QuantDists","QuantCentroid", "WeigthedCentroid"
LeafIDs: int vector,
The id of nodes to extend. If None, all the vertices will be extended.
TrimmingRadius: positive numeric
The trimming radius used to control distance
DoSA: bool
Should optimization (via simulated annealing)
be performed when Mode = "QuantDists"?
ControlPar: positive numeric
The parameter used to control the contribution of
the different data points
The value of ControlPar has a different interpretation
depending on the valus of Mode.
In each case, for only the extreme points,
i.e., the points associated with the leaf node that
do not have a projection on any edge are considered.
If Mode = "QuantCentroid", for each leaf node,
the extreme points are ordered by their distance from the node
and the centroid of the points farther away
than ControlPar is returned.
If Mode = "WeightedCentroid", for each leaf node,
a weight is computed for each points
by raising the distance to the ControlPar power.
Hence, larger values of ControlPar result in a larger influence
of points farther from the node
"""

X = _get_graph_data(adata, key)

PG = elpigraph.ExtendLeaves(
Expand All @@ -737,6 +772,76 @@ def extend_leaves(
adata.uns[key]["edge"] = PG["Edges"][0]
_store_graph_attributes(adata, X, key)

def grow_leaves(
adata,
n_nodes=20,
use_weights=False,
epg_mu=None,
epg_lambda=None,
epg_cycle_mu=None,
epg_cycle_lambda=None,
key="epg",
):
"""Grow leaves using elpigraph optimization
Parameters
----------
use_weights: bool
Whether to weight points with adata.obs['pointweights']
shift_nodes_pos: dict
Optional dict to hold some nodes fixed at specified positions
e.g., {2:[.5,.2]} will hold node 2 at coordinates [.5,.2]
epg_mu: float
ElPiGraph Mu parameter
epg_lambda: float
ElPiGraph Lambda parameter
cycle_epg_mu: float
ElPiGraph Mu parameter, specific for nodes that are part of cycles
cycle_epg_lambda: float
ElPiGraph Lambda parameter, specific for nodes that are part of cycles
"""
# --- Init parameters, variables
if epg_mu is None:
epg_mu = adata.uns[key]["params"]["epg_mu"]
if epg_lambda is None:
epg_lambda = adata.uns[key]["params"]["epg_lambda"]
if epg_cycle_mu is None:
epg_cycle_mu = epg_mu
if epg_cycle_lambda is None:
epg_cycle_lambda = epg_lambda
if use_weights:
weights = np.array(adata.obs["pointweights"])[:, None]
else:
weights = None

X = _get_graph_data(adata, key)
PG = elpigraph.GrowLeaves(
X,
NumNodes=n_nodes+len(adata.uns["epg"]["node_pos"]),
InitNodePositions=adata.uns["epg"]["node_pos"],
InitEdges=adata.uns["epg"]["edge"],
PointWeights=weights,
Mu=epg_mu,
Lambda=epg_lambda,
verbose=1,
Do_PCA=False,
CenterData=False
)[0]

adata.uns["epg"]["node_pos"] = PG["NodePositions"]
adata.uns["epg"]["edge"] = PG["Edges"][0]

# update edge_len, conn, data projection
_store_graph_attributes(adata, X, key)

def nodes_info(adata,key='epg'):
'''Return dict of graph nodes classified into leaf, branching, branch
'''
g = elpigraph.src.graphs.ConstructGraph(stream2elpi(adata, key=key))
leaf = np.where(np.array(g.degree()) == 1)[0]
branching = np.where(np.array(g.degree()) > 2)[0]
branch = np.where(np.array(g.degree()) == 2)[0]
return {'leaf':leaf, 'branching':branching, 'branch':branch}

def use_graph_with_n_nodes(adata, n_nodes):
"""Use the graph at n_nodes.
Expand Down
2 changes: 1 addition & 1 deletion stream2/tools/_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def detect_transition_markers(
+ " cells ..."
)
input_markers_expressed = np.array(input_markers)[
np.where((df_sc[input_markers] > 0).sum(axis=0) > min_num_cells)[0]
np.where((df_sc[input_markers] != 0).sum(axis=0) > min_num_cells)[0]
].tolist()
df_marker_detection = df_sc[input_markers_expressed].copy()

Expand Down

0 comments on commit c4e92a7

Please sign in to comment.