Skip to content

Commit

Permalink
refactor: rename _mk_func_nodes to ensure_func_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Jul 17, 2024
1 parent ca06ad8 commit 61870fa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 8 additions & 2 deletions meshed/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base functionality of meshed
"""

from collections import Counter
from dataclasses import dataclass, field, fields
from functools import partial, cached_property
Expand Down Expand Up @@ -529,7 +530,8 @@ def validate_that_func_node_names_are_sane(func_nodes: Iterable[FuncNode]):
)


def _mk_func_nodes(func_nodes):
def ensure_func_nodes(func_nodes):
"""Converts a list of objects to a list of FuncNodes."""
# TODO: Take care of names (or track and take care if collision)
if callable(func_nodes) and not isinstance(func_nodes, Iterable):
# if input is a single function, make it a list containing that function
Expand All @@ -544,6 +546,9 @@ def _mk_func_nodes(func_nodes):
raise TypeError(f"Can't convert this to a FuncNode: {func_node}")


_mk_func_nodes = ensure_func_nodes # backwards compatibility


def _func_nodes_to_graph_dict(func_nodes):
g = dict()

Expand Down Expand Up @@ -884,7 +889,8 @@ def gen():


def func_node_transformer(
fn: FuncNode, kwargs_transformers=(),
fn: FuncNode,
kwargs_transformers=(),
):
"""Get a modified ``FuncNode`` from an iterable of ``kwargs_trans`` modifiers."""
func_node_kwargs = fn.to_dict()
Expand Down
4 changes: 2 additions & 2 deletions meshed/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
dflt_configs,
BindInfo,
ch_func_node_func,
_mk_func_nodes,
ensure_func_nodes,
_func_nodes_to_graph_dict,
is_func_node,
FuncNodeAble,
Expand Down Expand Up @@ -503,7 +503,7 @@ class DAG:
)

def __post_init__(self):
self.func_nodes = tuple(_mk_func_nodes(self.func_nodes))
self.func_nodes = tuple(ensure_func_nodes(self.func_nodes))
self.graph = _func_nodes_to_graph_dict(self.func_nodes)
self.nodes = topological_sort(self.graph)
# reorder the nodes to fit topological order
Expand Down

0 comments on commit 61870fa

Please sign in to comment.