Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Apr 6, 2024
1 parent 892ed58 commit 29c0ef9
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
11 changes: 4 additions & 7 deletions WDL/Error.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def __init__(self, pos: SourcePosition, import_uri: str, message: Optional[str]
self.pos = pos


TVSourceNode = TypeVar("TVSourceNode", bound="SourceNode")


@total_ordering
class SourceNode:
"""Base class for an AST node, recording the source position"""
Expand All @@ -83,7 +80,7 @@ class SourceNode:
Source position for this AST node
"""

parent: Optional[TVSourceNode] = None
parent: Optional["SourceNode"] = None
"""
:type: Optional[SourceNode]
Expand All @@ -93,7 +90,7 @@ class SourceNode:
def __init__(self, pos: SourcePosition) -> None:
self.pos = pos

def __lt__(self, rhs: TVSourceNode) -> bool:
def __lt__(self, rhs: "SourceNode") -> bool:
if isinstance(rhs, SourceNode):
return (
self.pos.abspath,
Expand All @@ -110,12 +107,12 @@ def __lt__(self, rhs: TVSourceNode) -> bool:
)
return False

def __eq__(self, rhs: TVSourceNode) -> bool:
def __eq__(self, rhs: "SourceNode") -> bool:
assert isinstance(rhs, SourceNode)
return self.pos == rhs.pos

@property
def children(self: TVSourceNode) -> Iterable[TVSourceNode]:
def children(self) -> Iterable["SourceNode"]:
"""
:type: Iterable[SourceNode]
Expand Down
4 changes: 3 additions & 1 deletion WDL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ def values_from_json(
return ans


def values_to_json(values_env: Env.Bindings[Union[Value.Base, Tree.Decl, Type.Base]], namespace: str = "") -> Dict[str, Any]:
def values_to_json(
values_env: Env.Bindings[Union[Value.Base, Tree.Decl, Type.Base]], namespace: str = ""
) -> Dict[str, Any]:
"""
Convert a ``WDL.Env.Bindings[WDL.Value.Base]`` to a dict which ``json.dumps`` to
Cromwell-style JSON.
Expand Down
38 changes: 21 additions & 17 deletions WDL/runtime/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(

from .. import values_to_json

self.values_to_json = values_to_json # pyre-ignore
self.values_to_json = values_to_json # type: ignore

# Preprocess inputs: if None value is supplied for an input declared with a default but
# without the ? type quantifier, remove the binding entirely so that the default will be
Expand Down Expand Up @@ -264,7 +264,7 @@ def step(
doing so after initialization and after each ``call_finished()`` invocation, until at last
the workflow outputs are available.
"""
runnable = []
runnable: List[str] = []
while True:
# select a job whose dependencies are all finished
if not runnable:
Expand Down Expand Up @@ -346,7 +346,7 @@ def _do_job(

# for all non-Gather nodes, derive the environment by merging the outputs of all the
# dependencies (+ any current scatter variable bindings)
scatter_vars = Env.Bindings()
scatter_vars: Env.Bindings[Value.Base] = Env.Bindings()
for p in job.scatter_stack:
scatter_vars = Env.Bindings(p[1], scatter_vars)
# pyre-ignore
Expand Down Expand Up @@ -396,11 +396,12 @@ def _do_job(

if isinstance(job.node, Tree.Call):
# evaluate input expressions
call_inputs = Env.Bindings()
call_name = job.node.name
call_inputs: Env.Bindings[Value.Base] = Env.Bindings()
for name, expr in job.node.inputs.items():
call_inputs = call_inputs.bind(name, expr.eval(env, stdlib=stdlib))
# check workflow inputs for additional inputs supplied to this call
for b in self.inputs.enter_namespace(job.node.name):
for b in self.inputs.enter_namespace(call_name):
call_inputs = call_inputs.bind(b.name, b.value)

# coerce inputs to required types (treating inputs with defaults as optional even if
Expand All @@ -426,7 +427,7 @@ def _do_job(
call_inputs = Value.rewrite_env_paths(
call_inputs,
lambda v: _check_path_allowed(
cfg, self.fspath_allowlist, f"call {job.node.name} input", v
cfg, self.fspath_allowlist, f"call {call_name} input", v
),
)
# issue CallInstructions
Expand Down Expand Up @@ -470,7 +471,7 @@ def _scatter(
) -> Iterable[_Job]:
# we'll be tracking, for each body node ID, the IDs of the potentially multiple corresponding
# jobs scheduled
multiplex = {}
multiplex: Dict[str, Set[str]] = {}
for body_node in section.body:
multiplex[body_node.workflow_node_id] = set()
if isinstance(body_node, Tree.WorkflowSection):
Expand All @@ -479,10 +480,11 @@ def _scatter(

# evaluate scatter array or boolean condition
v = section.expr.eval(env, stdlib=stdlib)
array = []
array: List[Optional[Value.Base]] = []
if isinstance(section, Tree.Scatter):
assert isinstance(v, Value.Array)
array = v.value
for v_i in v.value:
array.append(v_i)
else:
assert isinstance(v, Value.Boolean)
if v.value:
Expand Down Expand Up @@ -653,7 +655,7 @@ def _gather(
assert False

# for each such name,
ans = Env.Bindings()
ans: Env.Bindings[Value.Base] = Env.Bindings()
ns = [leaf.name] if isinstance(leaf, Tree.Call) else []
for name in names:
# gather the corresponding values
Expand All @@ -662,7 +664,7 @@ def _gather(
assert v0 is None or isinstance(v0, Value.Base)
# bind the array, singleton value, or None as appropriate
if isinstance(gather.section, Tree.Scatter):
rhs = Value.Array((v0.type if v0 else Type.Any()), values)
rhs: Value.Base = Value.Array((v0.type if v0 else Type.Any()), values)
else:
assert isinstance(gather.section, Tree.Conditional)
assert len(values) <= 1
Expand Down Expand Up @@ -707,17 +709,19 @@ def _virtualize_filename(self, filename: str) -> str:
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(1048576), b""):
hasher.update(chunk)
cache_in = Env.Bindings().bind("file_sha256", Value.String(hasher.hexdigest()))
cache_in: Env.Bindings[Value.Base] = Env.Bindings()
cache_in = cache_in.bind("file_sha256", Value.String(hasher.hexdigest()))
cache_key = "write_/" + Value.digest_env(cache_in)
cache_out_types = Env.Bindings().bind("file", Type.File())
cache_out_types: Env.Bindings[Type.Base] = Env.Bindings()
cache_out_types = cache_out_types.bind("file", Type.File())
cache_out = self.cache.get(cache_key, cache_in, cache_out_types)
if cache_out:
filename = cache_out.resolve("file").value
else:
# otherwise, put our newly-written file to the cache, and proceed to use it
self.cache.put(
cache_key,
Env.Bindings().bind("file", Value.File(filename)),
Env.Bindings(Env.Binding("file", Value.File(filename))),
run_dir=self.state.run_dir,
)

Expand Down Expand Up @@ -873,7 +877,7 @@ def run_local_workflow(
cache = _cache
if not cache:
cache = cleanup.enter_context(new_call_cache(cfg, logger))
assert _thread_pools is None
assert cache and _thread_pools is None
if not _thread_pools:
cache.flock(logfile, exclusive=True) # flock top-level workflow.log
write_values_json(inputs, os.path.join(run_dir, "inputs.json"), namespace=workflow.name)
Expand Down Expand Up @@ -975,7 +979,7 @@ def _workflow_main_loop(
with compose_coroutines(
[
(
lambda kwargs, cor=cor: cor(
lambda kwargs, cor=cor: cor( # type: ignore
cfg, logger, run_id_stack, run_dir, workflow, **kwargs
)
)
Expand Down Expand Up @@ -1071,7 +1075,7 @@ def _workflow_main_loop(
except Exception as exn:
tbtxt = traceback.format_exc()
logger.debug(tbtxt)
cause = exn
cause: BaseException = exn
while isinstance(cause, RunFailed) and cause.__cause__:
cause = cause.__cause__
wrapper = RunFailed(workflow, run_id_stack[-1], run_dir)
Expand Down
3 changes: 3 additions & 0 deletions stubs/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
class Logger(OriginalLogger):
def notice(self, *args, **kwargs) -> None: ...
def verbose(self, *args, **kwargs) -> None: ...

def getLogger(name: str) -> Logger:
...

0 comments on commit 29c0ef9

Please sign in to comment.