diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py index c7e439cf51..6e311cde0c 100644 --- a/dace/sdfg/work_depth_analysis/assumptions.py +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -1,7 +1,7 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import sympy as sp -from typing import Tuple, Dict +from typing import Dict class UnionFind: diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index b05ccc70ae..3549e86a20 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -5,7 +5,7 @@ import argparse from collections import deque from dace.sdfg import nodes as nd, propagation, InterstateEdge -from dace import SDFG, SDFGState, dtypes, int64 +from dace import SDFG, SDFGState, dtypes from dace.subsets import Range from typing import Tuple, Dict import os @@ -251,7 +251,7 @@ def tasklet_work(tasklet_node, state): def tasklet_depth(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: - # For now we simply take depth == work for CPP tasklets. + # Depth == work for CPP tasklets. for oedge in state.out_edges(tasklet_node): return oedge.data.num_accesses if tasklet_node.code.language == dtypes.Language.Python: @@ -292,9 +292,13 @@ def do_initial_subs(w, d, eq, subs1): return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols: Dict[str, str], - detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]: +def sdfg_work_depth(sdfg: SDFG, + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -318,8 +322,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1) + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, + detailed_analysis) # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. state_work = sp.simplify(state_work * @@ -475,14 +479,16 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana return sdfg_result -def scope_work_depth(state: SDFGState, - w_d_map: Dict[str, sp.Expr], - analyze_tasklet, - symbols: Dict[str, str], - detailed_analysis: bool, - equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr], - entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: +def scope_work_depth( + state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + entry: nd.EntryNode = None, + detailed_analysis: bool = False, +) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a scope. This works by traversing through the scope analyzing the work and depth of each encountered node. @@ -519,8 +525,8 @@ def scope_work_depth(state: SDFGState, if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, node, + detailed_analysis) s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work @@ -545,8 +551,8 @@ def scope_work_depth(state: SDFGState, nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, - detailed_analysis, equality_subs, subs1) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, equality_subs, + subs1, detailed_analysis) nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map @@ -561,7 +567,7 @@ def scope_work_depth(state: SDFGState, # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, # such that the user can define its value. But it doesn't... # How to achieve this? - top_level_sdfg.add_symbol(f'{node.name}_work', int64) + top_level_sdfg.add_symbol(f'{node.name}_work', dtypes.int64) lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) lib_node_depth = sp.sympify(-1) # not analyzed if analyze_tasklet != get_tasklet_work: @@ -570,7 +576,7 @@ def scope_work_depth(state: SDFGState, lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) except KeyError: top_level_sdfg = state.parent - top_level_sdfg.add_symbol(f'{node.name}_depth', int64) + top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) work += lib_node_work @@ -581,7 +587,7 @@ def scope_work_depth(state: SDFGState, if isinstance(entry, nd.MapEntry): nmap: nd.Map = entry.map range: Range = nmap.range - n_exec = range.num_elements_exact() + n_exec = range.num_elements() work = sp.simplify(work * n_exec.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) else: print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') @@ -669,8 +675,13 @@ def scope_work_depth(state: SDFGState, return scope_result -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: +def state_work_depth(state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols, + equality_subs, + subs1, + detailed_analysis=False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -685,13 +696,16 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the state. """ - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, - None) + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, None, + detailed_analysis) return work, depth -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], - detailed_analysis: bool) -> None: +def analyze_sdfg(sdfg: SDFG, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + assumptions: [str], + detailed_analysis: bool = False) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -724,8 +738,8 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assum # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, - all_subs[0][0] if len(all_subs) > 0 else {}) + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}, + detailed_analysis) for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts.