From d0eb400b99c2a953f6e5c9eb249fae202b1d61b4 Mon Sep 17 00:00:00 2001 From: Cliff Hodel <111381329+hodelcl@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:41:28 +0200 Subject: [PATCH] Improvements to work depth analysis (#1363) * initial push of work_depth analysis script * adding tests to work_depth analysis * rename work depth analysis * todos added * code ready for PR * yapf for formatting * put tests into dace/tests/sdfg * fixed import after merge * merged propgatate_states_symbolically into propagate_states * fixed format issue in work_depth.py * small bugfix * include wcr edges into analysis, improve LibraryNodes analysis * imporved work depth. wcr now analyses, performance improved, assumptions can be passed * formatting with yapf * minor changes * start of op_in analysis * Revert "start of op_in analysis" This reverts commit eb5a6f427d47f314e3254f681639cf3f155f77c8. * changes according to comments --------- Co-authored-by: Cliff Hodel Co-authored-by: Cliff Hodel Co-authored-by: Philipp Schaad --- dace/sdfg/work_depth_analysis/assumptions.py | 285 +++++++++++++++ dace/sdfg/work_depth_analysis/helpers.py | 2 + dace/sdfg/work_depth_analysis/work_depth.py | 366 ++++++++++++++----- tests/sdfg/work_depth_tests.py | 97 ++++- 4 files changed, 638 insertions(+), 112 deletions(-) create mode 100644 dace/sdfg/work_depth_analysis/assumptions.py diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py new file mode 100644 index 0000000000..6e311cde0c --- /dev/null +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -0,0 +1,285 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import sympy as sp +from typing import Dict + + +class UnionFind: + """ + Simple, not really optimized UnionFind implementation. + """ + + def __init__(self, elements) -> None: + self.ids = {e: e for e in elements} + + def add_element(self, e): + if e in self.ids: + return False + self.ids.update({e: e}) + return True + + def find(self, e): + prev = e + curr = self.ids[e] + while prev != curr: + prev = curr + curr = self.ids[curr] + # shorten the path + self.ids[e] = curr + return curr + + def union(self, e, f): + if f not in self.ids: + self.add_element(f) + self.ids[self.find(e)] = f + + +class ContradictingAssumptions(Exception): + pass + + +class Assumptions: + """ + Summarises the assumptions for a single symbol in three lists: equal, greater, lesser. + """ + + def __init__(self) -> None: + self.greater = [] + self.lesser = [] + self.equal = [] + + def add_greater(self, g): + if isinstance(g, sp.Symbol): + self.greater.append(g) + else: + self.greater = [x for x in self.greater if isinstance(x, sp.Symbol) or x > g] + if len([y for y in self.greater if not isinstance(y, sp.Symbol)]) == 0: + self.greater.append(g) + self.check_consistency() + + def add_lesser(self, l): + if isinstance(l, sp.Symbol): + self.lesser.append(l) + else: + self.lesser = [x for x in self.lesser if isinstance(x, sp.Symbol) or x < l] + if len([y for y in self.lesser if not isinstance(y, sp.Symbol)]) == 0: + self.lesser.append(l) + self.check_consistency() + + def add_equal(self, e): + for x in self.equal: + if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e: + raise ContradictingAssumptions() + self.equal.append(e) + self.check_consistency() + + def check_consistency(self): + if len(self.equal) > 0: + # we know exact value + for e in self.equal: + for g in self.greater: + if (e <= g) == True: + raise ContradictingAssumptions() + for l in self.lesser: + if (e >= l) == True: + raise ContradictingAssumptions() + else: + # check if any greater > any lesser + for g in self.greater: + for l in self.lesser: + if (g > l) == True: + raise ContradictingAssumptions() + return True + + def num_assumptions(self): + # returns the number of individual assumptions for this symbol + return len(self.greater) + len(self.lesser) + len(self.equal) + + +def propagate_assumptions(x, y, condensed_assumptions): + """ + Assuming x is equal to y, we propagate the assumptions on x to y. E.g. we have x==y and + x<5. Then, this method adds y<5 to the assumptions. + + :param x: A symbol. + :param y: Another symbol equal to x. + :param condensed_assumptions: Current assumptions over all symbols. + """ + if x == y: + return + assum_x = condensed_assumptions[x] + if y not in condensed_assumptions: + condensed_assumptions[y] = Assumptions() + assum_y = condensed_assumptions[y] + for e in assum_x.equal: + if e is not sp.Symbol(y): + assum_y.add_equal(e) + for g in assum_x.greater: + assum_y.add_greater(g) + for l in assum_x.lesser: + assum_y.add_lesser(l) + assum_y.check_consistency() + + +def propagate_assumptions_equal_symbols(condensed_assumptions): + """ + This method handles two things: 1) It generates the substitution dict for all equality assumptions. + And 2) it propagates assumptions too all equal symbols. For each equivalence class, we find a unique + representative using UnionFind. Then, all assumptions get propagates to this symbol using + ``propagate_assumptions``. + + :param condensed_assumptions: Current assumptions over all symbols. + :return: Returns a tuple consisting of 2 substitution dicts. The first one replaces each symbol with + the unique representative of its equivalence class. The second dict replaces each symbol with its numeric + value (if we assume it to be equal some value, e.g. N==5). + """ + # Make one set with unique identifier for each equality class + uf = UnionFind(list(condensed_assumptions)) + for sym in condensed_assumptions: + for other in condensed_assumptions[sym].equal: + if isinstance(other, sp.Symbol): + # we assume sym == other --> union these + uf.union(sym, other.name) + + equality_subs1 = {} + + # For each equivalence class, we now have one unique identifier. + # For each class, we give all the assumptions to this single symbol. + # And we swap each symbol in class for this symbol. + for sym in list(condensed_assumptions): + for other in condensed_assumptions[sym].equal: + if isinstance(other, sp.Symbol): + propagate_assumptions(sym, uf.find(sym), condensed_assumptions) + equality_subs1.update({sym: sp.Symbol(uf.find(sym))}) + + equality_subs2 = {} + # In a second step, each symbol gets replace with its equal number (if present) + # using equality_subs2. + for sym, assum in condensed_assumptions.items(): + for e in assum.equal: + if not isinstance(e, sp.Symbol): + equality_subs2.update({sym: e}) + + # Imagine we have M>N and M==10. We need to deduce N<10 from that. Following code handles that: + for sym, assum in condensed_assumptions.items(): + for g in assum.greater: + if isinstance(g, sp.Symbol): + for e in condensed_assumptions[g.name].equal: + if not isinstance(e, sp.Symbol): + condensed_assumptions[sym].add_greater(e) + assum.greater.remove(g) + for l in assum.lesser: + if isinstance(l, sp.Symbol): + for e in condensed_assumptions[l.name].equal: + if not isinstance(e, sp.Symbol): + condensed_assumptions[sym].add_lesser(e) + assum.lesser.remove(l) + return equality_subs1, equality_subs2 + + +def parse_assumptions(assumptions, array_symbols): + """ + Parses a list of assumptions into substitution dictionaries. Firstly, it gathers all assumptions and + keeps only the strongest ones. Afterwards it constructs two substitution dicts for the equality + assumptions: First dict for symbol==symbol assumptions; second dict for symbol==number assumptions. + The other assumptions get handles by N tuples of substitution dicts (N = max number of concurrent + assumptions for a single symbol). Each tuple is responsible for at most one assumption for each symbol. + First dict in the tuple substitutes the symbol with the assumption; second dict restores the initial symbol. + + :param assumptions: List of assumption strings. + :param array_symbols: List of symbols we assume to be positive, since they are the size of a data container. + :return: Tuple consisting of the 2 dicts responsible for the equality assumptions and the list of size N + reponsible for all other assumptions. + """ + + # TODO: This assumptions system can be improved further, especially the deduction of further assumptions + # from the ones we already have. An example of what is not working currently: + # We have assumptions N>0 N<5 and M>5. + # In the first substitution round we use N>0 and M>5. + # In the second substitution round we use N<5. + # Therefore, Max(M, N) will not be evaluated to M, even though from the input assumptions + # one can clearly deduce M>N. + # This happens since N<5 and M>5 are not in the same substitution round. + # The easiest way to fix this is probably to actually deduce the M>N assumption. + # This guarantees that in some substitution round, we will replace M with N + _p_M, where + # _p_M is some positive symbol. Hence, we would resolve Max(M, N) to N + _p_M, which is M. + + # I suspect there to be many more cases where further assumptions will not be deduced properly. + # But if the user enters assumptions as explicitly as possible, e.g. N<5 M>5 M>N, then everything + # works fine. + + # For each symbol x appearing as a data container size, we can assume x>0. + # TODO (later): Analyze size of shapes more, such that e.g. shape N + 1 --> We can assume N > -1. + # For now we only extract assumptions out of shapes if shape consists of only a single symbol. + for sym in array_symbols: + assumptions.append(f'{sym.name}>0') + + if assumptions is None: + return {}, [({}, {})] + + # Gather assumptions, keeping only the strongest ones for each symbol. + condensed_assumptions: Dict[str, Assumptions] = {} + for a in assumptions: + if '==' in a: + symbol, rhs = a.split('==') + if symbol not in condensed_assumptions: + condensed_assumptions[symbol] = Assumptions() + try: + condensed_assumptions[symbol].add_equal(int(rhs)) + except ValueError: + condensed_assumptions[symbol].add_equal(sp.Symbol(rhs)) + elif '>' in a: + symbol, rhs = a.split('>') + if symbol not in condensed_assumptions: + condensed_assumptions[symbol] = Assumptions() + try: + condensed_assumptions[symbol].add_greater(int(rhs)) + except ValueError: + condensed_assumptions[symbol].add_greater(sp.Symbol(rhs)) + # add the opposite, i.e. for x>y, we add yx + if rhs not in condensed_assumptions: + condensed_assumptions[rhs] = Assumptions() + condensed_assumptions[rhs].add_greater(sp.Symbol(symbol)) + + # Handle equal assumptions. + equality_subs = propagate_assumptions_equal_symbols(condensed_assumptions) + + # How many assumptions does symbol with most assumptions have? + curr_max = -1 + for _, assum in condensed_assumptions.items(): + if assum.num_assumptions() > curr_max: + curr_max = assum.num_assumptions() + + all_subs = [] + for i in range(curr_max): + all_subs.append(({}, {})) + + # Construct all the substitution dicts. In each substitution round we take at most one assumption for each + # symbol. Each round has two dicts: First one swaps in the assumption and second one restores the initial + # symbol. + for sym, assum in condensed_assumptions.items(): + i = 0 + for g in assum.greater: + replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True) + all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g}) + all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g}) + i += 1 + for l in assum.lesser: + replacement_symbol = sp.Symbol(f'_n_{sym}', negative=True, integer=True) + all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + l}) + all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - l}) + i += 1 + + return equality_subs, all_subs diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py index a80e769f64..e592fd11b5 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -328,4 +328,6 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): # now we have a triple (node, oNode, exitCandidates) nodes_oNodes_exits.append((node, oNode, exitCandidates)) + # remove artificial end node + sdfg_nx.remove_node(artificial_end_node) return nodes_oNodes_exits diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index a05fe10266..3549e86a20 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -19,6 +19,9 @@ import warnings from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits +from dace.sdfg.work_depth_analysis.assumptions import parse_assumptions +from dace.transformation.passes.symbol_ssa import StrictSymbolSSA +from dace.transformation.pass_pipeline import FixedPointPipeline def get_array_size_symbols(sdfg): @@ -39,22 +42,6 @@ def get_array_size_symbols(sdfg): return symbols -def posify_certain_symbols(expr, syms_to_posify): - """ - Takes an expression and evaluates it while assuming that certain symbols are positive. - - :param expr: The expression to evaluate. - :param syms_to_posify: List of symbols we assume to be positive. - :note: This is adapted from the Sympy function posify. - """ - - expr = sp.sympify(expr) - - reps = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) for s in syms_to_posify if s.is_positive is None} - expr = expr.subs(reps) - return expr.subs({r: s for s, r in reps.items()}) - - def symeval(val, symbols): """ Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. @@ -64,7 +51,7 @@ def symeval(val, symbols): """ first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} - return val.subs(first_replacement).subs(second_replacement) + return sp.simplify(val.subs(first_replacement).subs(second_replacement)) def evaluate_symbols(base, new): @@ -87,7 +74,14 @@ def count_work_matmul(node, symbols, state): result *= symeval(C_memlet.data.subset.size()[-1], symbols) # K result *= symeval(A_memlet.data.subset.size()[-1], symbols) - return result + return sp.sympify(result) + + +def count_depth_matmul(node, symbols, state): + # optimal depth of a matrix multiplication is O(log(size of shared dimension)): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) + return bigo(sp.log(size_shared_dimension)) def count_work_reduce(node, symbols, state): @@ -102,7 +96,12 @@ def count_work_reduce(node, symbols, state): result *= in_memlet.data.volume else: result = 0 - return result + return sp.sympify(result) + + +def count_depth_reduce(node, symbols, state): + # optimal depth of reduction is log of the work + return bigo(sp.log(count_work_reduce(node, symbols, state))) LIBNODES_TO_WORK = { @@ -111,22 +110,6 @@ def count_work_reduce(node, symbols, state): Reduce: count_work_reduce, } - -def count_depth_matmul(node, symbols, state): - # For now we set it equal to work: see comments in count_depth_reduce just below - return count_work_matmul(node, symbols, state) - - -def count_depth_reduce(node, symbols, state): - # depth of reduction is log2 of the work - # TODO: Can we actually assume this? Or is it equal to the work? - # Another thing to consider is that we essetially do NOT count wcr edges as operations for now... - - # return sp.ceiling(sp.log(count_work_reduce(node, symbols, state), 2)) - # set it equal to work for now - return count_work_reduce(node, symbols, state) - - LIBNODES_TO_DEPTH = { MatMul: count_depth_matmul, Transpose: lambda *args: 0, @@ -254,9 +237,9 @@ def count_depth_code(code): def tasklet_work(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: + # simplified work analysis for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return bigo(oedge.data.num_accesses) - + return oedge.data.num_accesses elif tasklet_node.code.language == dtypes.Language.Python: return count_arithmetic_ops_code(tasklet_node.code.code) else: @@ -267,11 +250,10 @@ def tasklet_work(tasklet_node, state): def tasklet_depth(tasklet_node, state): - # TODO: how to get depth of CPP tasklets? - # For now we use depth == work: if tasklet_node.code.language == dtypes.Language.CPP: + # Depth == work for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return bigo(oedge.data.num_accesses) + return oedge.data.num_accesses if tasklet_node.code.language == dtypes.Language.Python: return count_depth_code(tasklet_node.code.code) else: @@ -282,19 +264,41 @@ def tasklet_depth(tasklet_node, state): def get_tasklet_work(node, state): - return tasklet_work(node, state), -1 + return sp.sympify(tasklet_work(node, state)), sp.sympify(-1) def get_tasklet_work_depth(node, state): - return tasklet_work(node, state), tasklet_depth(node, state) + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) def get_tasklet_avg_par(node, state): - return tasklet_work(node, state), tasklet_depth(node, state) + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) + + +def update_value_map(old, new): + # add new assignments to old + old.update({k: v for k, v in new.items() if k not in old}) + # check for conflicts: + for k, v in new.items(): + if k in old and old[k] != v: + # conflict detected --> forget this mapping completely + old.pop(k) -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, - symbols) -> Tuple[sp.Expr, sp.Expr]: +def do_initial_subs(w, d, eq, subs1): + """ + Calls subs three times for the give (w)ork and (d)epth values. + """ + return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) + + +def sdfg_work_depth(sdfg: SDFG, + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -304,6 +308,11 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana :param w_d_map: Dictionary which will save the result. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the SDFG. """ @@ -313,9 +322,16 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols) - state_works[state] = sp.simplify(state_work * state.executions) - state_depths[state] = sp.simplify(state_depth * state.executions) + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, + detailed_analysis) + + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. + state_work = sp.simplify(state_work * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + state_works[state], state_depths[state] = state_work, state_depth w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and @@ -329,12 +345,18 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # Now we need to go over each triple (node, oNode, exits). For each triple, we # - remove edge (oNode, node), i.e. the backward edge # - for all exits e, add edge (oNode, e). This edge may already exist + # - remove edge from node to exit (if present, i.e. while-do loop) + # - This ensures that every node with > 1 outgoing edge is a branch guard + # - useful for detailed anaylsis. for node, oNode, exits in nodes_oNodes_exits: sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) for e in exits: if len(sdfg.edges_between(oNode, e)) == 0: # no edge there yet sdfg.add_edge(oNode, e, InterstateEdge()) + if len(sdfg.edges_between(node, e)) > 0: + # edge present --> remove it + sdfg.remove_edge(sdfg.edges_between(node, e)[0]) # add a dummy exit to the SDFG, such that each path ends there. dummy_exit = sdfg.add_state('dummy_exit') @@ -345,6 +367,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. work_map: Dict[SDFGState, sp.Expr] = {} depth_map: Dict[SDFGState, sp.Expr] = {} + # Keeps track of assignments done on InterstateEdges. + state_value_map: Dict[SDFGState, Dict[sp.Symbol, sp.Symbol]] = {} # The dummy state has 0 work and depth. state_depths[dummy_exit] = sp.sympify(0) state_works[dummy_exit] = sp.sympify(0) @@ -353,40 +377,67 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions # have been calculated. traversal_q = deque() - traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None)) + traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() + while traversal_q: - state, depth, work, ie = traversal_q.popleft() + state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() if ie is not None: visited.add(ie) - n_depth = sp.simplify(depth + state_depths[state]) - n_work = sp.simplify(work + state_works[state]) + if state in state_value_map: + # update value map: + update_value_map(state_value_map[state], value_map) + else: + state_value_map[state] = value_map + + # ignore assignments such as tmp=x[0], as those do not give much information. + value_map = {k: v for k, v in state_value_map[state].items() if '[' not in k and '[' not in v} + n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) + n_work = sp.simplify((work + state_works[state]).subs(value_map)) # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one # single path with the least average parallelsim (of all paths with more than 0 work). if analyze_tasklet == get_tasklet_avg_par: - if state in depth_map: # and hence als state in work_map - # if current path has 0 depth, we don't do anything. + if state in depth_map: # this means we have already visited this state before + cse = common_subexpr_stack.pop() + # if current path has 0 depth (--> 0 work as well), we don't do anything. if n_depth != 0: - # see if we need to update the work and depth of the current state + # check if we need to update the work and depth of the current state # we update if avg parallelism of new incoming path is less than current avg parallelism - old_avg_par = sp.simplify(work_map[state] / depth_map[state]) - new_avg_par = sp.simplify(n_work / n_depth) - - if depth_map[state] == 0 or new_avg_par < old_avg_par: - # old value was divided by zero or new path gives actually worse avg par, then we keep new value - depth_map[state] = n_depth - work_map[state] = n_work + if depth_map[state] == 0: + # old value was divided by zero --> we take new value anyway + depth_map[state] = cse[1] + n_depth + work_map[state] = cse[0] + n_work + else: + old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) + new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) + # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) + depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), + (depth_map[state], True)) + work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), + (work_map[state], True)) else: depth_map[state] = n_depth work_map[state] = n_work else: # search heaviest and deepest path separately if state in depth_map: # and consequently also in work_map - depth_map[state] = sp.Max(depth_map[state], n_depth) - work_map[state] = sp.Max(work_map[state], n_work) + # This cse value would appear in both arguments of the Max. Hence, for performance reasons, + # we pull it out of the Max expression. + # Example: We do cse + Max(a, b) instead of Max(cse + a, cse + b). + # This increases performance drastically, expecially since we avoid nesting Max expressions + # for cases where cse itself contains Max operators. + cse = common_subexpr_stack.pop() + if detailed_analysis: + # This MAX should be covered in the more detailed analysis + cond = condition_stack.pop() + work_map[state] = cse[0] + sp.Piecewise((work_map[state], sp.Not(cond)), (n_work, cond)) + depth_map[state] = cse[1] + sp.Piecewise((depth_map[state], sp.Not(cond)), (n_depth, cond)) + else: + work_map[state] = cse[0] + sp.Max(work_map[state], n_work) + depth_map[state] = cse[1] + sp.Max(depth_map[state], n_depth) else: depth_map[state] = n_depth work_map[state] = n_work @@ -397,7 +448,22 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana pass else: for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) + if len(out_edges) > 1: + # It is important to copy these stacks. Else both branches operate on the same stack. + # state is a branch guard --> save condition on stack + new_cond_stack = list(condition_stack) + new_cond_stack.append(oedge.data.condition_sympy()) + # same for common_subexr_stack + new_cse_stack = list(common_subexpr_stack) + new_cse_stack.append((work_map[state], depth_map[state])) + # same for value_map + new_value_map = dict(state_value_map[state]) + new_value_map.update({sp.Symbol(k): sp.Symbol(v) for k, v in oedge.data.assignments.items()}) + traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) + else: + value_map.update(oedge.data.assignments) + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, + common_subexpr_stack, value_map)) try: max_depth = depth_map[dummy_exit] @@ -408,16 +474,21 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana raise Exception( 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') - sdfg_result = (sp.simplify(max_work), sp.simplify(max_depth)) + sdfg_result = (max_work, max_depth) w_d_map[get_uuid(sdfg)] = sdfg_result return sdfg_result -def scope_work_depth(state: SDFGState, - w_d_map: Dict[str, sp.Expr], - analyze_tasklet, - symbols, - entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: +def scope_work_depth( + state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + entry: nd.EntryNode = None, + detailed_analysis: bool = False, +) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a scope. This works by traversing through the scope analyzing the work and depth of each encountered node. @@ -430,7 +501,14 @@ def scope_work_depth(state: SDFGState, this can be done in linear time by traversing the graph in topological order. :param state: The state in which the scope to analyze is contained. - :param sym_map: A dictionary mapping symbols to their values. + :param w_d_map: Dictionary saving the final result for each SDFG element. + :param analyze_tasklet: Function used to analyze tasklets. Either analyzes just work, work and depth or average parallelism. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. :return: A tuple containing the work and depth of the scope. """ @@ -447,7 +525,9 @@ def scope_work_depth(state: SDFGState, if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, node, + detailed_analysis) + s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work w_d_map[get_uuid(node, state)] = (s_work, s_depth) @@ -457,8 +537,13 @@ def scope_work_depth(state: SDFGState, elif isinstance(node, nd.Tasklet): # add up work for whole state, but also save work for this node in w_d_map t_work, t_depth = analyze_tasklet(node, state) + # check if tasklet has any outgoing wcr edges + for e in state.out_edges(node): + if e.data.wcr is not None: + t_work += count_arithmetic_ops_code(e.data.wcr) + t_work, t_depth = do_initial_subs(t_work, t_depth, equality_subs, subs1) work += t_work - w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) + w_d_map[get_uuid(node, state)] = (t_work, t_depth) elif isinstance(node, nd.NestedSDFG): # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. # We only want global symbols in our final work depth expressions. @@ -466,18 +551,35 @@ def scope_work_depth(state: SDFGState, nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, equality_subs, + subs1, detailed_analysis) + nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map work += nsdfg_work w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) elif isinstance(node, nd.LibraryNode): - lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) - work += lib_node_work - lib_node_depth = -1 # not analyzed + try: + lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + except KeyError: + # add a symbol to the top level sdfg, such that the user can define it in the extension + top_level_sdfg = state.parent + # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, + # such that the user can define its value. But it doesn't... + # How to achieve this? + top_level_sdfg.add_symbol(f'{node.name}_work', dtypes.int64) + lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) + lib_node_depth = sp.sympify(-1) # not analyzed if analyze_tasklet != get_tasklet_work: # we are analyzing depth - lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + try: + lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + except KeyError: + top_level_sdfg = state.parent + top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) + lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) + lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) + work += lib_node_work w_d_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) if entry is not None: @@ -485,8 +587,8 @@ def scope_work_depth(state: SDFGState, if isinstance(entry, nd.MapEntry): nmap: nd.Map = entry.map range: Range = nmap.range - n_exec = range.num_elements_exact() - work = work * sp.simplify(n_exec) + n_exec = range.num_elements() + work = sp.simplify(work * n_exec.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) else: print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') @@ -510,6 +612,7 @@ def scope_work_depth(state: SDFGState, traversal_q.append((node, sp.sympify(0), None)) # this map keeps track of the length of the longest path ending at each state so far seen. depth_map = {} + wcr_depth_map = {} while traversal_q: node, in_depth, in_edge = traversal_q.popleft() @@ -534,19 +637,51 @@ def scope_work_depth(state: SDFGState, # replace out_edges with the out_edges of the scope exit node out_edges = state.out_edges(exit_node) for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[node], oedge)) + # check for wcr + wcr_depth = sp.sympify(0) + if oedge.data.wcr is not None: + # This division gives us the number of writes to each single memory location, which is the depth + # as these need to be sequential (without assumptions on HW etc). + wcr_depth = oedge.data.volume / oedge.data.subset.num_elements() + if get_uuid(node, state) in wcr_depth_map: + # max + wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], + wcr_depth) + else: + wcr_depth_map[get_uuid(node, state)] = wcr_depth + # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. + wcr_depth = wcr_depth if not isinstance(oedge.dst, nd.MapExit) else sp.sympify(0) + + # only append if it's actually new information + # this e.g. helps for huge nested SDFGs with lots of inputs/outputs inside a map scope + append = True + for n, d, _ in traversal_q: + if oedge.dst == n and depth_map[node] + wcr_depth == d: + append = False + break + if append: + traversal_q.append((oedge.dst, depth_map[node] + wcr_depth, oedge)) + else: + visited.add(oedge) if len(out_edges) == 0 or node == scope_exit: # We have reached an end node --> update max_depth max_depth = sp.Max(max_depth, depth_map[node]) + for uuid in wcr_depth_map: + w_d_map[uuid] = (w_d_map[uuid][0], w_d_map[uuid][1] + wcr_depth_map[uuid]) # summarise work / depth of the whole scope in the dictionary - scope_result = (sp.simplify(work), sp.simplify(max_depth)) + scope_result = (work, max_depth) w_d_map[get_uuid(state)] = scope_result return scope_result -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, - symbols) -> Tuple[sp.Expr, sp.Expr]: +def state_work_depth(state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols, + equality_subs, + subs1, + detailed_analysis=False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -554,13 +689,23 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task :param w_d_map: The result will be saved to this map. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the state. """ - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, None) + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, None, + detailed_analysis) return work, depth -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> None: +def analyze_sdfg(sdfg: SDFG, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + assumptions: [str], + detailed_analysis: bool = False) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -568,12 +713,24 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No condition and an assignment. :param sdfg: The SDFG to analyze. :param w_d_map: Dictionary of SDFG elements to (work, depth) tuples. Result will be saved in here. - :param analyze_tasklet: The function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + :param analyze_tasklet: Function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + :param assumptions: List of strings. Each string corresponds to one assumption for some symbol, e.g. 'N>5'. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). """ # deepcopy such that original sdfg not changed sdfg = deepcopy(sdfg) + # apply SSA pass + pipeline = FixedPointPipeline([StrictSymbolSSA()]) + pipeline.apply_pass(sdfg, {}) + + array_symbols = get_array_size_symbols(sdfg) + # parse assumptions + equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state # will be executed, or to determine upper bounds for that number (such as in the case of branching) for sd in sdfg.all_sdfgs_recursive(): @@ -581,17 +738,36 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols) + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}, + detailed_analysis) - # Note: This posify could be done more often to improve performance. - array_symbols = get_array_size_symbols(sdfg) for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. - v_w = posify_certain_symbols(symeval(v_w, symbols), array_symbols) - v_d = posify_certain_symbols(symeval(v_d, symbols), array_symbols) + v_w, v_d = do_subs(v_w, v_d, all_subs) + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) +def do_subs(work, depth, all_subs): + """ + Handles all substitutions beyond the equality substitutions and the first substitution. + :param work: Some work expression. + :param depth: Some depth expression. + :param all_subs: List of substitution pairs to perform. + :return: Work depth expressions after doing all substitutions. + """ + # first do subs2 of first sub + # then do all the remaining subs + subs2 = all_subs[0][1] if len(all_subs) > 0 else {} + work, depth = sp.simplify(sp.sympify(work).subs(subs2)), sp.simplify(sp.sympify(depth).subs(subs2)) + for i in range(1, len(all_subs)): + subs1, subs2 = all_subs[i] + work, depth = sp.simplify(work.subs(subs1)), sp.simplify(depth.subs(subs1)) + work, depth = sp.simplify(work.subs(subs2)), sp.simplify(depth.subs(subs2)) + return work, depth + + ################################################################################ # Utility functions for running the analysis from the command line ############# ################################################################################ @@ -608,7 +784,9 @@ def main() -> None: choices=['work', 'workDepth', 'avgPar'], default='workDepth', help='Choose what to analyze. Default: workDepth') + parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5') + parser.add_argument("--detailed", action="store_true", help="Turns on detailed mode.") args = parser.parse_args() if not os.path.exists(args.filename): @@ -624,7 +802,7 @@ def main() -> None: sdfg = SDFG.from_file(args.filename) work_depth_map = {} - analyze_sdfg(sdfg, work_depth_map, analyze_tasklet) + analyze_sdfg(sdfg, work_depth_map, analyze_tasklet, args.assume, args.detailed) if args.analyze == 'workDepth': for k, v, in work_depth_map.items(): diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py index 133afe8ae4..05375007df 100644 --- a/tests/sdfg/work_depth_tests.py +++ b/tests/sdfg/work_depth_tests.py @@ -1,14 +1,18 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ import dace as dc -from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth +from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth, parse_assumptions from dace.sdfg.work_depth_analysis.helpers import get_uuid +from dace.sdfg.work_depth_analysis.assumptions import ContradictingAssumptions import sympy as sp from dace.transformation.interstate import NestSDFG from dace.transformation.dataflow import MapExpansion +from pytest import raises + # TODO: add tests for library nodes (e.g. reduce, matMul) +# TODO: add tests for average parallelism N = dc.symbol('N') M = dc.symbol('M') @@ -65,11 +69,11 @@ def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): @dc.program def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): if x[10] > 50: - if x[9] > 50: + if x[9] > 40: z[:] = x + y # N work, 1 depth z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth else: - if y[9] > 50: + if y[9] > 30: for i in range(K): sum += x[i] # K work, K depth else: @@ -153,6 +157,22 @@ def break_while_loop(x: dc.float64[N]): x += 1 +@dc.program +def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive + if x[0] > 5: + x[:] += 1 # N+1 work, 1 depth + else: + for i in range(M): # M work, M depth + y[i + 1] += y[i] + if M > N: + y[:N + 1] += x[:] # N+1 work, 1 depth + else: + x[:M + 1] += y[:] # M+1 work, 1 depth + # --> Work: Max(N+1, M) + Max(N+1, M+1) + # Depth: Max(1, M) + 1 + + +#(sdfg, (expected_work, expected_depth)) tests_cases = [ (single_map, (N, 1)), (single_for_loop, (N, N)), @@ -164,25 +184,18 @@ def break_while_loop(x: dc.float64[N]): (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), (max_of_positive_symbol, (3 * N**2, 3 * N)), (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True) * N, sp.Symbol('num_execs_0_2', - nonnegative=True))), + (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)) * N, - sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)))), - (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7', nonnegative=True) * N, - 2 * sp.Symbol('num_execs_0_7', nonnegative=True))), - (continue_for_loop, (sp.Symbol('num_execs_0_6', nonnegative=True) * N, sp.Symbol('num_execs_0_6', - nonnegative=True))), + (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, sp.Max(1, sp.Symbol('num_execs_0_1')))), + (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), + (continue_for_loop, (sp.Symbol('num_execs_0_6') * N, sp.Symbol('num_execs_0_6'))), (break_for_loop, (N**2, N)), - (break_while_loop, (sp.Symbol('num_execs_0_5', nonnegative=True) * N, sp.Symbol('num_execs_0_5', nonnegative=True))) + (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) ] def test_work_depth(): - good = 0 - failed = 0 - exception = 0 - failed_tests = [] for test, correct in tests_cases: w_d_map = {} sdfg = test.to_sdfg() @@ -190,12 +203,60 @@ def test_work_depth(): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test.name: sdfg.apply_transformations(MapExpansion) - - analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} + res = (res[0].subs(reps), res[1].subs(reps)) + reps = { + s: sp.Symbol(s.name) + for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols) + } + correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) # check result assert correct == res +x, y, z, a = sp.symbols('x y z a') + +# (expr, assumptions, result) +assumptions_tests = [ + (sp.Max(x, y), ['x>y'], x), (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), (sp.Max(x, y), ['x==y'], y), + (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x, 3)), (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', + 'x>3'], 11 + x), + (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), (sp.Max(x, 11), ['x==y', 'x>11'], y), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'a<11', 'c>7'], x + 11), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), (sp.Max(x, y), ['y>x', 'y==1000'], 1000), + (sp.Max(x, y), ['y0', 'N<5', 'M>5'], M) +] + +# These assumptions should trigger the ContradictingAssumptions exception. +tests_for_exception = [['x>10', 'x<9'], ['x==y', 'x>10', 'y<9'], + ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], + ['x==5', 'x<4']] + + +def test_assumption_system(): + for expr, assums, res in assumptions_tests: + equality_subs, all_subs = parse_assumptions(assums, set()) + initial_expr = expr + expr = expr.subs(equality_subs[0]) + expr = expr.subs(equality_subs[1]) + for subs1, subs2 in all_subs: + expr = expr.subs(subs1) + expr = expr.subs(subs2) + assert expr == res + + for assums in tests_for_exception: + # check that the Exception gets raised. + with raises(ContradictingAssumptions): + parse_assumptions(assums, set()) + + if __name__ == '__main__': test_work_depth() + test_assumption_system()