From 32c16a5b489689d1ef00a4fa89df5223b73c1d76 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 29 Oct 2024 13:04:16 +0100 Subject: [PATCH] Re-implement changes made by @luca-patrignani in #1541 --- dace/config_schema.yml | 7 + dace/sdfg/validation.py | 51 ++- .../sdfg/warn_on_potential_data_race_test.py | 316 ++++++++++++++++++ 3 files changed, 367 insertions(+), 7 deletions(-) create mode 100644 tests/sdfg/warn_on_potential_data_race_test.py diff --git a/dace/config_schema.yml b/dace/config_schema.yml index da35e61997..7afb06a50a 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -919,6 +919,13 @@ required: description: > Check for undefined symbols in memlets during SDFG validation. + check_race_conditions: + type: bool + default: false + title: Check race conditions + description: > + Check for potential race conditions during validation. + ############################################# # Features for unit testing diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index e75099276f..f02a5003e9 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -1,17 +1,22 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Exception classes and methods for validation of SDFGs. """ + import copy -from dace.dtypes import DebugInfo import os -from typing import TYPE_CHECKING, Dict, List, Set import warnings +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Set + +import networkx as nx + from dace import dtypes, subsets, symbolic +from dace.dtypes import DebugInfo if TYPE_CHECKING: import dace + from dace.memlet import Memlet from dace.sdfg import SDFG from dace.sdfg import graph as gr - from dace.memlet import Memlet from dace.sdfg.state import ControlFlowRegion ########################################### @@ -34,8 +39,8 @@ def validate_control_flow_region(sdfg: 'SDFG', symbols: dict, references: Set[int] = None, **context: bool): - from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock from dace.sdfg.scope import is_in_scope + from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, SDFGState if len(region.source_nodes()) > 1 and region.start_block is None: raise InvalidSDFGError("Starting block undefined", sdfg, None) @@ -200,7 +205,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context # Avoid import loop from dace import data as dt from dace.codegen.targets import fpga - from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga + from dace.sdfg.scope import is_devicelevel_fpga, is_devicelevel_gpu references = references or set() @@ -383,7 +388,8 @@ def validate_state(state: 'dace.sdfg.SDFGState', from dace.sdfg import SDFG from dace.sdfg import nodes as nd from dace.sdfg import utils as sdutil - from dace.sdfg.scope import scope_contains_scope, is_devicelevel_gpu, is_devicelevel_fpga + from dace.sdfg.scope import (is_devicelevel_fpga, is_devicelevel_gpu, + scope_contains_scope) sdfg = sdfg or state.parent state_id = state_id if state_id is not None else state.parent_graph.node_id(state) @@ -839,6 +845,37 @@ def validate_state(state: 'dace.sdfg.SDFGState', continue raise error + if Config.get_bool('experimental.check_race_conditions'): + node_labels = [] + write_accesses = defaultdict(list) + read_accesses = defaultdict(list) + for node in state.data_nodes(): + node_labels.append(node.label) + write_accesses[node.label].extend( + [{'subset': e.data.dst_subset, 'node': node, 'wcr': e.data.wcr} for e in state.in_edges(node)]) + read_accesses[node.label].extend( + [{'subset': e.data.src_subset, 'node': node} for e in state.out_edges(node)]) + + for node_label in node_labels: + writes = write_accesses[node_label] + reads = read_accesses[node_label] + # Check write-write data races. + for i in range(len(writes)): + for j in range(i+1, len(writes)): + same_or_unreachable_nodes = (writes[i]['node'] == writes[j]['node'] or + not nx.has_path(state.nx, writes[i]['node'], writes[j]['node'])) + no_wcr = writes[i]['wcr'] is None and writes[j]['wcr'] is None + if same_or_unreachable_nodes and no_wcr: + subsets_intersect = subsets.intersects(writes[i]['subset'], writes[j]['subset']) + if subsets_intersect: + warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"') + # Check read-write data races. + for write in writes: + for read in reads: + if (not nx.has_path(state.nx, read['node'], write['node']) and + subsets.intersects(write['subset'], read['subset'])): + warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"') + ######################################## diff --git a/tests/sdfg/warn_on_potential_data_race_test.py b/tests/sdfg/warn_on_potential_data_race_test.py new file mode 100644 index 0000000000..8f17409a2f --- /dev/null +++ b/tests/sdfg/warn_on_potential_data_race_test.py @@ -0,0 +1,316 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import warnings +import dace +import pytest + +def test_memlet_range_not_overlap_ranges(): + sdfg = dace.SDFG('memlet_range_not_overlap_ranges') + state = sdfg.add_state() + N = dace.symbol("N", dtype=dace.int32) + sdfg.add_array("A", (N//2,), dace.int32) + A = state.add_access("A") + sdfg.add_array("B", (N,), dace.int32) + B = state.add_access("B") + state.add_mapped_tasklet( + name="first_tasklet", + code="b = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:N//2"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + state.add_mapped_tasklet( + name="second_tasklet", + code="b = a - 20", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k+N//2")}, + map_ranges={"k": "0:N//2"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + with dace.config.set_temporary("experimental.check_race_conditions", value=True): + sdfg.validate() + + +def test_memlet_range_write_write_overlap_ranges(): + sdfg = dace.SDFG('memlet_range_overlap_ranges') + state = sdfg.add_state() + N = dace.symbol("N", dtype=dace.int32) + sdfg.add_array("A", (N,), dace.int32) + A = state.add_access("A") + sdfg.add_array("B", (N,), dace.int32) + B = state.add_access("B") + state.add_mapped_tasklet( + name="first_tasklet", + code="b = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:N"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + state.add_mapped_tasklet( + name="second_tasklet", + code="b = a - 20", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:N"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + + with pytest.warns(UserWarning): + with dace.config.set_temporary("experimental.check_race_conditions", value=True): + sdfg.validate() + +def test_memlet_range_write_read_overlap_ranges(): + sdfg = dace.SDFG('memlet_range_write_read_overlap_ranges') + state = sdfg.add_state() + N = dace.symbol("N", dtype=dace.int32) + sdfg.add_array("A", (N,), dace.int32) + A_read = state.add_read("A") + A_write = state.add_write("A") + sdfg.add_array("B", (N,), dace.int32) + B = state.add_access("B") + sdfg.add_array("C", (N,), dace.int32) + C = state.add_access("C") + state.add_mapped_tasklet( + name="first_tasklet", + code="b = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:N"}, + external_edges=True, + input_nodes={"A": A_read}, + output_nodes={"B": B} + ) + state.add_mapped_tasklet( + name="second_tasklet", + code="a = c - 20", + inputs={"c": dace.Memlet(data="C", subset="k")}, + outputs={"a": dace.Memlet(data="A", subset="k")}, + map_ranges={"k": "0:N"}, + external_edges=True, + input_nodes={"C": C}, + output_nodes={"A": A_write} + ) + + with pytest.warns(UserWarning): + with dace.config.set_temporary('experimental', 'check_race_conditions', value=True): + sdfg.validate() + +def test_memlet_overlap_ranges_two_access_nodes(): + sdfg = dace.SDFG('memlet_range_write_read_overlap_ranges') + state = sdfg.add_state() + N = dace.symbol("N", dtype=dace.int32) + sdfg.add_array("A", (N,), dace.int32) + A1 = state.add_access("A") + A2 = state.add_access("A") + sdfg.add_array("B", (N,), dace.int32) + B1 = state.add_access("B") + B2 = state.add_access("B") + + state.add_mapped_tasklet( + name="first_tasklet", + code="b = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:N"}, + external_edges=True, + input_nodes={"A": A1}, + output_nodes={"B": B1} + ) + state.add_mapped_tasklet( + name="second_tasklet", + code="b = a - 20", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:N"}, + external_edges=True, + input_nodes={"A": A2}, + output_nodes={"B": B2} + ) + + with pytest.warns(UserWarning): + with dace.config.set_temporary('experimental', 'check_race_conditions', value=True): + sdfg.validate() + +def test_memlet_overlap_symbolic_ranges(): + sdfg = dace.SDFG('memlet_overlap_symbolic_ranges') + state = sdfg.add_state() + N = dace.symbol("N", dtype=dace.int32) + sdfg.add_array("A", (2*N,), dace.int32) + A = state.add_access("A") + sdfg.add_array("B", (2*N,), dace.int32) + B = state.add_access("B") + + state.add_mapped_tasklet( + name="first_tasklet", + code="b = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:N"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + state.add_mapped_tasklet( + name="second_tasklet", + code="b = a - 20", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "0:2*N"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + + with pytest.warns(UserWarning): + with dace.config.set_temporary('experimental', 'check_race_conditions', value=True): + sdfg.validate() + +def test_constant_memlet_overlap(): + sdfg = dace.SDFG('constant_memlet_overlap') + state = sdfg.add_state() + sdfg.add_array("A", (12,), dace.int32) + A = state.add_access("A") + sdfg.add_array("B", (12,), dace.int32) + B = state.add_access("B") + + state.add_mapped_tasklet( + name="first_tasklet", + code="b = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "3:10"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + state.add_mapped_tasklet( + name="second_tasklet", + code="b = a - 20", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "6:12"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + + with pytest.warns(UserWarning): + with dace.config.set_temporary('experimental', 'check_race_conditions', value=True): + sdfg.validate() + +def test_constant_memlet_almost_overlap(): + sdfg = dace.SDFG('constant_memlet_almost_overlap') + state = sdfg.add_state() + sdfg.add_array("A", (20,), dace.int32) + A = state.add_access("A") + sdfg.add_array("B", (20,), dace.int32) + B = state.add_access("B") + + state.add_mapped_tasklet( + name="first_tasklet", + code="b = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "3:10"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + state.add_mapped_tasklet( + name="second_tasklet", + code="b = a - 20", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="k")}, + map_ranges={"k": "10:20"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + with dace.config.set_temporary('experimental', 'check_race_conditions', value=True): + sdfg.validate() + +def test_elementwise_map(): + sdfg = dace.SDFG('elementwise_map') + state = sdfg.add_state() + sdfg.add_array("A", (20,), dace.int32) + A_read = state.add_read("A") + A_write = state.add_write("A") + + state.add_mapped_tasklet( + name="first_tasklet", + code="aa = a + 10", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"aa": dace.Memlet(data="A", subset="k")}, + map_ranges={"k": "0:20"}, + external_edges=True, + input_nodes={"A": A_read}, + output_nodes={"A": A_write} + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + with dace.config.set_temporary('experimental', 'check_race_conditions', value=True): + sdfg.validate() + +def test_memlet_overlap_with_wcr(): + sdfg = dace.SDFG('memlet_overlap_with_wcr') + state = sdfg.add_state() + sdfg.add_array("A", (20,), dace.int32) + sdfg.add_array("B", (1,), dace.int32) + A = state.add_read("A") + B = state.add_write("B") + + state.add_mapped_tasklet( + name="first_reduction", + code="b = a", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="0", wcr="lambda old, new: old + new")}, + map_ranges={"k": "0:20"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + + state.add_mapped_tasklet( + name="second_reduction", + code="b = a", + inputs={"a": dace.Memlet(data="A", subset="k")}, + outputs={"b": dace.Memlet(data="B", subset="0", wcr="lambda old, new: old + new")}, + map_ranges={"k": "0:20"}, + external_edges=True, + input_nodes={"A": A}, + output_nodes={"B": B} + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + with dace.config.set_temporary('experimental', 'check_race_conditions', value=True): + sdfg.validate() + + +if __name__ == '__main__': + test_memlet_range_not_overlap_ranges() + test_memlet_range_write_write_overlap_ranges() + test_memlet_range_write_read_overlap_ranges() + test_memlet_overlap_ranges_two_access_nodes() + test_memlet_overlap_symbolic_ranges() + test_constant_memlet_overlap() + test_constant_memlet_almost_overlap() + test_elementwise_map() + test_memlet_overlap_with_wcr()