Skip to content

Commit

Permalink
Re-implement changes made by @luca-patrignani in #1541
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Oct 29, 2024
1 parent 2070d39 commit 32c16a5
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 7 deletions.
7 changes: 7 additions & 0 deletions dace/config_schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,13 @@ required:
description: >
Check for undefined symbols in memlets during SDFG validation.
check_race_conditions:
type: bool
default: false
title: Check race conditions
description: >
Check for potential race conditions during validation.
#############################################
# Features for unit testing

Expand Down
51 changes: 44 additions & 7 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Exception classes and methods for validation of SDFGs. """

import copy
from dace.dtypes import DebugInfo
import os
from typing import TYPE_CHECKING, Dict, List, Set
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Set

import networkx as nx

from dace import dtypes, subsets, symbolic
from dace.dtypes import DebugInfo

if TYPE_CHECKING:
import dace
from dace.memlet import Memlet
from dace.sdfg import SDFG
from dace.sdfg import graph as gr
from dace.memlet import Memlet
from dace.sdfg.state import ControlFlowRegion

###########################################
Expand All @@ -34,8 +39,8 @@ def validate_control_flow_region(sdfg: 'SDFG',
symbols: dict,
references: Set[int] = None,
**context: bool):
from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock
from dace.sdfg.scope import is_in_scope
from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, SDFGState

if len(region.source_nodes()) > 1 and region.start_block is None:
raise InvalidSDFGError("Starting block undefined", sdfg, None)
Expand Down Expand Up @@ -200,7 +205,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
# Avoid import loop
from dace import data as dt
from dace.codegen.targets import fpga
from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga
from dace.sdfg.scope import is_devicelevel_fpga, is_devicelevel_gpu

references = references or set()

Expand Down Expand Up @@ -383,7 +388,8 @@ def validate_state(state: 'dace.sdfg.SDFGState',
from dace.sdfg import SDFG
from dace.sdfg import nodes as nd
from dace.sdfg import utils as sdutil
from dace.sdfg.scope import scope_contains_scope, is_devicelevel_gpu, is_devicelevel_fpga
from dace.sdfg.scope import (is_devicelevel_fpga, is_devicelevel_gpu,
scope_contains_scope)

sdfg = sdfg or state.parent
state_id = state_id if state_id is not None else state.parent_graph.node_id(state)
Expand Down Expand Up @@ -839,6 +845,37 @@ def validate_state(state: 'dace.sdfg.SDFGState',
continue
raise error

if Config.get_bool('experimental.check_race_conditions'):
node_labels = []
write_accesses = defaultdict(list)
read_accesses = defaultdict(list)
for node in state.data_nodes():
node_labels.append(node.label)
write_accesses[node.label].extend(
[{'subset': e.data.dst_subset, 'node': node, 'wcr': e.data.wcr} for e in state.in_edges(node)])
read_accesses[node.label].extend(
[{'subset': e.data.src_subset, 'node': node} for e in state.out_edges(node)])

for node_label in node_labels:
writes = write_accesses[node_label]
reads = read_accesses[node_label]
# Check write-write data races.
for i in range(len(writes)):
for j in range(i+1, len(writes)):
same_or_unreachable_nodes = (writes[i]['node'] == writes[j]['node'] or
not nx.has_path(state.nx, writes[i]['node'], writes[j]['node']))
no_wcr = writes[i]['wcr'] is None and writes[j]['wcr'] is None
if same_or_unreachable_nodes and no_wcr:
subsets_intersect = subsets.intersects(writes[i]['subset'], writes[j]['subset'])
if subsets_intersect:
warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"')
# Check read-write data races.
for write in writes:
for read in reads:
if (not nx.has_path(state.nx, read['node'], write['node']) and
subsets.intersects(write['subset'], read['subset'])):
warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"')

########################################


Expand Down
Loading

0 comments on commit 32c16a5

Please sign in to comment.