diff --git a/dace/cli/sdfv.py b/dace/cli/sdfv.py index 3be8e1ca45..f503775814 100644 --- a/dace/cli/sdfv.py +++ b/dace/cli/sdfv.py @@ -36,10 +36,15 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): """ # If vscode is open, try to open it inside vscode if filename is None: - if 'VSCODE_IPC_HOOK_CLI' in os.environ or 'VSCODE_GIT_IPC_HANDLE' in os.environ: - filename = tempfile.mktemp(suffix='.sdfg') + if ( + 'VSCODE_IPC_HOOK' in os.environ + or 'VSCODE_IPC_HOOK_CLI' in os.environ + or 'VSCODE_GIT_IPC_HANDLE' in os.environ + ): + fd, filename = tempfile.mkstemp(suffix='.sdfg') sdfg.save(filename) os.system(f'code {filename}') + os.close(fd) return if type(sdfg) is dace.SDFG: diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index e06e2b6a02..e51ee16c2f 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -239,6 +239,7 @@ def get_state_struct(self) -> ctypes.Structure: return ctypes.cast(self._libhandle, ctypes.POINTER(self._try_parse_state_struct())).contents def _try_parse_state_struct(self) -> Optional[Type[ctypes.Structure]]: + from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid import cycle # the path of the main sdfg file containing the state struct main_src_path = os.path.join(os.path.dirname(os.path.dirname(self._lib._library_filename)), "src", "cpu", self._sdfg.name + ".cpp") @@ -247,7 +248,7 @@ def _try_parse_state_struct(self) -> Optional[Type[ctypes.Structure]]: code_flat = code.replace("\n", " ") # try to find the first struct definition that matches the name we are looking for in the sdfg file - match = re.search(f"struct {self._sdfg.name}_t {{(.*?)}};", code_flat) + match = re.search(f"struct {mangle_dace_state_struct_name(self._sdfg)} {{(.*?)}};", code_flat) if match is None or len(match.groups()) != 1: return None diff --git a/dace/codegen/instrumentation/data/data_dump.py b/dace/codegen/instrumentation/data/data_dump.py index 859f78bd79..2217524d19 100644 --- a/dace/codegen/instrumentation/data/data_dump.py +++ b/dace/codegen/instrumentation/data/data_dump.py @@ -195,7 +195,7 @@ def __init__(self): def _generate_report_setter(self, sdfg: SDFG) -> str: return f''' - DACE_EXPORTED void __dace_set_instrumented_data_report({sdfg.name}_t *__state, const char *dirpath) {{ + DACE_EXPORTED void __dace_set_instrumented_data_report({cpp.mangle_dace_state_struct_name(sdfg)} *__state, const char *dirpath) {{ __state->serializer->set_folder(dirpath); }} ''' diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index c0d3b657a1..4885611408 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -12,7 +12,7 @@ from dace.sdfg.graph import SubgraphView from dace.memlet import Memlet from dace.sdfg import scope_contains_scope -from dace.sdfg.state import StateGraphView +from dace.sdfg.state import DataflowGraphView import sympy as sp import os @@ -392,7 +392,7 @@ def should_instrument_entry(map_entry: EntryNode) -> bool: return cond @staticmethod - def has_surrounding_perfcounters(node, dfg: StateGraphView): + def has_surrounding_perfcounters(node, dfg: DataflowGraphView): """ Returns true if there is a possibility that this node is part of a section that is profiled. """ parent = dfg.entry_node(node) @@ -605,7 +605,7 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet): return memlet.volume * memdata.dtype.bytes @staticmethod - def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: StateGraphView): + def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView): scope_dict = sdfg.node(state_id).scope_dict() out_costs = 0 @@ -636,7 +636,10 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: return out_costs @staticmethod - def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: StateGraphView, sdfg: dace.SDFG, state_id: int) -> str: + def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, + dfg: DataflowGraphView, + sdfg: dace.SDFG, + state_id: int) -> str: """ Get the amount of bytes processed by `tasklet`. The formula is sum(inedges * size) + sum(outedges * size) """ in_accum = [] @@ -693,7 +696,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str: return sym2cpp(input_size) @staticmethod - def accumulate_byte_movement(outermost_node, node, dfg: StateGraphView, sdfg, state_id): + def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id): itvars = dict() # initialize an empty dict diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 94159d12d5..f3f1424297 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -34,6 +34,22 @@ from dace.codegen.dispatcher import TargetDispatcher +def mangle_dace_state_struct_name(sdfg: Union[SDFG, str]) -> str: + """This function creates a unique type name for the `SDFG`'s state `struct`. + + The function uses the `compiler.codegen_state_struct_suffix` + configuration entry for deriving the type name of the state `struct`. + + :param sdfg: The SDFG for which the name should be generated. + """ + name = sdfg if isinstance(sdfg, str) else sdfg.name + state_suffix = Config.get("compiler", "codegen_state_struct_suffix") + type_name = f"{name}{state_suffix}" + if not dtypes.validate_name(type_name): + raise ValueError(f"The mangled type name `{type_name}` of the state struct of SDFG '{name}' is invalid.") + return type_name + + def copy_expr( dispatcher, sdfg, diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 268c04b693..bea848c905 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -1532,7 +1532,7 @@ def generate_nsdfg_header(self, sdfg, state, state_id, node, memlet_references, if state_struct: toplevel_sdfg: SDFG = sdfg.sdfg_list[0] - arguments.append(f'{toplevel_sdfg.name}_t *__state') + arguments.append(f'{cpp.mangle_dace_state_struct_name(toplevel_sdfg)} *__state') # Add "__restrict__" keywords to arguments that do not alias with others in the context of this SDFG restrict_args = [] diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index a465d2bbc0..b729b34088 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1,11 +1,8 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import ast -import copy import ctypes import functools -import os import warnings -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Dict, List, Set, Tuple, Union import networkx as nx import sympy @@ -14,7 +11,6 @@ import dace from dace import data as dt from dace import dtypes, registry -from dace import sdfg as sd from dace import subsets, symbolic from dace.codegen import common, cppunparse from dace.codegen.codeobject import CodeObject @@ -23,7 +19,7 @@ from dace.codegen.targets import cpp from dace.codegen.common import update_persistent_desc from dace.codegen.targets.cpp import (codeblock_to_cpp, cpp_array_expr, memlet_copy_to_absolute_strides, sym2cpp, - synchronize_streams, unparse_cr, unparse_cr_split) + synchronize_streams, unparse_cr, mangle_dace_state_struct_name) from dace.codegen.targets.target import IllegalCopy, TargetCodeGenerator, make_absolute from dace.config import Config from dace.frontend import operations @@ -345,12 +341,12 @@ def get_generated_codeobjects(self): {file_header} -DACE_EXPORTED int __dace_init_cuda({sdfg.name}_t *__state{params}); -DACE_EXPORTED int __dace_exit_cuda({sdfg.name}_t *__state); +DACE_EXPORTED int __dace_init_cuda({sdfg_state_name} *__state{params}); +DACE_EXPORTED int __dace_exit_cuda({sdfg_state_name} *__state); {other_globalcode} -int __dace_init_cuda({sdfg.name}_t *__state{params}) {{ +int __dace_init_cuda({sdfg_state_name} *__state{params}) {{ int count; // Check that we are able to run {backend} code @@ -389,7 +385,7 @@ def get_generated_codeobjects(self): return 0; }} -int __dace_exit_cuda({sdfg.name}_t *__state) {{ +int __dace_exit_cuda({sdfg_state_name} *__state) {{ {exitcode} // Synchronize and check for CUDA errors @@ -409,7 +405,7 @@ def get_generated_codeobjects(self): return __err; }} -DACE_EXPORTED bool __dace_gpu_set_stream({sdfg.name}_t *__state, int streamid, gpuStream_t stream) +DACE_EXPORTED bool __dace_gpu_set_stream({sdfg_state_name} *__state, int streamid, gpuStream_t stream) {{ if (streamid < 0 || streamid >= {nstreams}) return false; @@ -419,7 +415,7 @@ def get_generated_codeobjects(self): return true; }} -DACE_EXPORTED void __dace_gpu_set_all_streams({sdfg.name}_t *__state, gpuStream_t stream) +DACE_EXPORTED void __dace_gpu_set_all_streams({sdfg_state_name} *__state, gpuStream_t stream) {{ for (int i = 0; i < {nstreams}; ++i) __state->gpu_context->streams[i] = stream; @@ -427,6 +423,7 @@ def get_generated_codeobjects(self): {localcode} """.format(params=params_comma, + sdfg_state_name=mangle_dace_state_struct_name(self._global_sdfg), initcode=initcode.getvalue(), exitcode=exitcode.getvalue(), other_globalcode=self._globalcode.getvalue(), @@ -445,7 +442,7 @@ def node_dispatch_predicate(self, sdfg, state, node): if hasattr(node, 'schedule'): # NOTE: Works on nodes and scopes if node.schedule in dtypes.GPU_SCHEDULES: return True - if isinstance(node, nodes.NestedSDFG) and CUDACodeGen._in_device_code: + if CUDACodeGen._in_device_code: return True return False @@ -1324,11 +1321,11 @@ def generate_devicelevel_state(self, sdfg, state, function_stream, callsite_stre if write_scope == 'grid': callsite_stream.write("if (blockIdx.x == 0 " - "&& threadIdx.x == 0) " - "{ // sub-graph begin", sdfg, state.node_id) + "&& threadIdx.x == 0) " + "{ // sub-graph begin", sdfg, state.node_id) elif write_scope == 'block': callsite_stream.write("if (threadIdx.x == 0) " - "{ // sub-graph begin", sdfg, state.node_id) + "{ // sub-graph begin", sdfg, state.node_id) else: callsite_stream.write("{ // subgraph begin", sdfg, state.node_id) else: @@ -1567,7 +1564,7 @@ def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, callsite_st self.scope_entry_stream = old_entry_stream self.scope_exit_stream = old_exit_stream - state_param = [f'{self._global_sdfg.name}_t *__state'] + state_param = [f'{mangle_dace_state_struct_name(self._global_sdfg)} *__state'] # Write callback function definition self._localcode.write( @@ -2519,15 +2516,17 @@ def generate_devicelevel_scope(self, sdfg, dfg_scope, state_id, function_stream, def generate_node(self, sdfg, dfg, state_id, node, function_stream, callsite_stream): if self.node_dispatch_predicate(sdfg, dfg, node): # Dynamically obtain node generator according to class name - gen = getattr(self, '_generate_' + type(node).__name__) - gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) - return + gen = getattr(self, '_generate_' + type(node).__name__, False) + if gen is not False: # Not every node type has a code generator here + gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) + return if not CUDACodeGen._in_device_code: self._cpu_codegen.generate_node(sdfg, dfg, state_id, node, function_stream, callsite_stream) return - self._locals.clear_scope(self._code_state.indentation + 1) + if isinstance(node, nodes.ExitNode): + self._locals.clear_scope(self._code_state.indentation + 1) if CUDACodeGen._in_device_code and isinstance(node, nodes.MapExit): return # skip @@ -2591,6 +2590,78 @@ def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, callsite self._cpu_codegen._generate_MapExit(sdfg, dfg, state_id, node, function_stream, callsite_stream) + def _get_thread_id(self) -> str: + result = 'threadIdx.x' + if self._block_dims[1] != 1: + result += f' + ({sym2cpp(self._block_dims[0])}) * threadIdx.y' + if self._block_dims[2] != 1: + result += f' + ({sym2cpp(self._block_dims[0] * self._block_dims[1])}) * threadIdx.z' + return result + + def _get_warp_id(self) -> str: + return f'(({self._get_thread_id()}) / warpSize)' + + def _get_block_id(self) -> str: + result = 'blockIdx.x' + if self._block_dims[1] != 1: + result += f' + gridDim.x * blockIdx.y' + if self._block_dims[2] != 1: + result += f' + gridDim.x * gridDim.y * blockIdx.z' + return result + + def _generate_condition_from_location(self, name: str, index_expr: str, node: nodes.Tasklet, + callsite_stream: CodeIOStream) -> str: + if name not in node.location: + return 0 + + location: Union[int, str, subsets.Range] = node.location[name] + if isinstance(location, str) and ':' in location: + location = subsets.Range.from_string(location) + elif symbolic.issymbolic(location): + location = sym2cpp(location) + + if isinstance(location, subsets.Range): + # Range of indices + if len(location) != 1: + raise ValueError(f'Only one-dimensional ranges are allowed for {name} specialization, {location} given') + begin, end, stride = location[0] + rb, re, rs = sym2cpp(begin), sym2cpp(end), sym2cpp(stride) + cond = '' + cond += f'(({index_expr}) >= {rb}) && (({index_expr}) <= {re})' + if stride != 1: + cond += f' && ((({index_expr}) - {rb}) % {rs} == 0)' + + callsite_stream.write(f'if ({cond}) {{') + else: + # Single-element + callsite_stream.write(f'if (({index_expr}) == {location}) {{') + + return 1 + + def _generate_Tasklet(self, sdfg: SDFG, dfg, state_id: int, node: nodes.Tasklet, function_stream: CodeIOStream, + callsite_stream: CodeIOStream): + generated_preamble_scopes = 0 + if self._in_device_code: + # If location dictionary prescribes that the code should run on a certain group of threads/blocks, + # add condition + generated_preamble_scopes += self._generate_condition_from_location('gpu_thread', self._get_thread_id(), + node, callsite_stream) + generated_preamble_scopes += self._generate_condition_from_location('gpu_warp', self._get_warp_id(), node, + callsite_stream) + generated_preamble_scopes += self._generate_condition_from_location('gpu_block', self._get_block_id(), node, + callsite_stream) + + # Call standard tasklet generation + old_codegen = self._cpu_codegen.calling_codegen + self._cpu_codegen.calling_codegen = self + self._cpu_codegen._generate_Tasklet(sdfg, dfg, state_id, node, function_stream, callsite_stream) + self._cpu_codegen.calling_codegen = old_codegen + + if generated_preamble_scopes > 0: + # Generate appropriate postamble + for i in range(generated_preamble_scopes): + callsite_stream.write('}', sdfg, state_id, node) + def make_ptr_vector_cast(self, *args, **kwargs): return cpp.make_ptr_vector_cast(*args, **kwargs) diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index 413cb751d6..8df8fe94fa 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -652,7 +652,7 @@ def generate_state(self, sdfg: dace.SDFG, state: dace.SDFGState, function_stream kernel_args_opencl = [] # Include state in args - kernel_args_opencl.append(f"{self._global_sdfg.name}_t *__state") + kernel_args_opencl.append(f"{cpp.mangle_dace_state_struct_name(self._global_sdfg)} *__state") kernel_args_call_host.append(f"__state") for is_output, arg_name, arg, interface_id in state_parameters: diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index b1eb42fe60..0db4062976 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -131,6 +131,7 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend: :param global_stream: Stream to write to (global). :param backend: Whose backend this header belongs to. """ + from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid circular import # Hash file include if backend == 'frame': global_stream.write('#include "../../include/hash.h"\n', sdfg) @@ -181,7 +182,7 @@ def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool: # Write state struct structstr = '\n'.join(self.statestruct) global_stream.write(f''' -struct {sdfg.name}_t {{ +struct {mangle_dace_state_struct_name(sdfg)} {{ {structstr} }}; @@ -226,6 +227,7 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre :param callsite_stream: Stream to write to (at call site). """ import dace.library + from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid circular import fname = sdfg.name params = sdfg.signature(arglist=self.arglist) paramnames = sdfg.signature(False, for_call=True, arglist=self.arglist) @@ -255,7 +257,7 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre initparamnames_comma = (', ' + initparamnames) if initparamnames else '' callsite_stream.write( f''' -DACE_EXPORTED void __program_{fname}({fname}_t *__state{params_comma}) +DACE_EXPORTED void __program_{fname}({mangle_dace_state_struct_name(fname)} *__state{params_comma}) {{ __program_{fname}_internal(__state{paramnames_comma}); }}''', sdfg) @@ -263,18 +265,17 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre for target in self._dispatcher.used_targets: if target.has_initializer: callsite_stream.write( - 'DACE_EXPORTED int __dace_init_%s(%s_t *__state%s);\n' % - (target.target_name, sdfg.name, initparams_comma), sdfg) + f'DACE_EXPORTED int __dace_init_{target.target_name}({mangle_dace_state_struct_name(sdfg)} *__state{initparams_comma});\n', sdfg) if target.has_finalizer: callsite_stream.write( - 'DACE_EXPORTED int __dace_exit_%s(%s_t *__state);\n' % (target.target_name, sdfg.name), sdfg) + f'DACE_EXPORTED int __dace_exit_{target.target_name}({mangle_dace_state_struct_name(sdfg)} *__state);\n', sdfg) callsite_stream.write( f""" -DACE_EXPORTED {sdfg.name}_t *__dace_init_{sdfg.name}({initparams}) +DACE_EXPORTED {mangle_dace_state_struct_name(sdfg)} *__dace_init_{sdfg.name}({initparams}) {{ int __result = 0; - {sdfg.name}_t *__state = new {sdfg.name}_t; + {mangle_dace_state_struct_name(sdfg)} *__state = new {mangle_dace_state_struct_name(sdfg)}; """, sdfg) @@ -306,7 +307,7 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre return __state; }} -DACE_EXPORTED int __dace_exit_{sdfg.name}({sdfg.name}_t *__state) +DACE_EXPORTED int __dace_exit_{sdfg.name}({mangle_dace_state_struct_name(sdfg)} *__state) {{ int __err = 0; """, sdfg) @@ -352,6 +353,7 @@ def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeI can be ``CPU_Heap`` or any other ``dtypes.StorageType``); and (2) set the externally-allocated pointer to the generated code's internal state (``__dace_set_external_memory_``). """ + from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid circular import # Collect external arrays ext_arrays: Dict[dtypes.StorageType, List[Tuple[SDFG, str, data.Data]]] = collections.defaultdict(list) @@ -374,7 +376,7 @@ def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeI # Size query functions callsite_stream.write( f''' -DACE_EXPORTED size_t __dace_get_external_memory_size_{storage.name}({sdfg.name}_t *__state{initparams_comma}) +DACE_EXPORTED size_t __dace_get_external_memory_size_{storage.name}({mangle_dace_state_struct_name(sdfg)} *__state{initparams_comma}) {{ return {sym2cpp(size)}; }} @@ -383,7 +385,7 @@ def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeI # Pointer set functions callsite_stream.write( f''' -DACE_EXPORTED void __dace_set_external_memory_{storage.name}({sdfg.name}_t *__state, char *ptr{initparams_comma}) +DACE_EXPORTED void __dace_set_external_memory_{storage.name}({mangle_dace_state_struct_name(sdfg)} *__state, char *ptr{initparams_comma}) {{''', sdfg) offset = 0 @@ -828,7 +830,6 @@ def generate_code(self, code, and a set of targets that have been used in the generation of this SDFG. """ - if len(sdfg_id) == 0 and sdfg.sdfg_id != 0: sdfg_id = '_%d' % sdfg.sdfg_id @@ -923,6 +924,7 @@ def generate_code(self, # Get all environments used in the generated code, including # dependent environments import dace.library # Avoid import loops + from dace.codegen.targets.cpp import mangle_dace_state_struct_name self.environments = dace.library.get_environments_and_dependencies(self._dispatcher.used_environments) self.generate_header(sdfg, header_global_stream, header_stream) @@ -931,7 +933,7 @@ def generate_code(self, params = sdfg.signature(arglist=self.arglist) if params: params = ', ' + params - function_signature = ('void __program_%s_internal(%s_t *__state%s)\n{\n' % (sdfg.name, sdfg.name, params)) + function_signature = f'void __program_{sdfg.name}_internal({mangle_dace_state_struct_name(sdfg)}*__state{params})\n{{' self.generate_footer(sdfg, footer_global_stream, footer_stream) self.generate_external_memory_management(sdfg, footer_stream) diff --git a/dace/codegen/targets/intel_fpga.py b/dace/codegen/targets/intel_fpga.py index d3c46b0069..03a04fda41 100644 --- a/dace/codegen/targets/intel_fpga.py +++ b/dace/codegen/targets/intel_fpga.py @@ -3,8 +3,6 @@ import functools import copy import itertools -import os -import re from six import StringIO import numpy as np @@ -143,19 +141,20 @@ def get_generated_codeobjects(self): params_comma = ', ' + params_comma host_code.write(""" -DACE_EXPORTED int __dace_init_intel_fpga({sdfg.name}_t *__state{signature}) {{{emulation_flag} +DACE_EXPORTED int __dace_init_intel_fpga({sdfg_state_name} *__state{signature}) {{{emulation_flag} __state->fpga_context = new dace_fpga_context(); __state->fpga_context->Get().MakeProgram({kernel_file_name}); return 0; }} -DACE_EXPORTED int __dace_exit_intel_fpga({sdfg.name}_t *__state) {{ +DACE_EXPORTED int __dace_exit_intel_fpga({sdfg_state_name} *__state) {{ delete __state->fpga_context; return 0; }} {host_code}""".format(signature=params_comma, sdfg=self._global_sdfg, + sdfg_state_name=cpp.mangle_dace_state_struct_name(self._global_sdfg), emulation_flag=emulation_flag, kernel_file_name=kernel_file_name, host_code="".join([ diff --git a/dace/codegen/targets/mpi.py b/dace/codegen/targets/mpi.py index 419334ba5a..0bb2b67a7e 100644 --- a/dace/codegen/targets/mpi.py +++ b/dace/codegen/targets/mpi.py @@ -4,6 +4,7 @@ from dace.codegen.prettycode import CodeIOStream from dace.codegen.codeobject import CodeObject from dace.codegen.targets.target import TargetCodeGenerator, make_absolute +from dace.codegen.targets.cpp import mangle_dace_state_struct_name from dace.sdfg import nodes, SDFG from dace.config import Config @@ -45,10 +46,10 @@ def get_generated_codeobjects(self): {file_header} -DACE_EXPORTED int __dace_init_mpi({sdfg.name}_t *__state{params}); -DACE_EXPORTED int __dace_exit_mpi({sdfg.name}_t *__state); +DACE_EXPORTED int __dace_init_mpi({sdfg_state_name} *__state{params}); +DACE_EXPORTED int __dace_exit_mpi({sdfg_state_name} *__state); -int __dace_init_mpi({sdfg.name}_t *__state{params}) {{ +int __dace_init_mpi({sdfg_state_name} *__state{params}) {{ int isinit = 0; if (MPI_Initialized(&isinit) != MPI_SUCCESS) return 1; @@ -66,7 +67,7 @@ def get_generated_codeobjects(self): return 0; }} -int __dace_exit_mpi({sdfg.name}_t *__state) {{ +int __dace_exit_mpi({sdfg_state_name} *__state) {{ MPI_Comm_free(&__dace_mpi_comm); MPI_Finalize(); @@ -74,7 +75,7 @@ def get_generated_codeobjects(self): __dace_comm_size); return 0; }} -""".format(params=params_comma, sdfg=sdfg, file_header=fileheader.getvalue()), 'cpp', MPICodeGen, 'MPI') +""".format(params=params_comma, sdfg=sdfg, sdfg_state_name=mangle_dace_state_struct_name(sdfg), file_header=fileheader.getvalue()), 'cpp', MPICodeGen, 'MPI') return [codeobj] @staticmethod diff --git a/dace/codegen/targets/rtl.py b/dace/codegen/targets/rtl.py index dcb752e215..935615fad6 100644 --- a/dace/codegen/targets/rtl.py +++ b/dace/codegen/targets/rtl.py @@ -1,8 +1,8 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import itertools - from typing import List, Tuple, Dict +import warnings from dace import dtypes, config, registry, symbolic, nodes, sdfg, data from dace.sdfg import graph, state, find_input_arraynode, find_output_arraynode @@ -102,6 +102,21 @@ def copy_memory(self, sdfg: sdfg.SDFG, dfg: state.StateSubgraphView, state_id: i elif isinstance(arr, data.Scalar): line: str = "{} {} = {};".format(dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, edge.src.data) + elif isinstance(arr, data.Stream): + # TODO Streams are currently unsupported, as the proper + # behaviour has to be implemented to avoid deadlocking. It + # is only a warning, as the RTL backend is partially used + # by the Xilinx backend, which may hit this case, but will + # discard the errorneous code. + warnings.warn( + 'Streams are currently unsupported by the RTL backend.' \ + 'This may produce errors or deadlocks in the generated code.' + ) + line: str = "// WARNING: Unsupported read from ({}) variable '{}' from stream '{}'." \ + " This may lead to a deadlock if used in code.\n".format( + dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, edge.src_conn) + line += "{} {} = {}.pop();".format( + dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, edge.src.data) elif isinstance(edge.src, nodes.MapEntry) and isinstance(edge.dst, nodes.Tasklet): rtl_name = self.unique_name(edge.dst, sdfg.nodes()[state_id], sdfg) self.n_unrolled[rtl_name] = symbolic.evaluate(edge.src.map.range[0][1] + 1, sdfg.constants) diff --git a/dace/codegen/targets/xilinx.py b/dace/codegen/targets/xilinx.py index 5d82cfeafc..0c562c59c5 100644 --- a/dace/codegen/targets/xilinx.py +++ b/dace/codegen/targets/xilinx.py @@ -7,6 +7,7 @@ import re import numpy as np import ast + import dace from dace import data as dt, registry, dtypes, subsets from dace.config import Config @@ -141,7 +142,7 @@ def get_generated_codeobjects(self): params_comma = ', ' + params_comma host_code.write(""" -DACE_EXPORTED int __dace_init_xilinx({sdfg.name}_t *__state{signature}) {{ +DACE_EXPORTED int __dace_init_xilinx({sdfg_state_name} *__state{signature}) {{ {environment_variables} __state->fpga_context = new dace_fpga_context(); @@ -149,13 +150,14 @@ def get_generated_codeobjects(self): return 0; }} -DACE_EXPORTED int __dace_exit_xilinx({sdfg.name}_t *__state) {{ +DACE_EXPORTED int __dace_exit_xilinx({sdfg_state_name} *__state) {{ delete __state->fpga_context; return 0; }} {host_code}""".format(signature=params_comma, sdfg=self._global_sdfg, + sdfg_state_name=cpp.mangle_dace_state_struct_name(self._global_sdfg), environment_variables=set_env_vars, kernel_file_name=kernel_file_name, host_code="".join([ diff --git a/dace/config_schema.yml b/dace/config_schema.yml index 08a427aa52..063815e319 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -164,6 +164,15 @@ required: of the code generator that generated it. Used for debugging code generation. + codegen_state_struct_suffix: + type: str + default: "_state_t" + title: Suffix used by the code generator to mangle the state struct. + description: > + For every SDFG the code generator is is processing a state struct is generated. + The typename of this struct is derived by appending this value to the SDFG's name. + Note that the suffix may only contains letters, digits and underscores. + default_data_types: type : str default: Python diff --git a/dace/data.py b/dace/data.py index 3395b84fe2..3c9c9edbcf 100644 --- a/dace/data.py +++ b/dace/data.py @@ -1,8 +1,10 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import aenum import copy as cp import ctypes import functools +from abc import ABC, abstractmethod from collections import OrderedDict from numbers import Number from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -504,6 +506,701 @@ def __getitem__(self, s): if isinstance(s, list) or isinstance(s, tuple): return StructArray(self, tuple(s)) return StructArray(self, (s, )) + + +class TensorIterationTypes(aenum.AutoNumberEnum): + """ + Types of tensor iteration capabilities. + + Value (Coordinate Value Iteration) allows to directly iterate over + coordinates such as when using the Dense index type. + + Position (Coordinate Position Iteratation) iterates over coordinate + positions, at which the actual coordinates lie. This is for example the case + with a compressed index, in which the pos array enables one to iterate over + the positions in the crd array that hold the actual coordinates. + """ + Value = () + Position = () + + +class TensorAssemblyType(aenum.AutoNumberEnum): + """ + Types of possible assembly strategies for the individual indices. + + NoAssembly: Assembly is not possible as such. + + Insert: index allows inserting elements at random (e.g. Dense) + + Append: index allows appending to a list of existing coordinates. Depending + on append order, this affects whether the index is ordered or not. This + could be changed by sorting the index after assembly + """ + NoAssembly = () + Insert = () + Append = () + + +class TensorIndex(ABC): + """ + Abstract base class for tensor index implementations. + """ + + @property + @abstractmethod + def iteration_type(self) -> TensorIterationTypes: + """ + Iteration capability supported by this index. + + See TensorIterationTypes for reference. + """ + pass + + @property + @abstractmethod + def locate(self) -> bool: + """ + True if the index supports locate (aka random access), False otw. + """ + pass + + @property + @abstractmethod + def assembly(self) -> TensorAssemblyType: + """ + What assembly type is supported by the index. + + See TensorAssemblyType for reference. + """ + pass + + @property + @abstractmethod + def full(self) -> bool: + """ + True if the level is full, False otw. + + A level is considered full if it encompasses all valid coordinates along + the corresponding tensor dimension. + """ + pass + + @property + @abstractmethod + def ordered(self) -> bool: + """ + True if the level is ordered, False otw. + + A level is ordered when all coordinates that share the same ancestor are + ordered by increasing value (e.g. in typical CSR). + """ + pass + + @property + @abstractmethod + def unique(self) -> bool: + """ + True if coordinate in the level are unique, False otw. + + A level is considered unique if no collection of coordinates that share + the same ancestor contains duplicates. In CSR this is True, in COO it is + not. + """ + pass + + @property + @abstractmethod + def branchless(self) -> bool: + """ + True if the level doesn't branch, false otw. + + A level is considered branchless if no coordinate has a sibling (another + coordinate with same ancestor) and all coordinates in parent level have + a child. In other words if there is a bijection between the coordinates + in this level and the parent level. An example of the is the Singleton + index level in the COO format. + """ + pass + + @property + @abstractmethod + def compact(self) -> bool: + """ + True if the level is compact, false otw. + + A level is compact if no two coordinates are separated by an unlabled + node that does not encode a coordinate. An example of a compact level + can be found in CSR, while the DIA formats range and offset levels are + not compact (they have entries that would coorespond to entries outside + the tensors index range, e.g. column -1). + """ + pass + + @abstractmethod + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + """ + Generates the fields needed for the index. + + :returns: a Dict of fields that need to be present in the struct + """ + pass + + + def to_json(self): + attrs = serialize.all_properties_to_json(self) + + retdict = {"type": type(self).__name__, "attributes": attrs} + + return retdict + + + @classmethod + def from_json(cls, json_obj, context=None): + + # Selecting proper subclass + if json_obj['type'] == "TensorIndexDense": + self = TensorIndexDense.__new__(TensorIndexDense) + elif json_obj['type'] == "TensorIndexCompressed": + self = TensorIndexCompressed.__new__(TensorIndexCompressed) + elif json_obj['type'] == "TensorIndexSingleton": + self = TensorIndexSingleton.__new__(TensorIndexSingleton) + elif json_obj['type'] == "TensorIndexRange": + self = TensorIndexRange.__new__(TensorIndexRange) + elif json_obj['type'] == "TensorIndexOffset": + self = TensorIndexOffset.__new__(TensorIndexOffset) + else: + raise TypeError(f"Invalid data type, got: {json_obj['type']}") + + serialize.set_properties_from_json(self, json_obj['attributes'], context=context) + + return self + + +@make_properties +class TensorIndexDense(TensorIndex): + """ + Dense tensor index. + + Levels of this type encode the the coordinate in the interval [0, N), where + N is the size of the corresponding dimension. This level doesn't need any + index structure beyond the corresponding dimension size. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return True + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Insert + + @property + def full(self) -> bool: + return True + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return {} + + def __repr__(self) -> str: + s = "Dense" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexCompressed(TensorIndex): + """ + Tensor level that stores coordinates in segmented array. + + Levels of this type are compressed using a segented array. The pos array + holds the start and end positions of the segment in the crd (coordinate) + array that holds the child coordinates corresponding the parent. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, + full: bool = False, + ordered: bool = True, + unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_pos": dtypes.int32[dummy_symbol], # TODO (later) choose better length + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Compressed" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexSingleton(TensorIndex): + """ + Tensor index that encodes a single coordinate per parent coordinate. + + Levels of this type hold exactly one coordinate for every coordinate in the + parent level. An example can be seen in the COO format, where every + coordinate but the first is encoded in this manner. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return True + + def __init__(self, + full: bool = False, + ordered: bool = True, + unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Singleton" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexRange(TensorIndex): + """ + Tensor index that encodes a interval of coordinates for every parent. + + The interval is computed from an offset for each parent together with the + tensor dimension size of this level (M) and the parent level (N) parents + corresponding tensor. Given the parent coordinate i, the level encodes the + range of coordinates between max(0, -offset[i]) and min(N, M - offset[i]). + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Range" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexOffset(TensorIndex): + """ + Tensor index that encodes the next coordinates as offset from parent. + + Given a parent coordinate i and an offset index k, the level encodes the + coordinate j = i + offset[k]. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Offset" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class Tensor(Structure): + """ + Abstraction for Tensor storage format. + + This abstraction is based on [https://doi.org/10.1145/3276493]. + """ + + value_dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) + tensor_shape = ShapeProperty(default=[]) + indices = ListProperty(element_type=TensorIndex) + index_ordering = ListProperty(element_type=symbolic.SymExpr) + value_count = SymbolicProperty(default=0) + + def __init__( + self, + value_dtype: dtypes.Typeclasses, + tensor_shape, + indices: List[Tuple[TensorIndex, Union[int, symbolic.SymExpr]]], + value_count: symbolic.SymExpr, + name: str, + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + """ + Constructor for Tensor storage format. + + Below are examples of common matrix storage formats: + + .. code-block:: python + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + + csr = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 0), (dace.data.Compressed(), 1)], + nnz, + "CSR_Matrix", + ) + + csc = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 1), (dace.data.Compressed(), 0)], + nnz, + "CSC_Matrix", + ) + + coo = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(), 1), + ], + nnz, + "CSC_Matrix", + ) + + num_diags = dace.symbol('num_diags') # number of diagonals stored + + diag = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Dense(), num_diags), + (dace.data.Range(), 0), + (dace.data.Offset(), 1), + ], + nnz, + "DIA_Matrix", + ) + + Below you can find examples of common 3rd order tensor storage formats: + + .. code-block:: python + + I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) + + coo = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(unique=False), 1), + (dace.data.Singleton(), 2), + ], + nnz, + "COO_3D_Tensor", + ) + + csf = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(), 0), + (dace.data.Compressed(), 1), + (dace.data.Compressed(), 2), + ], + nnz, + "CSF_3D_Tensor", + ) + + :param value_type: data type of the explicitly stored values. + :param tensor_shape: logical shape of tensor (#rows, #cols, etc...) + :param indices: + a list of tuples, each tuple represents a level in the tensor + storage hirachy, specifying the levels tensor index type, and the + corresponding dimension this level encodes (as index of the + tensor_shape tuple above). The order of the dimensions may differ + from the logical shape of the tensor, e.g. as seen in the CSC + format. If an index's dimension is unrelated to the tensor shape + (e.g. in diagonal format where the first index's dimension is the + number of diagonals stored), a symbol can be specified instead. + :param value_count: number of explicitly stored values. + :param name: name of resulting struct. + :param others: See Structure class for remaining arguments + """ + + self.value_dtype = value_dtype + self.tensor_shape = tensor_shape + self.value_count = value_count + + indices, index_ordering = zip(*indices) + self.indices, self.index_ordering = list(indices), list(index_ordering) + + num_dims = len(tensor_shape) + dimension_order = [idx for idx in self.index_ordering if isinstance(idx, int)] + + # all tensor dimensions must occure exactly once in indices + if not sorted(dimension_order) == list(range(num_dims)): + raise TypeError(( + f"All tensor dimensions must be refferenced exactly once in " + f"tensor indices. (referenced dimensions: {dimension_order}; " + f"tensor dimensions: {list(range(num_dims))})" + )) + + # assembling permanent and index specific fields + fields = dict( + order=Scalar(dtypes.int32), + dim_sizes=dtypes.int32[num_dims], + value_count=value_count, + values=dtypes.float32[value_count], + ) + + for (lvl, index) in enumerate(indices): + fields.update(index.fields(lvl, value_count)) + + super(Tensor, self).__init__(fields, name, transient, storage, location, + lifetime, debuginfo) + + def __repr__(self): + return f"{self.name} (dtype: {self.value_dtype}, shape: {list(self.tensor_shape)}, indices: {self.indices})" + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Tensor': + raise TypeError("Invalid data type") + + # Create dummy object + tensor = Tensor.__new__(Tensor) + serialize.set_properties_from_json(tensor, json_obj, context=context) + + return tensor @make_properties diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index d95fa87e58..332c3a563f 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -6,7 +6,10 @@ import copy from dace.frontend.fortran import ast_internal_classes from dace.frontend.fortran.ast_internal_classes import FNode, Name_Node -from typing import Any, List, Tuple, Type, TypeVar, Union, overload +from typing import Any, List, Tuple, Type, TypeVar, Union, overload, TYPE_CHECKING + +if TYPE_CHECKING: + from dace.frontend.fortran.intrinsics import FortranIntrinsics #We rely on fparser to provide an initial AST and convert to a version that is more suitable for our purposes @@ -122,6 +125,8 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "DOUBLE PRECISION": "DOUBLE", "REAL": "REAL", } + from dace.frontend.fortran.intrinsics import FortranIntrinsics + self.intrinsic_handler = FortranIntrinsics() self.supported_fortran_syntax = { "str": self.str_node, "tuple": self.tuple_node, @@ -242,7 +247,7 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Level_2_Unary_Expr": self.level_2_expr, "Mult_Operand": self.power_expr, "Parenthesis": self.parenthesis_expr, - "Intrinsic_Name": self.intrinsic_name, + "Intrinsic_Name": self.intrinsic_handler.replace_function_name, "Intrinsic_Function_Reference": self.intrinsic_function_reference, "Only_List": self.only_list, "Structure_Constructor": self.structure_constructor, @@ -256,6 +261,9 @@ def __init__(self, ast: f03.Program, tables: symbol_table.SymbolTables): "Allocate_Shape_Spec_List": self.allocate_shape_spec_list, } + def fortran_intrinsics(self) -> "FortranIntrinsics": + return self.intrinsic_handler + def list_tables(self): for i in self.tables._symbol_tables: print(i) @@ -395,65 +403,12 @@ def structure_constructor(self, node: FASTNode): args = get_child(children, ast_internal_classes.Component_Spec_List_Node) return ast_internal_classes.Structure_Constructor_Node(name=name, args=args.args, type=None) - def intrinsic_name(self, node: FASTNode): - name = node.string - replacements = { - "INT": "__dace_int", - "DBLE": "__dace_dble", - "SQRT": "sqrt", - "COSH": "cosh", - "ABS": "abs", - "MIN": "min", - "MAX": "max", - "EXP": "exp", - "EPSILON": "__dace_epsilon", - "TANH": "tanh", - "SUM": "__dace_sum", - "SIGN": "__dace_sign", - "EXP": "exp", - "SELECTED_INT_KIND": "__dace_selected_int_kind", - "SELECTED_REAL_KIND": "__dace_selected_real_kind", - } - return ast_internal_classes.Name_Node(name=replacements[name]) - def intrinsic_function_reference(self, node: FASTNode): children = self.create_children(node) line = get_line(node) name = get_child(children, ast_internal_classes.Name_Node) args = get_child(children, ast_internal_classes.Arg_List_Node) - if name.name == "__dace_selected_int_kind": - import math - return ast_internal_classes.Int_Literal_Node(value=str( - math.ceil((math.log2(math.pow(10, int(args.args[0].value))) + 1) / 8)), - line_number=line) - # This selects the smallest kind that can hold the given number of digits (fp64,fp32 or fp16) - elif name.name == "__dace_selected_real_kind": - if int(args.args[0].value) >= 9 or int(args.args[1].value) > 126: - return ast_internal_classes.Int_Literal_Node(value="8", line_number=line) - elif int(args.args[0].value) >= 3 or int(args.args[1].value) > 14: - return ast_internal_classes.Int_Literal_Node(value="4", line_number=line) - else: - return ast_internal_classes.Int_Literal_Node(value="2", line_number=line) - - func_types = { - "__dace_int": "INT", - "__dace_dble": "DOUBLE", - "sqrt": "DOUBLE", - "cosh": "DOUBLE", - "abs": "DOUBLE", - "min": "DOUBLE", - "max": "DOUBLE", - "exp": "DOUBLE", - "__dace_epsilon": "DOUBLE", - "tanh": "DOUBLE", - "__dace_sum": "DOUBLE", - "__dace_sign": "DOUBLE", - "exp": "DOUBLE", - "__dace_selected_int_kind": "INT", - "__dace_selected_real_kind": "INT", - } - call_type = func_types[name.name] - return ast_internal_classes.Call_Expr_Node(name=name, type=call_type, args=args.args, line_number=line) + return self.intrinsic_handler.replace_function_reference(name, args, line) def function_stmt(self, node: FASTNode): raise NotImplementedError( diff --git a/dace/frontend/fortran/ast_internal_classes.py b/dace/frontend/fortran/ast_internal_classes.py index 70a43e21b8..d1e68572de 100644 --- a/dace/frontend/fortran/ast_internal_classes.py +++ b/dace/frontend/fortran/ast_internal_classes.py @@ -386,3 +386,7 @@ class Use_Stmt_Node(FNode): class Write_Stmt_Node(FNode): _attributes = () _fields = ('args', ) + +class Break_Node(FNode): + _attributes = () + _fields = () diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index e2a7246aed..0c96560fba 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -181,9 +181,11 @@ def __init__(self, funcs=None): if funcs is None: funcs = [] self.funcs = funcs + + from dace.frontend.fortran.intrinsics import FortranIntrinsics self.excepted_funcs = [ - "malloc", "exp", "pow", "sqrt", "cbrt", "max", "abs", "min", "__dace_sum", "__dace_sign", "tanh", - "__dace_epsilon" + "malloc", "exp", "pow", "sqrt", "cbrt", "max", "abs", "min", "__dace_sign", "tanh", + "__dace_epsilon", *FortranIntrinsics.function_names() ] def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): @@ -215,8 +217,10 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): if hasattr(node, "subroutine"): if node.subroutine is True: stop = True + + from dace.frontend.fortran.intrinsics import FortranIntrinsics if not stop and node.name.name not in [ - "malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon" + "malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() ]: self.nodes.append(node) return self.generic_visit(node) @@ -236,7 +240,8 @@ def __init__(self, count=0): def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): - if node.name.name in ["malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon"]: + from dace.frontend.fortran.intrinsics import FortranIntrinsics + if node.name.name in ["malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]: return self.generic_visit(node) if hasattr(node, "subroutine"): if node.subroutine is True: @@ -262,31 +267,14 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No if res is not None: for i in range(0, len(res)): - if (res[i].name.name == "__dace_sum"): - newbody.append( - ast_internal_classes.Decl_Stmt_Node(vardecl=[ - ast_internal_classes.Var_Decl_Node( - name="tmp_call_" + str(temp), - type=res[i].type, - sizes=None, - ) - ])) - newbody.append( - ast_internal_classes.BinOp_Node(lval=ast_internal_classes.Name_Node(name="tmp_call_" + - str(temp)), - op="=", - rval=ast_internal_classes.Int_Literal_Node(value="0"), - line_number=child.line_number)) - else: - - newbody.append( - ast_internal_classes.Decl_Stmt_Node(vardecl=[ - ast_internal_classes.Var_Decl_Node( - name="tmp_call_" + str(temp), - type=res[i].type, - sizes=None, - ) - ])) + newbody.append( + ast_internal_classes.Decl_Stmt_Node(vardecl=[ + ast_internal_classes.Var_Decl_Node( + name="tmp_call_" + str(temp), + type=res[i].type, + sizes=None + ) + ])) newbody.append( ast_internal_classes.BinOp_Node(op="=", lval=ast_internal_classes.Name_Node(name="tmp_call_" + @@ -344,6 +332,8 @@ def visit(self, node: ast_internal_classes.FNode, parent_node: Optional[ast_inte elif isinstance(value, ast_internal_classes.FNode): self.visit(value, node) + return node + class ScopeVarsDeclarations(NodeVisitor): """ Creates a mapping (scope name, variable name) -> variable declaration. @@ -458,7 +448,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No if self.normalize_offsets: # Find the offset of a variable to which we are assigning - var_name = child.lval.name.name + var_name = "" + if isinstance(j, ast_internal_classes.Name_Node): + var_name = j.name + else: + var_name = j.name.name variable = self.scope_vars.get_var(child.parent, var_name) offset = variable.offsets[idx] @@ -714,35 +708,19 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No return -class SumLoopNodeLister(NodeVisitor): - """ - Finds all sum operations that have to be transformed to loops in the AST - """ - def __init__(self): - self.nodes: List[ast_internal_classes.FNode] = [] - - def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): - - if isinstance(node.rval, ast_internal_classes.Call_Expr_Node): - if node.rval.name.name == "__dace_sum": - self.nodes.append(node) - - def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): - return - - def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, ranges: list, rangepos: list, + rangeslen: list, count: int, newbody: list, scope_vars: ScopeVarsDeclarations, - declaration=True, - is_sum_to_loop=False): + declaration=True): """ Helper function for the transformation of array operations and sums to loops :param node: The AST to be transformed :param ranges: The ranges of the loop + :param rangeslength: The length of ranges of the loop :param rangepos: The positions of the ranges :param count: The current count of the loop :param newbody: The new basic block that will contain the loop @@ -753,6 +731,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, currentindex = 0 indices = [] + offsets = scope_vars.get_var(node.parent, node.name.name).offsets for idx, i in enumerate(node.indices): @@ -786,9 +765,24 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, rval=ast_internal_classes.Int_Literal_Node(value="1") ) ranges.append([lower_boundary, upper_boundary]) + rangeslen.append(-1) else: ranges.append([i.range[0], i.range[1]]) + + start = 0 + if isinstance(i.range[0], ast_internal_classes.Int_Literal_Node): + start = int(i.range[0].value) + else: + start = i.range[0] + + end = 0 + if isinstance(i.range[1], ast_internal_classes.Int_Literal_Node): + end = int(i.range[1].value) + else: + end = i.range[1] + + rangeslen.append(end - start + 1) rangepos.append(currentindex) if declaration: newbody.append( @@ -828,7 +822,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No val = child.rval ranges = [] rangepos = [] - par_Decl_Range_Finder(current, ranges, rangepos, self.count, newbody, self.scope_vars, True) + par_Decl_Range_Finder(current, ranges, rangepos, [], self.count, newbody, self.scope_vars, True) if res_range is not None and len(res_range) > 0: rvals = [i for i in mywalk(val) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] @@ -836,7 +830,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No rangeposrval = [] rangesrval = [] - par_Decl_Range_Finder(i, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False) + par_Decl_Range_Finder(i, rangesrval, rangeposrval, [], self.count, newbody, self.scope_vars, False) for i, j in zip(ranges, rangesrval): if i != j: @@ -905,83 +899,6 @@ def mywalk(node): todo.extend(iter_child_nodes(node)) yield node - -class SumToLoop(NodeTransformer): - """ - Transforms the AST by removing array sums and replacing them with loops - """ - def __init__(self, ast): - self.count = 0 - ParentScopeAssigner().visit(ast) - self.scope_vars = ScopeVarsDeclarations() - self.scope_vars.visit(ast) - - def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): - newbody = [] - for child in node.execution: - lister = SumLoopNodeLister() - lister.visit(child) - res = lister.nodes - if res is not None and len(res) > 0: - - current = child.lval - val = child.rval - rvals = [i for i in mywalk(val) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] - if len(rvals) != 1: - raise NotImplementedError("Only one array can be summed") - val = rvals[0] - rangeposrval = [] - rangesrval = [] - - par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False, True) - - range_index = 0 - body = ast_internal_classes.BinOp_Node(lval=current, - op="=", - rval=ast_internal_classes.BinOp_Node( - lval=current, - op="+", - rval=val, - line_number=child.line_number), - line_number=child.line_number) - for i in rangesrval: - initrange = i[0] - finalrange = i[1] - init = ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="=", - rval=initrange, - line_number=child.line_number) - cond = ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="<=", - rval=finalrange, - line_number=child.line_number) - iter = ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="=", - rval=ast_internal_classes.BinOp_Node( - lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), - op="+", - rval=ast_internal_classes.Int_Literal_Node(value="1")), - line_number=child.line_number) - current_for = ast_internal_classes.Map_Stmt_Node( - init=init, - cond=cond, - iter=iter, - body=ast_internal_classes.Execution_Part_Node(execution=[body]), - line_number=child.line_number) - body = current_for - range_index += 1 - - newbody.append(body) - - self.count = self.count + range_index - else: - newbody.append(self.visit(child)) - return ast_internal_classes.Execution_Part_Node(execution=newbody) - - class RenameVar(NodeTransformer): def __init__(self, oldname: str, newname: str): self.oldname = oldname diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index b15435f4ff..21f61a171a 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -66,6 +66,7 @@ def __init__(self, ast: ast_components.InternalFortranAst, source: str): ast_internal_classes.Program_Node: self.ast2sdfg, ast_internal_classes.Write_Stmt_Node: self.write2sdfg, ast_internal_classes.Allocate_Stmt_Node: self.allocate2sdfg, + ast_internal_classes.Break_Node: self.break2sdfg, } def get_dace_type(self, type): @@ -295,7 +296,7 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): begin_loop_state = sdfg.add_state("BeginLoop" + name) end_loop_state = sdfg.add_state("EndLoop" + name) self.last_sdfg_states[sdfg] = begin_loop_state - self.last_loop_continues[sdfg] = end_loop_state + self.last_loop_continues[sdfg] = final_substate self.translate(node.body, sdfg) sdfg.add_edge(self.last_sdfg_states[sdfg], end_loop_state, InterstateEdge()) @@ -1015,6 +1016,11 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): if node.name not in self.contexts[sdfg.name].containers: self.contexts[sdfg.name].containers.append(node.name) + def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG): + + self.last_loop_breaks[sdfg] = self.last_sdfg_states[sdfg] + sdfg.add_edge(self.last_sdfg_states[sdfg], self.last_loop_continues.get(sdfg), InterstateEdge()) + def create_ast_from_string( source_string: str, sdfg_name: str, @@ -1045,7 +1051,10 @@ def create_ast_from_string( program = ast_transforms.CallExtractor().visit(program) program = ast_transforms.SignToIf().visit(program) program = ast_transforms.ArrayToLoop(program).visit(program) - program = ast_transforms.SumToLoop(program).visit(program) + + for transformation in own_ast.fortran_intrinsics().transformations(): + program = transformation(program).visit(program) + program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) @@ -1077,7 +1086,10 @@ def create_sdfg_from_string( program = ast_transforms.CallExtractor().visit(program) program = ast_transforms.SignToIf().visit(program) program = ast_transforms.ArrayToLoop(program).visit(program) - program = ast_transforms.SumToLoop(program).visit(program) + + for transformation in own_ast.fortran_intrinsics().transformations(): + program = transformation(program).visit(program) + program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) ast2sdfg = AST_translator(own_ast, __file__) @@ -1119,7 +1131,10 @@ def create_sdfg_from_fortran_file(source_string: str): program = ast_transforms.CallExtractor().visit(program) program = ast_transforms.SignToIf().visit(program) program = ast_transforms.ArrayToLoop(program).visit(program) - program = ast_transforms.SumToLoop(program).visit(program) + + for transformation in own_ast.fortran_intrinsics(): + program = transformation(program).visit(program) + program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program).visit(program) ast2sdfg = AST_translator(own_ast, __file__) diff --git a/dace/frontend/fortran/intrinsics.py b/dace/frontend/fortran/intrinsics.py new file mode 100644 index 0000000000..c2e5afe79b --- /dev/null +++ b/dace/frontend/fortran/intrinsics.py @@ -0,0 +1,1033 @@ + +from abc import abstractmethod +import copy +import math +from typing import Any, List, Optional, Set, Tuple, Type + +from dace.frontend.fortran import ast_internal_classes +from dace.frontend.fortran.ast_utils import fortrantypes2dacetypes +from dace.frontend.fortran.ast_transforms import NodeVisitor, NodeTransformer, ParentScopeAssigner, ScopeVarsDeclarations, par_Decl_Range_Finder, mywalk + +FASTNode = Any + +class IntrinsicTransformation: + + @staticmethod + @abstractmethod + def replaced_name(func_name: str) -> str: + pass + + @staticmethod + @abstractmethod + def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line) -> ast_internal_classes.FNode: + pass + + @staticmethod + def has_transformation() -> bool: + return False + +class SelectedKind(IntrinsicTransformation): + + FUNCTIONS = { + "SELECTED_INT_KIND": "__dace_selected_int_kind", + "SELECTED_REAL_KIND": "__dace_selected_real_kind", + } + + @staticmethod + def replaced_name(func_name: str) -> str: + return SelectedKind.FUNCTIONS[func_name] + + @staticmethod + def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line) -> ast_internal_classes.FNode: + + if func_name.name == "__dace_selected_int_kind": + return ast_internal_classes.Int_Literal_Node(value=str( + math.ceil((math.log2(math.pow(10, int(args.args[0].value))) + 1) / 8)), + line_number=line) + # This selects the smallest kind that can hold the given number of digits (fp64,fp32 or fp16) + elif func_name.name == "__dace_selected_real_kind": + if int(args.args[0].value) >= 9 or int(args.args[1].value) > 126: + return ast_internal_classes.Int_Literal_Node(value="8", line_number=line) + elif int(args.args[0].value) >= 3 or int(args.args[1].value) > 14: + return ast_internal_classes.Int_Literal_Node(value="4", line_number=line) + else: + return ast_internal_classes.Int_Literal_Node(value="2", line_number=line) + + raise NotImplemented() + +class LoopBasedReplacement: + + INTRINSIC_TO_DACE = { + "SUM": "__dace_sum", + "PRODUCT": "__dace_product", + "ANY": "__dace_any", + "ALL": "__dace_all", + "COUNT": "__dace_count", + "MINVAL": "__dace_minval", + "MAXVAL": "__dace_maxval", + "MERGE": "__dace_merge" + } + + @staticmethod + def replaced_name(func_name: str) -> str: + return LoopBasedReplacement.INTRINSIC_TO_DACE[func_name] + + @staticmethod + def has_transformation() -> bool: + return True + +class LoopBasedReplacementVisitor(NodeVisitor): + + """ + Finds all intrinsic operations that have to be transformed to loops in the AST + """ + def __init__(self, func_name: str): + self._func_name = func_name + self.nodes: List[ast_internal_classes.FNode] = [] + + def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): + + if isinstance(node.rval, ast_internal_classes.Call_Expr_Node): + if node.rval.name.name == self._func_name: + self.nodes.append(node) + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + return + +class LoopBasedReplacementTransformation(NodeTransformer): + + """ + Transforms the AST by removing intrinsic call and replacing it with loops + """ + def __init__(self, ast): + self.count = 0 + + # We need to rerun the assignment because transformations could have created + # new AST nodes + ParentScopeAssigner().visit(ast) + self.scope_vars = ScopeVarsDeclarations() + self.scope_vars.visit(ast) + self.rvals = [] + + @staticmethod + @abstractmethod + def func_name() -> str: + pass + + @abstractmethod + def _initialize(self): + pass + + @abstractmethod + def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): + pass + + @abstractmethod + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + pass + + @abstractmethod + def _initialize_result(self, node: ast_internal_classes.FNode) -> Optional[ast_internal_classes.BinOp_Node]: + pass + + @abstractmethod + def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + pass + + def _skip_result_assignment(self): + return False + + """ + When replacing Fortran's AST reference to an intrinsic function, we set a dummy variable with VOID type. + The reason is that at the point, we do not know the types of arguments. For many intrinsics, the return + type will depend on the input types. + + When transforming the AST, we gather all scopes and variable declarations in that scope. + Then, we can query the types of input arguments and properly determine the return type. + + Both the type of the variable and its corresponding Var_Decl_node need to be updated! + """ + + @abstractmethod + def _update_result_type(self, var: ast_internal_classes.Name_Node): + pass + + def _parse_array(self, node: ast_internal_classes.Execution_Part_Node, arg: ast_internal_classes.FNode) -> ast_internal_classes.Array_Subscript_Node: + + # supports syntax func(arr) + if isinstance(arg, ast_internal_classes.Name_Node): + array_node = ast_internal_classes.Array_Subscript_Node(parent=arg.parent) + array_node.name = arg + + # If we access SUM(arr) where arr has many dimensions, + # We need to create a ParDecl_Node for each dimension + dims = len(self.scope_vars.get_var(node.parent, arg.name).sizes) + array_node.indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * dims + + return array_node + + # supports syntax func(arr(:)) + if isinstance(arg, ast_internal_classes.Array_Subscript_Node): + return arg + + def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_internal_classes.BinOp_Node) -> Tuple[ + ast_internal_classes.Array_Subscript_Node, + Optional[ast_internal_classes.Array_Subscript_Node], + ast_internal_classes.BinOp_Node + ]: + + """ + Supports passing binary operations as an input to function. + In both cases, we extract the arrays used, and return a brand + new binary operation that has array references replaced. + We return both arrays (second optionaly None) and the binary op. + + The binary op can be: + + (1) arr1 op arr2 + where arr1 and arr2 are name node or array subscript node + #there, we need to extract shape and verify they are the same + + (2) arr1 op scalar + there, we ignore the scalar because it's not an array + + """ + if not isinstance(arg, ast_internal_classes.BinOp_Node): + return False + + first_array = self._parse_array(node, arg.lval) + second_array = self._parse_array(node, arg.rval) + has_two_arrays = first_array is not None and second_array is not None + + # array and scalar - simplified case + if not has_two_arrays: + + # if one side of the operator is scalar, then parsing array + # will return none + dominant_array = first_array + if dominant_array is None: + dominant_array = second_array + + # replace the array subscript node in the binary operation + # ignore this when the operand is a scalar + cond = copy.deepcopy(arg) + if first_array is not None: + cond.lval = dominant_array + if second_array is not None: + cond.rval = dominant_array + + return (dominant_array, None, cond) + + if len(first_array.indices) != len(second_array.indices): + raise TypeError("Can't parse Fortran binary op with different array ranks!") + + for left_idx, right_idx in zip(first_array.indices, second_array.indices): + if left_idx.type != right_idx.type: + raise TypeError("Can't parse Fortran binary op with different array ranks!") + + # Now, we need to convert the array to a proper subscript node + cond = copy.deepcopy(arg) + cond.lval = first_array + cond.rval = second_array + + return (first_array, second_array, cond) + + def _adjust_array_ranges(self, node: ast_internal_classes.FNode, array: ast_internal_classes.Array_Subscript_Node, loop_ranges_main: list, loop_ranges_array: list): + + """ + When given a binary operator with arrays as an argument to the intrinsic, + one array will dictate loop range. + However, the other array can potentially have a different access range. + Thus, we need to add an offset to the loop iterator when accessing array elements. + + If the access pattern on the right array is different, we need to shfit it - for every dimension. + For example, we can have arr(1:3) == arr2(3:5) + Then, loop_idx is from 1 to 3 + arr becomes arr[loop_idx] + but arr2 must be arr2[loop_idx + 2] + """ + for i in range(len(array.indices)): + + idx_var = array.indices[i] + start_loop = loop_ranges_main[i][0] + end_loop = loop_ranges_array[i][0] + + difference = int(end_loop.value) - int(start_loop.value) + if difference != 0: + new_index = ast_internal_classes.BinOp_Node( + lval=idx_var, + op="+", + rval=ast_internal_classes.Int_Literal_Node(value=str(difference)), + line_number=node.line_number + ) + array.indices[i] = new_index + + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): + + newbody = [] + + for child in node.execution: + lister = LoopBasedReplacementVisitor(self.func_name()) + lister.visit(child) + res = lister.nodes + + if res is None or len(res) == 0: + newbody.append(self.visit(child)) + continue + + self.loop_ranges = [] + # We need to reinitialize variables as the class is reused for transformation between different + # calls to the same intrinsic. + self._initialize() + + # Visit all intrinsic arguments and extract arrays + for i in mywalk(child.rval): + if isinstance(i, ast_internal_classes.Call_Expr_Node) and i.name.name == self.func_name(): + self._parse_call_expr_node(i) + + # Verify that all of intrinsic args are correct and prepare them for loop generation + self._summarize_args(node, child, newbody) + + # Change the type of result variable + self._update_result_type(child.lval) + + # Initialize the result variable + init_stm = self._initialize_result(child) + if init_stm is not None: + newbody.append(init_stm) + + # Generate the intrinsic-specific logic inside loop body + body = self._generate_loop_body(child) + + # Now generate the multi-dimensiona loop header and updates + range_index = 0 + for i in self.loop_ranges: + initrange = i[0] + finalrange = i[1] + init = ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), + op="=", + rval=initrange, + line_number=child.line_number) + cond = ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), + op="<=", + rval=finalrange, + line_number=child.line_number) + iter = ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), + op="=", + rval=ast_internal_classes.BinOp_Node( + lval=ast_internal_classes.Name_Node(name="tmp_parfor_" + str(self.count + range_index)), + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1")), + line_number=child.line_number) + current_for = ast_internal_classes.Map_Stmt_Node( + init=init, + cond=cond, + iter=iter, + body=ast_internal_classes.Execution_Part_Node(execution=[body]), + line_number=child.line_number) + body = current_for + range_index += 1 + + newbody.append(body) + + self.count = self.count + range_index + return ast_internal_classes.Execution_Part_Node(execution=newbody) + +class SumProduct(LoopBasedReplacementTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _initialize(self): + self.rvals = [] + self.argument_variable = None + + def _update_result_type(self, var: ast_internal_classes.Name_Node): + + """ + For both SUM and PRODUCT, the result type depends on the input variable. + """ + input_type = self.scope_vars.get_var(var.parent, self.argument_variable.name.name) + + var_decl = self.scope_vars.get_var(var.parent, var.name) + var.type = input_type.type + var_decl.type = input_type.type + + def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): + + for arg in node.args: + + array_node = self._parse_array(node, arg) + + if array_node is not None: + self.rvals.append(array_node) + else: + raise NotImplementedError("We do not support non-array arguments for SUM/PRODUCT") + + + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + + if len(self.rvals) != 1: + raise NotImplementedError("Only one array can be summed") + + self.argument_variable = self.rvals[0] + + par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + + def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + return ast_internal_classes.BinOp_Node( + lval=node.lval, + op="=", + rval=ast_internal_classes.Int_Literal_Node(value=self._result_init_value()), + line_number=node.line_number + ) + + def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + return ast_internal_classes.BinOp_Node( + lval=node.lval, + op="=", + rval=ast_internal_classes.BinOp_Node( + lval=node.lval, + op=self._result_update_op(), + rval=self.argument_variable, + line_number=node.line_number + ), + line_number=node.line_number + ) + + +class Sum(LoopBasedReplacement): + + """ + In this class, we implement the transformation for Fortran intrinsic SUM(:) + We support two ways of invoking the function - by providing array name and array subscript. + We do NOT support the *DIM* argument. + + During the loop construction, we add a single variable storing the partial result. + Then, we generate a binary node accumulating the result. + """ + + class Transformation(SumProduct): + + def __init__(self, ast): + super().__init__(ast) + + @staticmethod + def func_name() -> str: + return "__dace_sum" + + def _result_init_value(self): + return "0" + + def _result_update_op(self): + return "+" + +class Product(LoopBasedReplacement): + + """ + In this class, we implement the transformation for Fortran intrinsic PRODUCT(:) + We support two ways of invoking the function - by providing array name and array subscript. + We do NOT support the *DIM* and *MASK* arguments. + + During the loop construction, we add a single variable storing the partial result. + Then, we generate a binary node accumulating the result. + """ + + class Transformation(SumProduct): + + def __init__(self, ast): + super().__init__(ast) + + @staticmethod + def func_name() -> str: + return "__dace_product" + + def _result_init_value(self): + return "1" + + def _result_update_op(self): + return "*" + +class AnyAllCountTransformation(LoopBasedReplacementTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _initialize(self): + self.rvals = [] + + self.first_array = None + self.second_array = None + self.dominant_array = None + self.cond = None + + def _update_result_type(self, var: ast_internal_classes.Name_Node): + + """ + For all functions, the result type is INTEGER. + Theoretically, we should return LOGICAL for ANY and ALL, + but we no longer use booleans on DaCe side. + """ + var_decl = self.scope_vars.get_var(var.parent, var.name) + var.type = "INTEGER" + var_decl.type = "INTEGER" + + def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): + + if len(node.args) > 1: + raise NotImplementedError("Fortran ANY with the DIM parameter is not supported!") + arg = node.args[0] + + array_node = self._parse_array(node, arg) + if array_node is not None: + self.first_array = array_node + self.cond = ast_internal_classes.BinOp_Node( + op="==", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + lval=self.first_array, + line_number=node.line_number + ) + else: + self.first_array, self.second_array, self.cond = self._parse_binary_op(node, arg) + + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + + rangeslen_left = [] + par_Decl_Range_Finder(self.first_array, self.loop_ranges, [], rangeslen_left, self.count, new_func_body, self.scope_vars, True) + if self.second_array is None: + return + + loop_ranges_right = [] + rangeslen_right = [] + par_Decl_Range_Finder(self.second_array, loop_ranges_right, [], rangeslen_right, self.count, new_func_body, self.scope_vars, True) + + for left_len, right_len in zip(rangeslen_left, rangeslen_right): + if left_len != right_len: + raise TypeError("Can't support Fortran ANY with different array ranks!") + + # In this intrinsic, the left array dictates loop range. + # Thus, we only need to adjust the second array + self._adjust_array_ranges(node, self.second_array, self.loop_ranges, loop_ranges_right) + + + def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + init_value = self._result_init_value() + + return ast_internal_classes.BinOp_Node( + lval=node.lval, + op="=", + rval=ast_internal_classes.Int_Literal_Node(value=init_value), + line_number=node.line_number + ) + + def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + """ + For any, we check if the condition is true and then set the value to true + For all, we check if the condition is NOT true and then set the value to false + """ + + body_if = ast_internal_classes.Execution_Part_Node(execution=[ + self._result_loop_update(node), + # TODO: we should make the `break` generation conditional based on the architecture + # For parallel maps, we should have no breaks + # For sequential loop, we want a break to be faster + #ast_internal_classes.Break_Node( + # line_number=node.line_number + #) + ]) + + return ast_internal_classes.If_Stmt_Node( + cond=self._loop_condition(), + body=body_if, + body_else=ast_internal_classes.Execution_Part_Node(execution=[]), + line_number=node.line_number + ) + +class Any(LoopBasedReplacement): + + """ + In this class, we implement the transformation for Fortran intrinsic ANY + We support three ways of invoking the function - by providing array name, array subscript, + and a binary operation. + We do NOT support the *DIM* argument. + + First, we split the implementation between three scenarios: + (1) ANY(arr) + (2) ANY(arr1 op arr2) + (3) ANY(arr1 op scalar) + Depending on the scenario, we verify if all participating arrays have the same rank. + We determine the loop range based on the arrays, and convert all array accesses to depend on + the loop. We take special care for situations where arrays have different subscripts, e.g., + arr1(1:3) op arr2(5:7) - the second array needs a shift when indexing based on loop iterator. + + During the loop construction, we add a single variable storing the partial result. + Then, we generate an if condition inside the loop to check if the value is true or not. + For (1), we check if the array entry is equal to 1. + For (2), we reuse the provided binary operation. + When the condition is true, we set the value to true and exit. + """ + class Transformation(AnyAllCountTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _result_init_value(self): + return "0" + + def _result_loop_update(self, node: ast_internal_classes.FNode): + + return ast_internal_classes.BinOp_Node( + lval=copy.deepcopy(node.lval), + op="=", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=node.line_number + ) + + def _loop_condition(self): + return self.cond + + @staticmethod + def func_name() -> str: + return "__dace_any" + +class All(LoopBasedReplacement): + + """ + In this class, we implement the transformation for Fortran intrinsic ALL. + The implementation is very similar to ANY. + The main difference is that we initialize the partial result to 1, + and set it to 0 if any of the evaluated conditions is false. + """ + class Transformation(AnyAllCountTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _result_init_value(self): + return "1" + + def _result_loop_update(self, node: ast_internal_classes.FNode): + + return ast_internal_classes.BinOp_Node( + lval=copy.deepcopy(node.lval), + op="=", + rval=ast_internal_classes.Int_Literal_Node(value="0"), + line_number=node.line_number + ) + + def _loop_condition(self): + return ast_internal_classes.UnOp_Node( + op="not", + lval=self.cond + ) + + @staticmethod + def func_name() -> str: + return "__dace_all" + +class Count(LoopBasedReplacement): + + """ + In this class, we implement the transformation for Fortran intrinsic COUNT. + The implementation is very similar to ANY and ALL. + The main difference is that we initialize the partial result to 0 + and increment it if any of the evaluated conditions is true. + + We do not support the KIND argument. + """ + class Transformation(AnyAllCountTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _result_init_value(self): + return "0" + + def _result_loop_update(self, node: ast_internal_classes.FNode): + + update = ast_internal_classes.BinOp_Node( + lval=copy.deepcopy(node.lval), + op="+", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=node.line_number + ) + return ast_internal_classes.BinOp_Node( + lval=copy.deepcopy(node.lval), + op="=", + rval=update, + line_number=node.line_number + ) + + def _loop_condition(self): + return self.cond + + @staticmethod + def func_name() -> str: + return "__dace_count" + + +class MinMaxValTransformation(LoopBasedReplacementTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _initialize(self): + self.rvals = [] + self.argument_variable = None + + def _update_result_type(self, var: ast_internal_classes.Name_Node): + + """ + For both MINVAL and MAXVAL, the result type depends on the input variable. + """ + + input_type = self.scope_vars.get_var(var.parent, self.argument_variable.name.name) + + var_decl = self.scope_vars.get_var(var.parent, var.name) + var.type = input_type.type + var_decl.type = input_type.type + + def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): + + for arg in node.args: + + array_node = self._parse_array(node, arg) + + if array_node is not None: + self.rvals.append(array_node) + else: + raise NotImplementedError("We do not support non-array arguments for MINVAL/MAXVAL") + + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + + if len(self.rvals) != 1: + raise NotImplementedError("Only one array can be summed") + + self.argument_variable = self.rvals[0] + + par_Decl_Range_Finder(self.argument_variable, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + + def _initialize_result(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + return ast_internal_classes.BinOp_Node( + lval=node.lval, + op="=", + rval=self._result_init_value(self.argument_variable), + line_number=node.line_number + ) + + def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + cond = ast_internal_classes.BinOp_Node( + lval=self.argument_variable, + op=self._condition_op(), + rval=node.lval, + line_number=node.line_number + ) + body_if = ast_internal_classes.BinOp_Node( + lval=node.lval, + op="=", + rval=self.argument_variable, + line_number=node.line_number + ) + return ast_internal_classes.If_Stmt_Node( + cond=cond, + body=body_if, + body_else=ast_internal_classes.Execution_Part_Node(execution=[]), + line_number=node.line_number + ) + +class MinVal(LoopBasedReplacement): + + """ + In this class, we implement the transformation for Fortran intrinsic MINVAL. + + We do not support the MASK and DIM argument. + """ + class Transformation(MinMaxValTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _result_init_value(self, array: ast_internal_classes.Array_Subscript_Node): + + var_decl = self.scope_vars.get_var(array.parent, array.name.name) + + # TODO: this should be used as a call to HUGE + fortran_type = var_decl.type + dace_type = fortrantypes2dacetypes[fortran_type] + from dace.dtypes import max_value + max_val = max_value(dace_type) + + if fortran_type == "INTEGER": + return ast_internal_classes.Int_Literal_Node(value=str(max_val)) + elif fortran_type == "DOUBLE": + return ast_internal_classes.Real_Literal_Node(value=str(max_val)) + + def _condition_op(self): + return "<" + + @staticmethod + def func_name() -> str: + return "__dace_minval" + + +class MaxVal(LoopBasedReplacement): + + """ + In this class, we implement the transformation for Fortran intrinsic MAXVAL. + + We do not support the MASK and DIM argument. + """ + class Transformation(MinMaxValTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _result_init_value(self, array: ast_internal_classes.Array_Subscript_Node): + + var_decl = self.scope_vars.get_var(array.parent, array.name.name) + + # TODO: this should be used as a call to HUGE + fortran_type = var_decl.type + dace_type = fortrantypes2dacetypes[fortran_type] + from dace.dtypes import min_value + min_val = min_value(dace_type) + + if fortran_type == "INTEGER": + return ast_internal_classes.Int_Literal_Node(value=str(min_val)) + elif fortran_type == "DOUBLE": + return ast_internal_classes.Real_Literal_Node(value=str(min_val)) + + def _condition_op(self): + return ">" + + @staticmethod + def func_name() -> str: + return "__dace_maxval" + +class Merge(LoopBasedReplacement): + + class Transformation(LoopBasedReplacementTransformation): + + def __init__(self, ast): + super().__init__(ast) + + def _initialize(self): + self.rvals = [] + + self.first_array = None + self.second_array = None + self.mask_first_array = None + self.mask_second_array = None + self.mask_cond = None + self.destination_array = None + + @staticmethod + def func_name() -> str: + return "__dace_merge" + + def _update_result_type(self, var: ast_internal_classes.Name_Node): + """ + We can ignore the result type, because we exempted this + transformation from generating a result. + In MERGE, we write directly to the destination array. + Thus, we store this result array for future use. + """ + pass + + def _parse_call_expr_node(self, node: ast_internal_classes.Call_Expr_Node): + + if len(node.args) != 3: + raise NotImplementedError("Expected three arguments to MERGE!") + + # First argument is always an array + self.first_array = self._parse_array(node, node.args[0]) + assert self.first_array is not None + + # Second argument is always an array + self.second_array = self._parse_array(node, node.args[1]) + assert self.second_array is not None + + # Last argument is either an array or a binary op + arg = node.args[2] + array_node = self._parse_array(node, node.args[2]) + if array_node is not None: + + self.mask_first_array = array_node + self.mask_cond = ast_internal_classes.BinOp_Node( + op="==", + rval=ast_internal_classes.Int_Literal_Node(value="1"), + lval=self.mask_first_array, + line_number=node.line_number + ) + + else: + + self.mask_first_array, self.mask_second_array, self.mask_cond = self._parse_binary_op(node, arg) + + def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, node: ast_internal_classes.FNode, new_func_body: List[ast_internal_classes.FNode]): + + self.destination_array = self._parse_array(exec_node, node.lval) + + # The first main argument is an array -> this dictates loop boundaries + # Other arrays, regardless if they appear as the second array or mask, need to have the same loop boundary. + par_Decl_Range_Finder(self.first_array, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + + loop_ranges = [] + par_Decl_Range_Finder(self.second_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + self._adjust_array_ranges(node, self.second_array, self.loop_ranges, loop_ranges) + + par_Decl_Range_Finder(self.destination_array, [], [], [], self.count, new_func_body, self.scope_vars, True) + + if self.mask_first_array is not None: + loop_ranges = [] + par_Decl_Range_Finder(self.mask_first_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + self._adjust_array_ranges(node, self.mask_first_array, self.loop_ranges, loop_ranges) + + if self.mask_second_array is not None: + loop_ranges = [] + par_Decl_Range_Finder(self.mask_second_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True) + self._adjust_array_ranges(node, self.mask_second_array, self.loop_ranges, loop_ranges) + + def _initialize_result(self, node: ast_internal_classes.FNode) -> Optional[ast_internal_classes.BinOp_Node]: + """ + We don't use result variable in MERGE. + """ + return None + + def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_classes.BinOp_Node: + + """ + We check if the condition is true. If yes, then we write from the first array. + Otherwise, we copy data from the second array. + """ + + copy_first = ast_internal_classes.BinOp_Node( + lval=copy.deepcopy(self.destination_array), + op="=", + rval=self.first_array, + line_number=node.line_number + ) + + copy_second = ast_internal_classes.BinOp_Node( + lval=copy.deepcopy(self.destination_array), + op="=", + rval=self.second_array, + line_number=node.line_number + ) + + body_if = ast_internal_classes.Execution_Part_Node(execution=[ + copy_first + ]) + + body_else = ast_internal_classes.Execution_Part_Node(execution=[ + copy_second + ]) + + return ast_internal_classes.If_Stmt_Node( + cond=self.mask_cond, + body=body_if, + body_else=body_else, + line_number=node.line_number + ) + +class FortranIntrinsics: + + IMPLEMENTATIONS_AST = { + "SELECTED_INT_KIND": SelectedKind, + "SELECTED_REAL_KIND": SelectedKind, + "SUM": Sum, + "PRODUCT": Product, + "ANY": Any, + "COUNT": Count, + "ALL": All, + "MINVAL": MinVal, + "MAXVAL": MaxVal, + "MERGE": Merge + } + + DIRECT_REPLACEMENTS = { + "__dace_selected_int_kind": SelectedKind, + "__dace_selected_real_kind": SelectedKind + } + + EXEMPTED_FROM_CALL_EXTRACTION = [ + Merge + ] + + def __init__(self): + self._transformations_to_run = set() + + def transformations(self) -> Set[Type[NodeTransformer]]: + return self._transformations_to_run + + @staticmethod + def function_names() -> List[str]: + return list(LoopBasedReplacement.INTRINSIC_TO_DACE.values()) + + @staticmethod + def call_extraction_exemptions() -> List[str]: + return [func.Transformation.func_name() for func in FortranIntrinsics.EXEMPTED_FROM_CALL_EXTRACTION] + + def replace_function_name(self, node: FASTNode) -> ast_internal_classes.Name_Node: + + func_name = node.string + replacements = { + "INT": "__dace_int", + "DBLE": "__dace_dble", + "SQRT": "sqrt", + "COSH": "cosh", + "ABS": "abs", + "MIN": "min", + "MAX": "max", + "EXP": "exp", + "EPSILON": "__dace_epsilon", + "TANH": "tanh", + "SIGN": "__dace_sign", + "EXP": "exp" + } + if func_name in replacements: + return ast_internal_classes.Name_Node(name=replacements[func_name]) + else: + + if self.IMPLEMENTATIONS_AST[func_name].has_transformation(): + self._transformations_to_run.add(self.IMPLEMENTATIONS_AST[func_name].Transformation) + + return ast_internal_classes.Name_Node(name=self.IMPLEMENTATIONS_AST[func_name].replaced_name(func_name)) + + def replace_function_reference(self, name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line): + + func_types = { + "__dace_int": "INT", + "__dace_dble": "DOUBLE", + "sqrt": "DOUBLE", + "cosh": "DOUBLE", + "abs": "DOUBLE", + "min": "DOUBLE", + "max": "DOUBLE", + "exp": "DOUBLE", + "__dace_epsilon": "DOUBLE", + "tanh": "DOUBLE", + "__dace_sign": "DOUBLE", + } + if name.name in func_types: + # FIXME: this will be progressively removed + call_type = func_types[name.name] + return ast_internal_classes.Call_Expr_Node(name=name, type=call_type, args=args.args, line_number=line) + elif name.name in self.DIRECT_REPLACEMENTS: + return self.DIRECT_REPLACEMENTS[name.name].replace(name, args, line) + else: + # We will do the actual type replacement later + # To that end, we need to know the input types - but these we do not know at the moment. + return ast_internal_classes.Call_Expr_Node( + name=name, type="VOID", + args=args.args, line_number=line + ) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 991613a9ea..1b6817a7d0 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -240,6 +240,9 @@ def to_sdfg(self, *args, simplify=None, save=False, validate=False, use_cache=Fa warnings.warn("You are calling to_sdfg() on a dace program that " "has set 'recompile' to False. " "This may not be what you want.") + if self.autoopt == True: + warnings.warn("You are calling to_sdfg() on a dace program that " + "has set `auto_optimize` to True. Automatic optimization will not be applied.") if use_cache: # Update global variables with current closure diff --git a/dace/properties.py b/dace/properties.py index 44f8b4fbcc..e02a54ad1f 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -1153,7 +1153,7 @@ def allow_none(self): def __set__(self, obj, val): if isinstance(val, str): val = self.from_string(val) - if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices)): + if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices) and not isinstance(val, sbs.SubsetUnion)): raise TypeError("Subset property must be either Range or Indices: got {}".format(type(val).__name__)) super(SubsetProperty, self).__set__(obj, val) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 917f748cb8..084d46f47d 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -275,7 +275,7 @@ def remove_name_collisions(sdfg: SDFG): # Rename duplicate states for state in nsdfg.nodes(): if state.label in state_names_seen: - state.set_label(data.find_new_name(state.label, state_names_seen)) + state.label = data.find_new_name(state.label, state_names_seen) state_names_seen.add(state.label) replacements: Dict[str, str] = {} diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 55ff69a994..796d8d7633 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -262,9 +262,8 @@ def label(self): def __label__(self, sdfg, state): return self.data - def desc(self, sdfg): - from dace.sdfg import SDFGState, ScopeSubgraphView - if isinstance(sdfg, (SDFGState, ScopeSubgraphView)): + def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']): + if isinstance(sdfg, (dace.sdfg.SDFGState, dace.sdfg.ScopeSubgraphView)): sdfg = sdfg.parent return sdfg.arrays[self.data] diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 4b36fad4fe..a2c7b9a43c 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -175,17 +175,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]): sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname] del sdfg.constants_prop[aname] - # Replace in interstate edges - for e in sdfg.edges(): - e.data.replace_dict(repl, replace_keys=False) - - for state in sdfg.nodes(): - # Replace in access nodes - for node in state.data_nodes(): - if node.data in repl: - node.data = repl[node.data] - - # Replace in memlets - for edge in state.edges(): - if edge.data.data in repl: - edge.data.data = repl[edge.data.data] + for cf in sdfg.all_control_flow_regions(): + # Replace in interstate edges + for e in cf.edges(): + e.data.replace_dict(repl, replace_keys=False) + + for block in cf.nodes(): + if isinstance(block, dace.SDFGState): + # Replace in access nodes + for node in block.data_nodes(): + if node.data in repl: + node.data = repl[node.data] + # Replace in memlets + for edge in block.edges(): + if edge.data.data in repl: + edge.data.data = repl[edge.data.data] diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a85e773337..fdf8835c7e 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -30,7 +30,7 @@ from dace.frontend.python import astutils, wrappers from dace.sdfg import nodes as nd from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ControlFlowRegion from dace.sdfg.propagation import propagate_memlets_sdfg from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name @@ -402,7 +402,7 @@ def label(self): @make_properties -class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): +class SDFG(ControlFlowRegion): """ The main intermediate representation of code in DaCe. A Stateful DataFlow multiGraph (SDFG) is a directed graph of directed @@ -499,8 +499,6 @@ def __init__(self, self._parent_sdfg = None self._parent_nsdfg_node = None self._sdfg_list = [self] - self._start_state: Optional[int] = None - self._cached_start_state: Optional[SDFGState] = None self._arrays = NestedDict() # type: Dict[str, dt.Array] self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -531,14 +529,14 @@ def __deepcopy__(self, memo): memo[id(self)] = result for k, v in self.__dict__.items(): # Skip derivative attributes - if k in ('_cached_start_state', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', + if k in ('_cached_start_block', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', '_sdfg_list', '_transformation_hist'): continue setattr(result, k, copy.deepcopy(v, memo)) # Copy edges and nodes result._edges = copy.deepcopy(self._edges, memo) result._nodes = copy.deepcopy(self._nodes, memo) - result._cached_start_state = copy.deepcopy(self._cached_start_state, memo) + result._cached_start_block = copy.deepcopy(self._cached_start_block, memo) # Copy parent attributes for k in ('_parent', '_parent_sdfg', '_parent_nsdfg_node'): if id(getattr(self, k)) in memo: @@ -583,7 +581,7 @@ def to_json(self, hash=False): tmp['attributes']['constants_prop'] = json.loads(dace.serialize.dumps(tmp['attributes']['constants_prop'])) tmp['sdfg_list_id'] = int(self.sdfg_id) - tmp['start_state'] = self._start_state + tmp['start_state'] = self._start_block tmp['attributes']['name'] = self.name if hash: @@ -627,7 +625,7 @@ def from_json(cls, json_obj, context_info=None): ret.add_edge(nodelist[int(e.src)], nodelist[int(e.dst)], e.data) if 'start_state' in json_obj: - ret._start_state = json_obj['start_state'] + ret._start_block = json_obj['start_state'] return ret @@ -753,14 +751,7 @@ def replace_dict(self, for array in self.arrays.values(): replace_properties_dict(array, repldict, symrepl) - if replace_in_graph: - # Replace in inter-state edges - for edge in self.edges(): - edge.data.replace_dict(repldict, replace_keys=replace_keys) - - # Replace in states - for state in self.nodes(): - state.replace_dict(repldict, symrepl) + super().replace_dict(repldict, symrepl, replace_in_graph, replace_keys) def add_symbol(self, name, stype): """ Adds a symbol to the SDFG. @@ -787,34 +778,11 @@ def remove_symbol(self, name): @property def start_state(self): - """ Returns the starting state of this SDFG. """ - if self._cached_start_state is not None: - return self._cached_start_state - - source_nodes = self.source_nodes() - if len(source_nodes) == 1: - self._cached_start_state = source_nodes[0] - return source_nodes[0] - # If starting state is ambiguous (i.e., loop to initial state or more - # than one possible start state), allow manually overriding start state - if self._start_state is not None: - self._cached_start_state = self.node(self._start_state) - return self._cached_start_state - raise ValueError('Ambiguous or undefined starting state for SDFG, ' - 'please use "is_start_state=True" when adding the ' - 'starting state with "add_state"') + return self.start_block @start_state.setter def start_state(self, state_id): - """ Manually sets the starting state of this SDFG. - - :param state_id: The node ID (use `node_id(state)`) of the - state to set. - """ - if state_id < 0 or state_id >= self.number_of_nodes(): - raise ValueError("Invalid state ID") - self._start_state = state_id - self._cached_start_state = self.node(state_id) + self.start_block = state_id def set_global_code(self, cpp_code: str, location: str = 'frame'): """ @@ -1127,7 +1095,7 @@ def remove_data(self, name, validate=True): # Verify that there are no access nodes that use this data if validate: - for state in self.nodes(): + for state in self.states(): for node in state.nodes(): if isinstance(node, nd.AccessNode) and node.data == name: raise ValueError(f"Cannot remove data descriptor " @@ -1243,75 +1211,14 @@ def parent_sdfg(self, value): def parent_nsdfg_node(self, value): self._parent_nsdfg_node = value - def add_node(self, node, is_start_state=False): - """ Adds a new node to the SDFG. Must be an SDFGState or a subclass - thereof. - - :param node: The node to add. - :param is_start_state: If True, sets this node as the starting - state. - """ - if not isinstance(node, SDFGState): - raise TypeError("Expected SDFGState, got " + str(type(node))) - super(SDFG, self).add_node(node) - self._cached_start_state = None - if is_start_state is True: - self.start_state = len(self.nodes()) - 1 - self._cached_start_state = node - def remove_node(self, node: SDFGState): - if node is self._cached_start_state: - self._cached_start_state = None + if node is self._cached_start_block: + self._cached_start_block = None return super().remove_node(node) - def add_edge(self, u, v, edge): - """ Adds a new edge to the SDFG. Must be an InterstateEdge or a - subclass thereof. - - :param u: Source node. - :param v: Destination node. - :param edge: The edge to add. - """ - if not isinstance(u, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(u).__name__)) - if not isinstance(v, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(v).__name__)) - if not isinstance(edge, InterstateEdge): - raise TypeError("Expected InterstateEdge, got: {}".format(type(edge).__name__)) - if v is self._cached_start_state: - self._cached_start_state = None - return super(SDFG, self).add_edge(u, v, edge) - def states(self): - """ Alias that returns the nodes (states) in this SDFG. """ - return self.nodes() - - def all_nodes_recursive(self) -> Iterator[Tuple[nd.Node, Union['SDFG', 'SDFGState']]]: - """ Iterate over all nodes in this SDFG, including states, nodes in - states, and recursive states and nodes within nested SDFGs, - returning tuples on the form (node, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for node in self.nodes(): - yield node, self - yield from node.all_nodes_recursive() - - def all_sdfgs_recursive(self): - """ Iterate over this and all nested SDFGs. """ - yield self - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_sdfgs_recursive() - - def all_edges_recursive(self): - """ Iterate over all edges in this SDFG, including state edges, - inter-state edges, and recursively edges within nested SDFGs, - returning tuples on the form (edge, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for e in self.edges(): - yield e, self - for node in self.nodes(): - yield from node.all_edges_recursive() + """ Returns the states in this SDFG, recursing into state scope blocks. """ + return list(self.all_states()) def arrays_recursive(self): """ Iterate over all arrays in this SDFG, including arrays within @@ -1323,19 +1230,15 @@ def arrays_recursive(self): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.arrays_recursive() - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping - will be removed from the set of defined symbols. - """ - defined_syms = set() - free_syms = set() + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set]=None, + free_syms: Optional[Set]=None, + used_before_assignment: Optional[Set]=None, + keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment # Exclude data descriptor names and constants for name in self.arrays.keys(): @@ -1349,54 +1252,10 @@ def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) - for code in self.exit_code.values(): free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) - # Add free state symbols - used_before_assignment = set() - - try: - ordered_states = self.topological_sort(self.start_state) - except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) - ordered_states = self.nodes() - - for state in ordered_states: - state_fsyms = state.used_symbols(all_symbols) - free_syms |= state_fsyms - - # Add free inter-state symbols - for e in self.out_edges(state): - # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by - # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly - # compute the symbols that are used before being assigned. - efsyms = e.data.used_symbols(all_symbols) - defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_fsyms) - used_before_assignment.update(efsyms - defined_syms) - free_syms |= efsyms - - # Remove symbols that were used before they were assigned - defined_syms -= used_before_assignment - - # Remove from defined symbols those that are in the symbol mapping - if self.parent_nsdfg_node is not None and keep_defined_in_mapping: - defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) - - # Add the set of SDFG symbol parameters - # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets - if all_symbols: - free_syms |= set(self.symbols.keys()) - - # Subtract symbols defined in inter-state edges and constants - return free_syms - defined_syms - - @property - def free_symbols(self) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG and verify that ``SDFG.symbols`` is complete. - - :note: Assumes that the graph is valid (i.e., without undefined or - overlapping symbols). - """ - return self.used_symbols(all_symbols=True) + return super()._used_symbols_internal( + all_symbols=all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, + defined_syms=defined_syms, free_syms=free_syms, used_before_assignment=used_before_assignment + ) def get_all_toplevel_symbols(self) -> Set[str]: """ @@ -1608,16 +1467,16 @@ def shared_transients(self, check_toplevel=True) -> List[str]: shared = [] # If a transient is present in an inter-state edge, it is shared - for interstate_edge in self.edges(): + for interstate_edge in self.all_interstate_edges(): for sym in interstate_edge.data.free_symbols: if sym in self.arrays and self.arrays[sym].transient: seen[sym] = interstate_edge shared.append(sym) # If transient is accessed in more than one state, it is shared - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.AccessNode) and node.desc(self).transient: + for state in self.states(): + for node in state.data_nodes(): + if node.desc(self).transient: if (check_toplevel and node.desc(self).toplevel) or (node.data in seen and seen[node.data] != state): shared.append(node.data) @@ -1706,62 +1565,6 @@ def from_file(filename: str) -> 'SDFG': # Dynamic SDFG creation API ############################## - def add_state(self, label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state to this graph and returns it. - - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - if self._labels is None or len(self._labels) != self.number_of_nodes(): - self._labels = set(s.label for s in self.nodes()) - label = label or 'state' - existing_labels = self._labels - label = dt.find_new_name(label, existing_labels) - state = SDFGState(label, self) - self._labels.add(label) - - self.add_node(state, is_start_state=is_start_state) - return state - - def add_state_before(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state before an existing state, reconnecting - predecessors to it instead. - - :param state: The state to prepend the new state before. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.in_edges(state): - self.remove_edge(e) - self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, InterstateEdge()) - return new_state - - def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state after an existing state, reconnecting - it to the successors instead. - - :param state: The state to append the new state after. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.out_edges(state): - self.remove_edge(e) - self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, InterstateEdge()) - return new_state def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ @@ -2482,7 +2285,7 @@ def __call__(self, *args, **kwargs): def fill_scope_connectors(self): """ Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit nodes) according to data on the memlets. """ - for state in self.nodes(): + for state in self.states(): state.fill_scope_connectors() def predecessor_state_transitions(self, state): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 69fccfdabd..c8e67a7c20 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2,6 +2,7 @@ """ Contains classes of a single SDFG state and dataflow subgraphs. """ import ast +import abc import collections import copy import inspect @@ -19,7 +20,7 @@ from dace.properties import (CodeBlock, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView +from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -28,6 +29,11 @@ import dace.sdfg.scope +NodeT = Union[nd.Node, 'ControlFlowBlock'] +EdgeT = Union[MultiConnectorEdge[mm.Memlet], Edge['dace.sdfg.InterstateEdge']] +GraphT = Union['ControlFlowRegion', 'SDFGState'] + + def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. @@ -66,13 +72,248 @@ def _make_iterators(ndrange): return params, map_range -class StateGraphView(object): +class BlockGraphView(object): """ - Read-only view interface of an SDFG state, containing methods for memlet - tracking, traversal, subgraph creation, queries, and replacements. - ``SDFGState`` and ``StateSubgraphView`` inherit from this class to share + Read-only view interface of an SDFG control flow block, containing methods for memlet tracking, traversal, subgraph + creation, queries, and replacements. ``ControlFlowBlock`` and ``StateSubgraphView`` inherit from this class to share methods. """ + + + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List[NodeT]: + ... + + @overload + def edges(self) -> List[EdgeT]: + ... + + @overload + def in_degree(self, node: NodeT) -> int: + ... + + @overload + def out_degree(self, node: NodeT) -> int: + ... + + ################################################################### + # Traversal methods + + @abc.abstractmethod + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + """ + Iterate over all nodes in this graph or subgraph. + This includes control flow blocks, nodes in those blocks, and recursive control flow blocks and nodes within + nested SDFGs. It returns tuples of the form (node, parent), where the node is either a dataflow node, in which + case the parent is an SDFG state, or a control flow block, in which case the parent is a control flow graph + (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: + """ + Iterate over all edges in this graph or subgraph. + This includes dataflow edges, inter-state edges, and recursive edges within nested SDFGs. It returns tuples of + the form (edge, parent), where the edge is either a dataflow edge, in which case the parent is an SDFG state, or + an inter-stte edge, in which case the parent is a control flow graph (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def data_nodes(self) -> List[nd.AccessNode]: + """ + Returns all data nodes (i.e., AccessNodes, arrays) present in this graph or subgraph. + Note: This does not recurse into nested SDFGs. + """ + raise NotImplementedError() + + @abc.abstractmethod + def entry_node(self, node: nd.Node) -> nd.EntryNode: + """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ + raise NotImplementedError() + + @abc.abstractmethod + def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + """ Returns the exit node leaving the context opened by the given entry node. """ + raise NotImplementedError() + + ################################################################### + # Memlet-tracking methods + + @abc.abstractmethod + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + """ + Given one edge, returns a list of edges representing a path between its source and sink nodes. + Used for memlet tracking. + + :note: Behavior is undefined when there is more than one path involving this edge. + :param edge: An edge within a state (memlet). + :return: A list of edges from a source node to a destination node. + """ + raise NotImplementedError() + + @abc.abstractmethod + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + """ + Given one edge, returns a tree of edges between its node source(s) and sink(s). + Used for memlet tracking. + + :param edge: An edge within a state (memlet). + :return: A tree of edges whose root is the source/sink node (depending on direction) and associated children + edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering the given connector of the given node. + + :param node: Destination node of edges. + :param connector: Destination connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges exiting the given connector of the given node. + + :param node: Source node of edges. + :param connector: Source connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering or exiting the given connector of the given node. + + :param node: Source/destination node of edges. + :param connector: Source/destination connector of edges. + """ + raise NotImplementedError() + + ################################################################### + # Query, subgraph, and replacement methods + + @abc.abstractmethod + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + """ + Returns a set of symbol names that are used in the graph. + + :param all_symbols: If False, only returns symbols that are needed as arguments (only used in generated code). + :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping + will be removed from the set of defined symbols. + """ + raise NotImplementedError() + + @property + def free_symbols(self) -> Set[str]: + """ + Returns a set of symbol names that are used, but not defined, in this graph view. + In the case of an SDFG, this property is used to determine the symbolic parameters of the SDFG and + verify that ``SDFG.symbols`` is complete. + + :note: Assumes that the graph is valid (i.e., without undefined or overlapping symbols). + """ + return self.used_symbols(all_symbols=True) + + @abc.abstractmethod + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + """ + Determines what data is read and written in this graph. + Does not include reads to subsets of containers that have previously been written within the same state. + + :return: A two-tuple of sets of things denoting ({data read}, {data written}). + """ + raise NotImplementedError() + + @abc.abstractmethod + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + raise NotImplementedError() + + def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: + """ + Returns an ordered dictionary of arguments (names and types) required to invoke this subgraph. + + The arguments differ from SDFG.arglist, but follow the same order, + namely: , . + + Data arguments contain: + * All used non-transient data containers in the subgraph + * All used transient data containers that were allocated outside. + This includes data from memlets, transients shared across multiple states, and transients that could not + be allocated within the subgraph (due to their ``AllocationLifetime`` or according to the + ``dtypes.can_allocate`` function). + + Scalar arguments contain: + * Free symbols in this state/subgraph. + * All transient and non-transient scalar data containers used in this subgraph. + + This structure will create a sorted list of pointers followed by a sorted list of PoDs and structs. + + :return: An ordered dictionary of (name, data descriptor type) of all the arguments, sorted as defined here. + """ + data_args, scalar_args = self.unordered_arglist(defined_syms, shared_transients) + + # Fill up ordered dictionary + result = collections.OrderedDict() + for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): + result[k] = v + + return result + + def signature_arglist(self, with_types=True, for_call=False): + """ Returns a list of arguments necessary to call this state or subgraph, formatted as a list of C definitions. + + :param with_types: If True, includes argument types in the result. + :param for_call: If True, returns arguments that can be used when calling the SDFG. + :return: A list of strings. For example: `['float *A', 'int b']`. + """ + return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in self.arglist().items()] + + @abc.abstractmethod + def top_level_transients(self) -> Set[str]: + """Iterate over top-level transients of this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def all_transients(self) -> List[str]: + """Iterate over all transients in this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def replace(self, name: str, new_name: str): + """ + Finds and replaces all occurrences of a symbol or array in this graph. + + :param name: Name to find. + :param new_name: Name to replace. + """ + raise NotImplementedError() + + @abc.abstractmethod + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): + """ + Finds and replaces all occurrences of a set of symbols or arrays in this graph. + + :param repl: Mapping from names to replacements. + :param symrepl: Optional symbolic version of ``repl``. + """ + raise NotImplementedError() + + +@make_properties +class DataflowGraphView(BlockGraphView, abc.ABC): def __init__(self, *args, **kwargs): self._clear_scopedict_cache() @@ -91,29 +332,29 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]: ################################################################### # Traversal methods - def all_nodes_recursive(self): + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: for node in self.nodes(): yield node, self if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_nodes_recursive() - def all_edges_recursive(self): + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: for e in self.edges(): yield e, self for node in self.nodes(): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_edges_recursive() - def data_nodes(self): + def data_nodes(self) -> List[nd.AccessNode]: """ Returns all data_nodes (arrays) present in this state. """ return [n for n in self.nodes() if isinstance(n, nd.AccessNode)] - def entry_node(self, node: nd.Node) -> nd.EntryNode: + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ return self.scope_dict()[node] - def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: """ Returns the exit node leaving the context opened by the given entry node. """ node_to_children = self.scope_children() @@ -152,7 +393,7 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto result.insert(0, next_edge) curedge = next_edge - # Prepend outgoing edges until reaching the sink node + # Append outgoing edges until reaching the sink node curedge = edge while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)): # Trace through scope entry using IN_# -> OUT_# @@ -168,13 +409,6 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto return result def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: - """ Given one edge, returns a tree of edges between its node source(s) - and sink(s). Used for memlet tracking. - - :param edge: An edge within this state. - :return: A tree of edges whose root is the source/sink node - (depending on direction) and associated children edges. - """ propagate_forward = False propagate_backward = False if ((isinstance(edge.src, nd.EntryNode) and edge.src_conn is not None) or @@ -246,30 +480,12 @@ def traverse(node): return traverse(tree_root) def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering the given connector of the - given node. - - :param node: Destination node of edges. - :param connector: Destination connector of edges. - """ return (e for e in self.in_edges(node) if e.dst_conn == connector) def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges exiting the given connector of the - given node. - - :param node: Source node of edges. - :param connector: Source connector of edges. - """ return (e for e in self.out_edges(node) if e.src_conn == connector) def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering or exiting the given - connector of the given node. - - :param node: Source/destination node of edges. - :param connector: Source/destination connector of edges. - """ return itertools.chain(self.in_edges_by_connector(node, connector), self.out_edges_by_connector(node, connector)) @@ -297,8 +513,6 @@ def scope_tree(self) -> 'dace.sdfg.scope.ScopeTree': result = {} - sdfg_symbols = self.parent.symbols.keys() - # Get scopes for node, scopenodes in sdc.items(): if node is None: @@ -325,15 +539,7 @@ def scope_leaves(self) -> List['dace.sdfg.scope.ScopeTree']: self._scope_leaves_cached = [scope for scope in st.values() if len(scope.children) == 0] return copy.copy(self._scope_leaves_cached) - def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd.Node, Optional[nd.Node]]: - """ Returns a dictionary that maps each SDFG node to its parent entry - node, or to None if the node is not in any scope. - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to its parent scope entry node. - """ + def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd.Node, Union['SDFGState', nd.Node]]: from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids result = None result = copy.copy(self._scope_dict_toparent_cached) @@ -367,16 +573,7 @@ def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd def scope_children(self, return_ids: bool = False, - validate: bool = True) -> Dict[Optional[nd.EntryNode], List[nd.Node]]: - """ Returns a dictionary that maps each SDFG entry node to its children, - not including the children of children entry nodes. The key `None` - contains a list of top-level nodes (i.e., not in any scope). - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to a list of children nodes. - """ + validate: bool = True) -> Dict[Union[nd.Node, 'SDFGState'], List[nd.Node]]: from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids result = None if self._scope_dict_tochildren_cached is not None: @@ -419,13 +616,7 @@ def is_leaf_memlet(self, e): return False return True - def used_symbols(self, all_symbols: bool) -> Set[str]: - """ - Returns a set of symbol names that are used in the state. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - """ + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: state = self.graph if isinstance(self, SubgraphView) else self sdfg = state.parent new_symbols = set() @@ -579,33 +770,9 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set, write_set = self._read_and_write_sets() return set(read_set.keys()), set(write_set.keys()) - def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: - """ - Returns an ordered dictionary of arguments (names and types) required - to invoke this SDFG state or subgraph thereof. - - The arguments differ from SDFG.arglist, but follow the same order, - namely: , . - - Data arguments contain: - * All used non-transient data containers in the subgraph - * All used transient data containers that were allocated outside. - This includes data from memlets, transients shared across multiple - states, and transients that could not be allocated within the - subgraph (due to their ``AllocationLifetime`` or according to the - ``dtypes.can_allocate`` function). - - Scalar arguments contain: - * Free symbols in this state/subgraph. - * All transient and non-transient scalar data containers used in - this subgraph. - - This structure will create a sorted list of pointers followed by a - sorted list of PoDs and structs. - - :return: An ordered dictionary of (name, data descriptor type) of all - the arguments, sorted as defined here. - """ + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: sdfg: 'dace.sdfg.SDFG' = self.parent shared_transients = shared_transients or sdfg.shared_transients() sdict = self.scope_dict() @@ -699,12 +866,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat if not str(k).startswith('__dace') and str(k) not in sdfg.constants }) - # Fill up ordered dictionary - result = collections.OrderedDict() - for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): - result[k] = v - - return result + return data_args, scalar_args def signature_arglist(self, with_types=True, for_call=False): """ Returns a list of arguments necessary to call this state or @@ -749,22 +911,212 @@ def replace(self, name: str, new_name: str): def replace_dict(self, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): - """ Finds and replaces all occurrences of a set of symbols or arrays in this state. - - :param repl: Mapping from names to replacements. - :param symrepl: Optional symbolic version of ``repl``. - """ from dace.sdfg.replace import replace_dict replace_dict(self, repl, symrepl) @make_properties -class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], StateGraphView): +class ControlGraphView(BlockGraphView, abc.ABC): + + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List['ControlFlowBlock']: + ... + + @overload + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + ################################################################### + # Traversal methods + + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + for node in self.nodes(): + yield node, self + yield from node.all_nodes_recursive() + + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: + for e in self.edges(): + yield e, self + for node in self.nodes(): + yield from node.all_edges_recursive() + + def data_nodes(self) -> List[nd.AccessNode]: + data_nodes = [] + for node in self.nodes(): + data_nodes.extend(node.data_nodes()) + return data_nodes + + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: + for block in self.nodes(): + if node in block.nodes(): + return block.exit_node(node) + return None + + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: + for block in self.nodes(): + if entry_node in block.nodes(): + return block.exit_node(entry_node) + return None + + ################################################################### + # Memlet-tracking methods + + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_path(edge) + return [] + + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_tree(edge) + return mm.MemletTree(edge) + + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.in_edges_by_connector(node, connector) + return [] + + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.out_edges_by_connector(node, connector) + return [] + + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.edges_by_connector(node, connector) + + ################################################################### + # Query, subgraph, and replacement methods + + @abc.abstractmethod + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + raise NotImplementedError() + + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping)[0] + + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + read_set = set() + write_set = set() + for block in self.nodes(): + for edge in self.in_edges(block): + read_set |= edge.data.free_symbols & self.sdfg.arrays.keys() + rs, ws = block.read_and_write_sets() + read_set.update(rs) + write_set.update(ws) + return read_set, write_set + + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + data_args = {} + scalar_args = {} + for block in self.nodes(): + n_data_args, n_scalar_args = block.unordered_arglist(defined_syms, shared_transients) + data_args.update(n_data_args) + scalar_args.update(n_scalar_args) + return data_args, scalar_args + + def top_level_transients(self) -> Set[str]: + res = set() + for block in self.nodes(): + res.update(block.top_level_transients()) + return res + + def all_transients(self) -> List[str]: + res = [] + for block in self.nodes(): + res.extend(block.all_transients()) + return dtypes.deduplicate(res) + + def replace(self, name: str, new_name: str): + for n in self.nodes(): + n.replace(name, new_name) + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, replace_keys: bool = False): + symrepl = symrepl or { + symbolic.symbol(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v + for k, v in repl.items() + } + + if replace_in_graph: + # Replace in inter-state edges + for edge in self.edges(): + edge.data.replace_dict(repl, replace_keys=replace_keys) + + # Replace in states + for state in self.nodes(): + state.replace_dict(repl, symrepl) + +@make_properties +class ControlFlowBlock(BlockGraphView, abc.ABC): + + is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) + + _label: str + + def __init__(self, label: str=''): + super(ControlFlowBlock, self).__init__() + self._label = label + self._default_lineinfo = None + self.is_collapsed = False + + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): + """ + Sets the default source line information to be lineinfo, or None to + revert to default mode. + """ + self._default_lineinfo = lineinfo + + def to_json(self, parent=None): + tmp = { + 'type': self.__class__.__name__, + 'collapsed': self.is_collapsed, + 'label': self._label, + 'id': parent.node_id(self) if parent is not None else None, + } + return tmp + + def __str__(self): + return self._label + + def __repr__(self) -> str: + return f'ControlFlowBlock ({self.label})' + + @property + def label(self) -> str: + return self._label + + @label.setter + def label(self, label: str): + self._label = label + + @property + def name(self) -> str: + return self._label + + +@make_properties +class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlock, DataflowGraphView): """ An acyclic dataflow multigraph in an SDFG, corresponding to a single state in the SDFG state machine. """ - is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) - nosync = Property(dtype=bool, default=False, desc="Do not synchronize at the end of the state") instrument = EnumProperty(dtype=dtypes.InstrumentationType, @@ -803,13 +1155,14 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): :param debuginfo: Source code locator for debugging. """ from dace.sdfg.sdfg import SDFG # Avoid import loop + OrderedMultiDiConnectorGraph.__init__(self) + ControlFlowBlock.__init__(self, label) super(SDFGState, self).__init__() self._label = label self._parent: SDFG = sdfg self._graph = self # Allowing MemletTrackingView mixin to work self._clear_scopedict_cache() self._debuginfo = debuginfo - self.is_collapsed = False self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None @@ -839,33 +1192,12 @@ def parent(self): def parent(self, value): self._parent = value - def __str__(self): - return self._label - - @property - def label(self): - return self._label - - @property - def name(self): - return self._label - - def set_label(self, label): - self._label = label - def is_empty(self): return self.number_of_nodes() == 0 def validate(self) -> None: validate_state(self) - def set_default_lineinfo(self, lineinfo: dtypes.DebugInfo): - """ - Sets the default source line information to be lineinfo, or None to - revert to default mode. - """ - self._default_lineinfo = lineinfo - def nodes(self) -> List[nd.Node]: # Added for type hints return super().nodes() @@ -1984,8 +2316,244 @@ def fill_scope_connectors(self): node.add_in_connector(edge.dst_conn) -class StateSubgraphView(SubgraphView, StateGraphView): +class StateSubgraphView(SubgraphView, DataflowGraphView): """ A read-only subgraph view of an SDFG state. """ def __init__(self, graph, subgraph_nodes): super().__init__(graph, subgraph_nodes) + + +@make_properties +class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, + ControlFlowBlock): + + def __init__(self, + label: str=''): + OrderedDiGraph.__init__(self) + ControlGraphView.__init__(self) + ControlFlowBlock.__init__(self, label) + + self._labels: Set[str] = set() + self._start_block: Optional[int] = None + self._cached_start_block: Optional[ControlFlowBlock] = None + + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): + """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. + + :param u: Source node. + :param v: Destination node. + :param edge: The edge to add. + """ + if not isinstance(src, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) + if not isinstance(dst, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) + if not isinstance(data, dace.sdfg.InterstateEdge): + raise TypeError('Expected InterstateEdge, got ' + str(type(data))) + if dst is self._cached_start_block: + self._cached_start_block = None + return super().add_edge(src, dst, data) + + def add_node(self, node, is_start_block=False, *, is_start_state: bool=None): + if not isinstance(node, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) + super().add_node(node) + self._cached_start_block = None + start_block = is_start_block + if is_start_state is not None: + warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) + start_block = is_start_state + + if start_block: + self.start_block = len(self.nodes()) - 1 + self._cached_start_block = node + + def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=None) -> SDFGState: + if self._labels is None or len(self._labels) != self.number_of_nodes(): + self._labels = set(s.label for s in self.nodes()) + label = label or 'state' + existing_labels = self._labels + label = dt.find_new_name(label, existing_labels) + state = SDFGState(label) + state.parent = self + self._labels.add(label) + start_block = is_start_block + if is_start_state is not None: + warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) + start_block = is_start_state + self.add_node(state, is_start_block=start_block) + return state + + def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. + + :param state: The state to prepend the new state before. + :param label: State label. + :param is_start_state: If True, resets scope block starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.in_edges(state): + self.remove_edge(e) + self.add_edge(e.src, new_state, e.data) + # Add unconditional connection between the new state and the current + self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + return new_state + + def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. + + :param state: The state to append the new state after. + :param label: State label. + :param is_start_state: If True, resets SDFG starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.out_edges(state): + self.remove_edge(e) + self.add_edge(new_state, e.dst, e.data) + # Add unconditional connection between the current and the new state + self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + return new_state + + @abc.abstractmethod + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment + + try: + ordered_blocks = self.topological_sort(self.start_block) + except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) + ordered_blocks = self.nodes() + + for block in ordered_blocks: + state_symbols = set() + if isinstance(block, ControlFlowRegion): + b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms + state_symbols = b_free_syms + else: + state_symbols = block.used_symbols(all_symbols) + free_syms |= state_symbols + + # Add free inter-state symbols + for e in self.out_edges(block): + # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by + # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly + # compute the symbols that are used before being assigned. + efsyms = e.data.used_symbols(all_symbols) + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) + used_before_assignment.update(efsyms - defined_syms) + free_syms |= efsyms + + # Remove symbols that were used before they were assigned. + defined_syms -= used_before_assignment + + if isinstance(self, dace.SDFG): + # Remove from defined symbols those that are in the symbol mapping + if self.parent_nsdfg_node is not None and keep_defined_in_mapping: + defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) + + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + + # Subtract symbols defined in inter-state edges and constants from the list of free symbols. + free_syms -= defined_syms + + return free_syms, defined_syms, used_before_assignment + + def to_json(self, parent=None): + graph_json = OrderedDiGraph.to_json(self) + block_json = ControlFlowBlock.to_json(self, parent) + graph_json.update(block_json) + return graph_json + + ################################################################### + # Traversal methods + + def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegion']: + """ Iterate over this and all nested control flow regions. """ + yield self + for block in self.nodes(): + if isinstance(block, SDFGState) and recursive: + for node in block.nodes(): + if isinstance(node, nd.NestedSDFG): + yield from node.sdfg.all_control_flow_regions(recursive=recursive) + elif isinstance(block, ControlFlowRegion): + yield from block.all_control_flow_regions(recursive=recursive) + + def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: + """ Iterate over this and all nested SDFGs. """ + for cfg in self.all_control_flow_regions(recursive=True): + if isinstance(cfg, dace.SDFG): + yield cfg + + def all_states(self) -> Iterator[SDFGState]: + """ Iterate over all states in this control flow graph. """ + for block in self.nodes(): + if isinstance(block, SDFGState): + yield block + elif isinstance(block, ControlFlowRegion): + yield from block.all_states() + + def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: + """ Iterate over all control flow blocks in this control flow graph. """ + for cfg in self.all_control_flow_regions(recursive=recursive): + for block in cfg.nodes(): + yield block + + def all_interstate_edges(self, recursive=False) -> Iterator[Edge['dace.sdfg.InterstateEdge']]: + """ Iterate over all interstate edges in this control flow graph. """ + for cfg in self.all_control_flow_regions(recursive=recursive): + for edge in cfg.edges(): + yield edge + + ################################################################### + # Getters & setters, overrides + + def __str__(self): + return ControlFlowBlock.__str__(self) + + def __repr__(self) -> str: + return f'{self.__class__.__name__} ({self.label})' + + @property + def start_block(self): + """ Returns the starting block of this ControlFlowGraph. """ + if self._cached_start_block is not None: + return self._cached_start_block + + source_nodes = self.source_nodes() + if len(source_nodes) == 1: + self._cached_start_block = source_nodes[0] + return source_nodes[0] + # If the starting block is ambiguous allow manual override. + if self._start_block is not None: + self._cached_start_block = self.node(self._start_block) + return self._cached_start_block + raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' + 'please use "is_start_block=True" when adding the ' + 'starting block with "add_state" or "add_node"') + + @start_block.setter + def start_block(self, block_id): + """ Manually sets the starting block of this ControlFlowGraph. + + :param block_id: The node ID (use `node_id(block)`) of the block to set. + """ + if block_id < 0 or block_id >= self.number_of_nodes(): + raise ValueError('Invalid state ID') + self._start_block = block_id + self._cached_start_block = self.node(block_id) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 1078414161..621f8a9e16 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -668,7 +668,7 @@ def consolidate_edges(sdfg: SDFG, starting_scope=None) -> int: from dace.sdfg.propagation import propagate_memlets_scope total_consolidated = 0 - for state in sdfg.nodes(): + for state in sdfg.states(): # Start bottom-up if starting_scope and starting_scope.entry not in state.nodes(): continue @@ -1206,8 +1206,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> counter = 0 if progress is True or progress is None: fusible_states = 0 - for sd in sdfg.all_sdfgs_recursive(): - fusible_states += sd.number_of_edges() + for cfg in sdfg.all_control_flow_regions(): + fusible_states += cfg.number_of_edges() if progress is True: pbar = tqdm(total=fusible_states, desc='Fusing states') @@ -1217,30 +1217,32 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id - while True: - edges = list(sd.nx.edges) - applied = 0 - skip_nodes = set() - for u, v in edges: - if (progress is None and tqdm is not None and (time.time() - start) > 5): - progress = True - pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) - - if u in skip_nodes or v in skip_nodes: - continue - candidate = {StateFusion.first_state: u, StateFusion.second_state: v} - sf = StateFusion() - sf.setup_match(sd, id, -1, candidate, 0, override=True) - if sf.can_be_applied(sd, 0, sd, permissive=permissive): - sf.apply(sd, sd) - applied += 1 - counter += 1 - if progress: - pbar.update(1) - skip_nodes.add(u) - skip_nodes.add(v) - if applied == 0: - break + for cfg in sd.all_control_flow_regions(): + while True: + edges = list(cfg.nx.edges) + applied = 0 + skip_nodes = set() + for u, v in edges: + if (progress is None and tqdm is not None and (time.time() - start) > 5): + progress = True + pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) + + if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) or + not isinstance(u, SDFGState)): + continue + candidate = {StateFusion.first_state: u, StateFusion.second_state: v} + sf = StateFusion() + sf.setup_match(cfg, id, -1, candidate, 0, override=True) + if sf.can_be_applied(cfg, 0, sd, permissive=permissive): + sf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) + if applied == 0: + break if progress: pbar.close() return counter diff --git a/dace/subsets.py b/dace/subsets.py index f8b66a565d..068b330a07 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -10,21 +10,52 @@ from dace.config import Config +def nng(expr): + # When dealing with set sizes, assume symbols are non-negative + try: + # TODO: Fix in symbol definition, not here + for sym in list(expr.free_symbols): + expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)}) + return expr + except AttributeError: # No free_symbols in expr + return expr + +def bounding_box_cover_exact(subset_a, subset_b) -> bool: + return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True + and (symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))) == True + for rb, re, orb, ore in zip(subset_a.min_element(), subset_a.max_element(), + subset_b.min_element(), subset_b.max_element())]) + +def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool: + min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element() + max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element() + min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element() + max_elements_b = subset_b.max_element_approx() if approximation else subset_b.max_element() + + for rb, re, orb, ore in zip(min_elements_a, max_elements_a, + min_elements_b, max_elements_b): + # NOTE: We first test for equality, which always returns True or False. If the equality test returns + # False, then we test for less-equal and greater-equal, which may return an expression, leading to + # TypeError. This is a workaround for the case where two expressions are the same or equal and + # SymPy confirms this but fails to return True when testing less-equal and greater-equal. + + # lower bound: first check whether symbolic positive condition applies + if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1): + if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or + symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))): + return False + # upper bound: first check whether symbolic positive condition applies + if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0): + if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or + symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))): + return False + return True + class Subset(object): """ Defines a subset of a data descriptor. """ def covers(self, other): """ Returns True if this subset covers (using a bounding box) another subset. """ - def nng(expr): - # When dealing with set sizes, assume symbols are non-negative - try: - # TODO: Fix in symbol definition, not here - for sym in list(expr.free_symbols): - expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)}) - return expr - except AttributeError: # No free_symbols in expr - return expr - symbolic_positive = Config.get('optimizer', 'symbolic_positive') if not symbolic_positive: @@ -38,28 +69,65 @@ def nng(expr): else: try: - for rb, re, orb, ore in zip(self.min_element_approx(), self.max_element_approx(), - other.min_element_approx(), other.max_element_approx()): - # NOTE: We first test for equality, which always returns True or False. If the equality test returns - # False, then we test for less-equal and greater-equal, which may return an expression, leading to - # TypeError. This is a workaround for the case where two expressions are the same or equal and - # SymPy confirms this but fails to return True when testing less-equal and greater-equal. - - # lower bound: first check whether symbolic positive condition applies - if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1): - if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or - symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))): - return False - - # upper bound: first check whether symbolic positive condition applies - if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0): - if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or - symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))): - return False + if not bounding_box_symbolic_positive(self, other, True): + return False except TypeError: return False return True + + def covers_precise(self, other): + """ Returns True if self contains all the elements in other. """ + + # If self does not cover other with a bounding box union, return false. + symbolic_positive = Config.get('optimizer', 'symbolic_positive') + try: + bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other) + if not bounding_box_cover: + return False + except TypeError: + return False + + try: + # if self is an index no further distinction is needed + if isinstance(self, Indices): + return True + + elif isinstance(self, Range): + # other is an index so we need to check if the step of self is such that other is covered + # self.start % self.step == other.index % self.step + if isinstance(other, Indices): + try: + return all( + [(symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext(nng(step)) == + symbolic.simplify_ext(nng(i)) % symbolic.simplify_ext(nng(step))) == True + for (start, _, step), i in zip(self.ranges, other.indices)]) + except: + return False + if isinstance(other, Range): + # other is a range so in every dimension self.step has to divide other.step and + # self.start % self.step = other.start % other.step + try: + self_steps = [r[2] for r in self.ranges] + other_steps = [r[2] for r in other.ranges] + for start, step, ostart, ostep in zip(self.min_element(), self_steps, other.min_element(), + other_steps): + if not (ostep % step == 0 and + ((symbolic.simplify_ext(nng(start)) == symbolic.simplify_ext(nng(ostart))) or + (symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext( + nng(step)) == symbolic.simplify_ext(nng(ostart)) % symbolic.simplify_ext( + nng(ostep))) == True)): + return False + except: + return False + return True + # unknown type + else: + raise TypeError + + except TypeError: + return False + def __repr__(self): return '%s (%s)' % (type(self).__name__, self.__str__()) @@ -973,6 +1041,111 @@ def intersection(self, other: 'Indices'): return self return None +class SubsetUnion(Subset): + """ + Wrapper subset type that stores multiple Subsets in a list. + """ + + def __init__(self, subset): + self.subset_list: list[Subset] = [] + if isinstance(subset, SubsetUnion): + self.subset_list = subset.subset_list + elif isinstance(subset, list): + for subset in subset: + if not subset: + break + if isinstance(subset, (Range, Indices)): + self.subset_list.append(subset) + else: + raise NotImplementedError + elif isinstance(subset, (Range, Indices)): + self.subset_list = [subset] + + def covers(self, other): + """ + Returns True if this SubsetUnion covers another subset (using a bounding box). + If other is another SubsetUnion then self and other will + only return true if self is other. If other is a different type of subset + true is returned when one of the subsets in self is equal to other. + """ + + if isinstance(other, SubsetUnion): + for subset in self.subset_list: + # check if ther is a subset in self that covers every subset in other + if all(subset.covers(s) for s in other.subset_list): + return True + # return False if that's not the case for any of the subsets in self + return False + else: + return any(s.covers(other) for s in self.subset_list) + + def covers_precise(self, other): + """ + Returns True if this SubsetUnion covers another + subset. If other is another SubsetUnion then self and other will + only return true if self is other. If other is a different type of subset + true is returned when one of the subsets in self is equal to other + """ + + if isinstance(other, SubsetUnion): + for subset in self.subset_list: + # check if ther is a subset in self that covers every subset in other + if all(subset.covers_precise(s) for s in other.subset_list): + return True + # return False if that's not the case for any of the subsets in self + return False + else: + return any(s.covers_precise(other) for s in self.subset_list) + + def __str__(self): + string = '' + for subset in self.subset_list: + if not string == '': + string += " " + string += subset.__str__() + return string + + def dims(self): + if not self.subset_list: + return 0 + return next(iter(self.subset_list)).dims() + + def union(self, other: Subset): + """In place union of self with another Subset""" + try: + if isinstance(other, SubsetUnion): + self.subset_list += other.subset_list + elif isinstance(other, Indices) or isinstance(other, Range): + self.subset_list.append(other) + else: + raise TypeError + except TypeError: # cannot determine truth value of Relational + return None + + @property + def free_symbols(self) -> Set[str]: + result = set() + for subset in self.subset_list: + result |= subset.free_symbols + return result + + def replace(self, repl_dict): + for subset in self.subset_list: + subset.replace(repl_dict) + + def num_elements(self): + # TODO: write something more meaningful here + min = 0 + for subset in self.subset_list: + try: + if subset.num_elements() < min or min ==0: + min = subset.num_elements() + except: + continue + + return min + + def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType, bre: symbolic.SymbolicType): @@ -1038,6 +1211,8 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range: return Range(result) + + def union(subset_a: Subset, subset_b: Subset) -> Subset: """ Compute the union of two Subset objects. If the subsets are not of the same type, degenerates to bounding-box @@ -1056,6 +1231,9 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset: return subset_b elif subset_a is None and subset_b is None: raise TypeError('Both subsets cannot be None') + elif isinstance(subset_a, SubsetUnion) or isinstance( + subset_b, SubsetUnion): + return list_union(subset_a, subset_b) elif type(subset_a) != type(subset_b): return bounding_box_union(subset_a, subset_b) elif isinstance(subset_a, Indices): @@ -1066,13 +1244,43 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset: # TODO(later): More involved Strided-Tiled Range union return bounding_box_union(subset_a, subset_b) else: - warnings.warn('Unrecognized Subset type %s in union, degenerating to' - ' bounding box' % type(subset_a).__name__) + warnings.warn( + 'Unrecognized Subset type %s in union, degenerating to' + ' bounding box' % type(subset_a).__name__) return bounding_box_union(subset_a, subset_b) except TypeError: # cannot determine truth value of Relational return None +def list_union(subset_a: Subset, subset_b: Subset) -> Subset: + """ + Returns the union of two Subset lists. + + :param subset_a: The first subset. + :param subset_b: The second subset. + :return: A SubsetUnion object that contains all elements of subset_a and subset_b. + """ + # TODO(later): Merge subsets in both lists if possible + try: + if subset_a is not None and subset_b is None: + return subset_a + elif subset_b is not None and subset_a is None: + return subset_b + elif subset_a is None and subset_b is None: + raise TypeError('Both subsets cannot be None') + elif type(subset_a) != type(subset_b): + if isinstance(subset_b, SubsetUnion): + return SubsetUnion(subset_b.subset_list.append(subset_a)) + else: + return SubsetUnion(subset_a.subset_list.append(subset_b)) + elif isinstance(subset_a, SubsetUnion): + return SubsetUnion(subset_a.subset_list + subset_b.subset_list) + else: + return SubsetUnion([subset_a, subset_b]) + + except TypeError: + return None + def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]: """ Returns True if two subsets intersect, False if they do not, or diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 54dbc8d4ac..bb384cfd9a 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -515,11 +515,29 @@ def make_transients_persistent(sdfg: SDFG, return result +def apply_gpu_storage(sdfg: SDFG) -> None: + """ Changes the storage of the SDFG's input and output data to GPU global memory. """ + + written_scalars = set() + for state in sdfg.nodes(): + for node in state.data_nodes(): + desc = node.desc(sdfg) + if isinstance(desc, dt.Scalar) and not desc.transient and state.in_degree(node) > 0: + written_scalars.add(node.data) + + for name, desc in sdfg.arrays.items(): + if not desc.transient and desc.storage == dtypes.StorageType.Default: + if isinstance(desc, dt.Scalar) and not name in written_scalars: + continue + desc.storage = dtypes.StorageType.GPU_Global + + def auto_optimize(sdfg: SDFG, device: dtypes.DeviceType, validate: bool = True, validate_all: bool = False, - symbols: Dict[str, int] = None) -> SDFG: + symbols: Dict[str, int] = None, + use_gpu_storage: bool = False) -> SDFG: """ Runs a basic sequence of transformations to optimize a given SDFG to decent performance. In particular, performs the following: @@ -539,6 +557,7 @@ def auto_optimize(sdfg: SDFG, have been applied. :param validate_all: If True, validates the SDFG after every step. :param symbols: Optional dict that maps symbols (str/symbolic) to int/float + :param use_gpu_storage: If True, changes the storage of non-transient data to GPU global memory. :return: The optimized SDFG. :note: Operates in-place on the given SDFG. :note: This function is still experimental and may harm correctness in @@ -565,6 +584,8 @@ def auto_optimize(sdfg: SDFG, # Apply GPU transformations and set library node implementations if device == dtypes.DeviceType.GPU: + if use_gpu_storage: + apply_gpu_storage(sdfg) sdfg.apply_gpu_transformations() sdfg.simplify() @@ -625,7 +646,6 @@ def auto_optimize(sdfg: SDFG, if symbols: # Specialize for all known symbols - known_symbols = {s: v for (s, v) in symbols.items() if s in sdfg.free_symbols} known_symbols = {} for (s, v) in symbols.items(): if s in sdfg.free_symbols: diff --git a/dace/transformation/change_strides.py b/dace/transformation/change_strides.py new file mode 100644 index 0000000000..001cd4aa63 --- /dev/null +++ b/dace/transformation/change_strides.py @@ -0,0 +1,210 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" This module provides a function to change the stride in a given SDFG """ +from typing import List, Union, Tuple +import sympy + +import dace +from dace.dtypes import ScheduleType +from dace.sdfg import SDFG, nodes, SDFGState +from dace.data import Array, Scalar +from dace.memlet import Memlet + + +def list_access_nodes( + sdfg: dace.SDFG, + array_name: str) -> List[Tuple[nodes.AccessNode, Union[SDFGState, dace.SDFG]]]: + """ + Find all access nodes in the SDFG of the given array name. Does not recourse into nested SDFGs. + + :param sdfg: The SDFG to search through + :type sdfg: dace.SDFG + :param array_name: The name of the wanted array + :type array_name: str + :return: List of the found access nodes together with their state + :rtype: List[Tuple[nodes.AccessNode, Union[dace.SDFGState, dace.SDFG]]] + """ + found_nodes = [] + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.AccessNode) and node.data == array_name: + found_nodes.append((node, state)) + return found_nodes + + +def change_strides( + sdfg: dace.SDFG, + stride_one_values: List[str], + schedule: ScheduleType) -> SDFG: + """ + Change the strides of the arrays on the given SDFG such that the given dimension has stride 1. Returns a new SDFG. + + :param sdfg: The input SDFG + :type sdfg: dace.SDFG + :param stride_one_values: Length of the dimension whose stride should be set to one. Expects that each array has + only one dimension whose length is in this list. Expects that list contains name of symbols + :type stride_one_values: List[str] + :param schedule: Schedule to use to copy the arrays + :type schedule: ScheduleType + :return: SDFG with changed strides + :rtype: SDFG + """ + # Create new SDFG and copy constants and symbols + original_name = sdfg.name + sdfg.name = "changed_strides" + new_sdfg = SDFG(original_name) + for dname, value in sdfg.constants.items(): + new_sdfg.add_constant(dname, value) + for dname, stype in sdfg.symbols.items(): + new_sdfg.add_symbol(dname, stype) + + changed_stride_state = new_sdfg.add_state("with_changed_strides", is_start_state=True) + inputs, outputs = sdfg.read_and_write_sets() + # Get all arrays which are persistent == not transient + persistent_arrays = {name: desc for name, desc in sdfg.arrays.items() if not desc.transient} + + # Get the persistent arrays of all the transient arrays which get copied to GPU + for dname in persistent_arrays: + for access, state in list_access_nodes(sdfg, dname): + if len(state.out_edges(access)) == 1: + edge = state.out_edges(access)[0] + if isinstance(edge.dst, nodes.AccessNode): + if edge.dst.data in inputs: + inputs.remove(edge.dst.data) + inputs.add(dname) + if len(state.in_edges(access)) == 1: + edge = state.in_edges(access)[0] + if isinstance(edge.src, nodes.AccessNode): + if edge.src.data in inputs: + outputs.remove(edge.src.data) + outputs.add(dname) + + # Only keep inputs and outputs which are persistent + inputs.intersection_update(persistent_arrays.keys()) + outputs.intersection_update(persistent_arrays.keys()) + nsdfg = changed_stride_state.add_nested_sdfg(sdfg, new_sdfg, inputs=inputs, outputs=outputs) + transform_state = new_sdfg.add_state_before(changed_stride_state, label="transform_data", is_start_state=True) + transform_state_back = new_sdfg.add_state_after(changed_stride_state, "transform_data_back", is_start_state=False) + + # copy arrays + for dname, desc in sdfg.arrays.items(): + if not desc.transient: + if isinstance(desc, Array): + new_sdfg.add_array(dname, desc.shape, desc.dtype, desc.storage, + desc.location, desc.transient, desc.strides, + desc.offset) + elif isinstance(desc, Scalar): + new_sdfg.add_scalar(dname, desc.dtype, desc.storage, desc.transient, desc.lifetime, desc.debuginfo) + + new_order = {} + new_strides_map = {} + + # Map of array names in the nested sdfg: key: array name in parent sdfg (this sdfg), value: name in the nsdfg + # Assumes that name changes only appear in the first level of nsdfg nesting + array_names_map = {} + for graph in sdfg.sdfg_list: + if graph.parent_nsdfg_node is not None: + if graph.parent_sdfg == sdfg: + for connector in graph.parent_nsdfg_node.in_connectors: + for in_edge in graph.parent.in_edges_by_connector(graph.parent_nsdfg_node, connector): + array_names_map[str(connector)] = in_edge.data.data + + for containing_sdfg, dname, desc in sdfg.arrays_recursive(): + shape_str = [str(s) for s in desc.shape] + # Get index of the dimension we want to have stride 1 + stride_one_idx = None + this_stride_one_value = None + for dim in stride_one_values: + if str(dim) in shape_str: + stride_one_idx = shape_str.index(str(dim)) + this_stride_one_value = dim + break + + if stride_one_idx is not None: + new_order[dname] = [stride_one_idx] + + new_strides = list(desc.strides) + new_strides[stride_one_idx] = sympy.S.One + + previous_size = dace.symbolic.symbol(this_stride_one_value) + previous_stride = sympy.S.One + for i in range(len(new_strides)): + if i != stride_one_idx: + new_order[dname].append(i) + new_strides[i] = previous_size * previous_stride + previous_size = desc.shape[i] + previous_stride = new_strides[i] + + new_strides_map[dname] = {} + # Create a map entry for this data linking old strides to new strides. This assumes that each entry in + # strides is unique which is given as otherwise there would be two dimension i, j where a[i, j] would point + # to the same address as a[j, i] + for new_stride, old_stride in zip(new_strides, desc.strides): + new_strides_map[dname][old_stride] = new_stride + desc.strides = tuple(new_strides) + else: + parent_name = array_names_map[dname] if dname in array_names_map else dname + if parent_name in new_strides_map: + new_strides = [] + for stride in desc.strides: + new_strides.append(new_strides_map[parent_name][stride]) + desc.strides = new_strides + + # Add new flipped arrays for every non-transient array + flipped_names_map = {} + for dname, desc in sdfg.arrays.items(): + if not desc.transient: + flipped_name = f"{dname}_flipped" + flipped_names_map[dname] = flipped_name + new_sdfg.add_array(flipped_name, desc.shape, desc.dtype, + desc.storage, desc.location, True, + desc.strides, desc.offset) + + # Deal with the inputs: Create tasklet to flip them and connect via memlets + # for input in inputs: + for input in set([*inputs, *outputs]): + if input in new_order: + flipped_data = flipped_names_map[input] + if input in inputs: + changed_stride_state.add_memlet_path(changed_stride_state.add_access(flipped_data), nsdfg, + dst_conn=input, memlet=Memlet(data=flipped_data)) + # Simply need to copy the data, the different strides take care of the transposing + arr = sdfg.arrays[input] + tasklet, map_entry, map_exit = transform_state.add_mapped_tasklet( + name=f"transpose_{input}", + map_ranges={f"_i{i}": f"0:{s}" for i, s in enumerate(arr.shape)}, + inputs={'_in': Memlet(data=input, subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + code='_out = _in', + outputs={'_out': Memlet(data=flipped_data, + subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + external_edges=True, + schedule=schedule, + ) + # Do the same for the outputs + for output in outputs: + if output in new_order: + flipped_data = flipped_names_map[output] + changed_stride_state.add_memlet_path(nsdfg, changed_stride_state.add_access(flipped_data), + src_conn=output, memlet=Memlet(data=flipped_data)) + # Simply need to copy the data, the different strides take care of the transposing + arr = sdfg.arrays[output] + tasklet, map_entry, map_exit = transform_state_back.add_mapped_tasklet( + name=f"transpose_{output}", + map_ranges={f"_i{i}": f"0:{s}" for i, s in enumerate(arr.shape)}, + inputs={'_in': Memlet(data=flipped_data, + subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + code='_out = _in', + outputs={'_out': Memlet(data=output, subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + external_edges=True, + schedule=schedule, + ) + # Deal with any arrays which have not been flipped (should only be scalars). Connect them directly + for dname, desc in sdfg.arrays.items(): + if not desc.transient and dname not in new_order: + if dname in inputs: + changed_stride_state.add_memlet_path(changed_stride_state.add_access(dname), nsdfg, dst_conn=dname, + memlet=Memlet(data=dname)) + if dname in outputs: + changed_stride_state.add_memlet_path(nsdfg, changed_stride_state.add_access(dname), src_conn=dname, + memlet=Memlet(data=dname)) + + return new_sdfg diff --git a/dace/transformation/dataflow/double_buffering.py b/dace/transformation/dataflow/double_buffering.py index 8ff70a6355..6efe6543ca 100644 --- a/dace/transformation/dataflow/double_buffering.py +++ b/dace/transformation/dataflow/double_buffering.py @@ -128,7 +128,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add initial reads to initial nested state initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state - initial_state.set_label('%s_init' % map_entry.map.label) + initial_state.label = '%s_init' % map_entry.map.label for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src @@ -152,7 +152,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): # Add the main state's contents to the last state, modifying # memlets appropriately. final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] - final_state.set_label('%s_final_computation' % map_entry.map.label) + final_state.label = '%s_final_computation' % map_entry.map.label dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): @@ -183,7 +183,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, new_memlet) - nstate.set_label('%s_double_buffered' % map_entry.map.label) + nstate.label = '%s_double_buffered' % map_entry.map.label # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 275b99c7e8..60f1f13f32 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -1,16 +1,18 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Contains classes that implement the map-expansion transformation. """ from dace.sdfg.utils import consolidate_edges from typing import Dict, List import dace from dace import dtypes, subsets, symbolic +from dace.properties import EnumProperty, make_properties from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.graph import OrderedMultiDiConnectorGraph from dace.transformation import transformation as pm +@make_properties class MapExpansion(pm.SingleStateTransformation): """ Implements the map-expansion pattern. @@ -25,14 +27,16 @@ class MapExpansion(pm.SingleStateTransformation): map_entry = pm.PatternNode(nodes.MapEntry) + inner_schedule = EnumProperty(desc="Schedule for inner maps", + dtype=dtypes.ScheduleType, + default=dtypes.ScheduleType.Sequential, + allow_none=True) + @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.map_entry)] - def can_be_applied(self, graph: dace.SDFGState, - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False): + def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG, permissive: bool = False): # A candidate subgraph matches the map-expansion pattern when it # includes an N-dimensional map, with N greater than one. return self.map_entry.map.get_param_num() > 1 @@ -44,10 +48,11 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): current_map = map_entry.map # Create new maps + inner_schedule = self.inner_schedule or current_map.schedule new_maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), - schedule=dtypes.ScheduleType.Sequential) + schedule=inner_schedule) for param, param_range in zip(current_map.params[1:], current_map.range[1:]) ] current_map.params = [current_map.params[0]] diff --git a/dace/transformation/dataflow/otf_map_fusion.py b/dace/transformation/dataflow/otf_map_fusion.py index b2e5710942..f41e3b4e0b 100644 --- a/dace/transformation/dataflow/otf_map_fusion.py +++ b/dace/transformation/dataflow/otf_map_fusion.py @@ -289,14 +289,17 @@ def apply(self, graph: SDFGState, sdfg: SDFG): for edge in graph.edges_between(first_map_entry, node): memlet = copy.deepcopy(edge.data) - in_connector = edge.src_conn.replace("OUT", "IN") - if in_connector in connector_mapping: - out_connector = connector_mapping[in_connector].replace("IN", "OUT") + if edge.src_conn is not None: + in_connector = edge.src_conn.replace("OUT", "IN") + if in_connector in connector_mapping: + out_connector = connector_mapping[in_connector].replace("IN", "OUT") + else: + out_connector = edge.src_conn + + if out_connector not in self.second_map_entry.out_connectors: + self.second_map_entry.add_out_connector(out_connector) else: - out_connector = edge.src_conn - - if out_connector not in self.second_map_entry.out_connectors: - self.second_map_entry.add_out_connector(out_connector) + out_connector = None graph.add_edge(self.second_map_entry, out_connector, node, edge.dst_conn, memlet) graph.remove_edge(edge) diff --git a/dace/transformation/dataflow/trivial_map_elimination.py b/dace/transformation/dataflow/trivial_map_elimination.py index 327d5d8c9a..9387cfce23 100644 --- a/dace/transformation/dataflow/trivial_map_elimination.py +++ b/dace/transformation/dataflow/trivial_map_elimination.py @@ -5,6 +5,7 @@ from dace.sdfg import utils as sdutil from dace.transformation import transformation from dace.properties import make_properties +from dace.memlet import Memlet @make_properties @@ -48,12 +49,15 @@ def apply(self, graph, sdfg): if len(remaining_ranges) == 0: # Redirect map entry's out edges + write_only_map = True for edge in graph.out_edges(map_entry): path = graph.memlet_path(edge) index = path.index(edge) - # Add an edge directly from the previous source connector to the destination - graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) + if not edge.data.is_empty(): + # Add an edge directly from the previous source connector to the destination + graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) + write_only_map = False # Redirect map exit's in edges. for edge in graph.in_edges(map_exit): @@ -63,6 +67,11 @@ def apply(self, graph, sdfg): # Add an edge directly from the source to the next destination connector if len(path) > index + 1: graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data) + if write_only_map: + outer_exit = path[index+1].dst + outer_entry = graph.entry_node(outer_exit) + if outer_entry is not None: + graph.add_edge(outer_entry, None, edge.src, None, Memlet()) # Remove map graph.remove_nodes_from([map_entry, map_exit]) diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index e95674adc1..7f4fbc654d 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -2,10 +2,14 @@ """ Transformations to convert subgraphs to write-conflict resolutions. """ import ast import re -from dace import registry, nodes, dtypes +import copy +from dace import registry, nodes, dtypes, Memlet from dace.transformation import transformation, helpers as xfh from dace.sdfg import graph as gr, utils as sdutil from dace import SDFG, SDFGState +from dace.sdfg.state import StateSubgraphView +from dace.transformation import helpers +from dace.sdfg.propagation import propagate_memlets_state class AugAssignToWCR(transformation.SingleStateTransformation): @@ -20,6 +24,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): map_exit = transformation.PatternNode(nodes.MapExit) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] + _FUNCTIONS = ['min', 'max'] _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} @@ -27,6 +32,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): def expressions(cls): return [ sdutil.node_path_graph(cls.input, cls.tasklet, cls.output), + sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): @@ -38,7 +44,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Free tasklet if expr_index == 0: - # Only free tasklets supported for now if graph.entry_node(tasklet) is not None: return False @@ -49,8 +54,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Make sure augmented assignment can be fissioned as necessary if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0: - return False outedge = graph.edges_between(tasklet, outarr)[0] else: # Free map @@ -65,12 +68,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if len(graph.edges_between(tasklet, mx)) > 1: return False - # Currently no fission is supported + # Make sure augmented assignment can be fissioned as necessary if any(e.src is not me and not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(me) + graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0: - return False outedge = graph.edges_between(tasklet, mx)[0] @@ -78,6 +79,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) if tasklet.language is dtypes.Language.Python: # Match a single assignment with a binary operation as RHS @@ -108,18 +110,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Try to match a single C assignment that can be converted to WCR inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) - rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) - if re.match(lhs, cstr) is None: - continue + # rhs: a = (...) op b + rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None: + if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops, + re.escape(inconn)) + if re.match(rhs2, cstr) is None: + continue + # Same memlet if edge.data.subset != outedge.data.subset: continue # If in map, only match if the subset is independent of any # map indices (otherwise no conflict) - if (expr_index == 1 and len(outedge.data.subset.free_symbols - & set(me.map.params)) == len(me.map.params)): - continue + if expr_index == 1: + if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len( + me.map.params): + continue return True else: @@ -132,50 +149,22 @@ def apply(self, state: SDFGState, sdfg: SDFG): input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output + if self.expr_index == 1: + me = self.map_entry + mx = self.map_exit # If state fission is necessary to keep semantics, do it first - if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): - newstate = sdfg.add_state_after(state) - newstate.add_node(tasklet) - new_input, new_output = None, None - - # Keep old edges for after we remove tasklet from the original state - in_edges = list(state.in_edges(tasklet)) - out_edges = list(state.out_edges(tasklet)) - - for e in in_edges: - r = newstate.add_read(e.src.data) - newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) - if e.src is input: - new_input = r - for e in out_edges: - w = newstate.add_write(e.dst.data) - newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) - if e.dst is output: - new_output = w - - # Remove tasklet and resulting isolated nodes - state.remove_node(tasklet) - for e in in_edges: - if state.degree(e.src) == 0: - state.remove_node(e.src) - for e in out_edges: - if state.degree(e.dst) == 0: - state.remove_node(e.dst) - - # Reset state and nodes for rest of transformation - input = new_input - output = new_output - state = newstate - # End of state fission + if state.in_degree(input) > 0: + subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)]) + subgraph_nodes.add(input) + + subgraph = StateSubgraphView(state, subgraph_nodes) + helpers.state_fission(sdfg, subgraph) if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: - me = self.map_entry - mx = self.map_exit - inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] @@ -183,6 +172,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: @@ -206,13 +196,40 @@ def apply(self, state: SDFGState, sdfg: SDFG): inconn = edge.dst_conn match = re.match(r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: - # match = re.match( - # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % - # (re.escape(outconn), ops, re.escape(inconn)), cstr) - # if match is None: - continue - # op = match.group(2) - # expr = match.group(1) + match = re.match( + r'^\s*%s\s*=\s*\((.*)\)\s*(%s)\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)), cstr) + if match is None: + func_rhs = r'^\s*%s\s*=\s*(%s)\((.*),\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_rhs, cstr) + if match is None: + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,(.*)\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_lhs, cstr) + if match is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*(%s)\s*(%s)\s*%s;$' % ( + re.escape(outconn), re.escape(other_inconn), ops, re.escape(inconn)) + match = re.match(rhs2, cstr) + if match is None: + continue + else: + op = match.group(2) + expr = match.group(1) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(2) + expr = match.group(1) else: op = match.group(1) expr = match.group(2) @@ -232,16 +249,14 @@ def apply(self, state: SDFGState, sdfg: SDFG): raise NotImplementedError # Change output edge - outedge.data.wcr = f'lambda a,b: a {op} b' - - if self.expr_index == 0: - # Remove input node and connector - state.remove_edge_and_connectors(inedge) - if state.degree(input) == 0: - state.remove_node(input) + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' else: - # Remove input edge and dst connector, but not necessarily src - state.remove_memlet_path(inedge) + outedge.data.wcr = f'lambda a,b: a {op} b' + + # Remove input node and connector + state.remove_memlet_path(inedge) + propagate_memlets_state(sdfg, state) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards @@ -252,6 +267,9 @@ def apply(self, state: SDFGState, sdfg: SDFG): sd = sd.parent_sdfg outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): - outedge.data.wcr = f'lambda a,b: a {op} b' + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' + else: + outedge.data.wcr = f'lambda a,b: a {op} b' # At this point we are leading to an access node again and can # traverse further up diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 8986c4e37f..9c41e4dec4 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1137,7 +1137,8 @@ def traverse(state: SDFGState, treenode: ScopeTree): ntree.state = nstate treenode.children.append(ntree) for child in treenode.children: - traverse(getattr(child, 'state', state), child) + if hasattr(child, 'state') and child.state != state: + traverse(getattr(child, 'state', state), child) traverse(state, stree) return stree diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index 0bd168751c..b8bcc716e6 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -15,3 +15,4 @@ from .move_loop_into_map import MoveLoopIntoMap from .trivial_loop_elimination import TrivialLoopElimination from .multistate_inline import InlineMultistateSDFG +from .move_assignment_outside_if import MoveAssignmentOutsideIf diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index 47d438a2fc..b1dbfdd5c9 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -116,8 +116,7 @@ def instantiate_loop( # Replace iterate with value in each state for state in new_states: - state.set_label(state.label + '_' + itervar + '_' + - (state_suffix if state_suffix is not None else str(value))) + state.label = state.label + '_' + itervar + '_' + (state_suffix if state_suffix is not None else str(value)) state.replace(itervar, value) # Add subgraph to original SDFG diff --git a/dace/transformation/interstate/move_assignment_outside_if.py b/dace/transformation/interstate/move_assignment_outside_if.py new file mode 100644 index 0000000000..3d4db9ae25 --- /dev/null +++ b/dace/transformation/interstate/move_assignment_outside_if.py @@ -0,0 +1,113 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Transformation to move assignments outside if statements to potentially avoid warp divergence. Speedup gained is +questionable. +""" + +import ast +import sympy as sp + +from dace import sdfg as sd +from dace.sdfg import graph as gr +from dace.sdfg.nodes import Tasklet, AccessNode +from dace.transformation import transformation + + +class MoveAssignmentOutsideIf(transformation.MultiStateTransformation): + + if_guard = transformation.PatternNode(sd.SDFGState) + if_stmt = transformation.PatternNode(sd.SDFGState) + else_stmt = transformation.PatternNode(sd.SDFGState) + + @classmethod + def expressions(cls): + sdfg = gr.OrderedDiGraph() + sdfg.add_nodes_from([cls.if_guard, cls.if_stmt, cls.else_stmt]) + sdfg.add_edge(cls.if_guard, cls.if_stmt, sd.InterstateEdge()) + sdfg.add_edge(cls.if_guard, cls.else_stmt, sd.InterstateEdge()) + return [sdfg] + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # The if-guard can only have two outgoing edges: to the if and to the else part + guard_outedges = graph.out_edges(self.if_guard) + if len(guard_outedges) != 2: + return False + + # Outgoing edges must be a negation of each other + if guard_outedges[0].data.condition_sympy() != (sp.Not(guard_outedges[1].data.condition_sympy())): + return False + + # The if guard should either have zero or one incoming edge + if len(sdfg.in_edges(self.if_guard)) > 1: + return False + + # set of the variables which get a const value assigned + assigned_const = set() + # Dict which collects all AccessNodes for each variable together with its state + access_nodes = {} + # set of the variables which are only written to + self.write_only_values = set() + # Dictionary which stores additional information for the variables which are written only + self.assign_context = {} + for state in [self.if_stmt, self.else_stmt]: + for node in state.nodes(): + if isinstance(node, Tasklet): + # If node is a tasklet, check if assigns a constant value + assigns_const = True + for code_stmt in node.code.code: + if not (isinstance(code_stmt, ast.Assign) and isinstance(code_stmt.value, ast.Constant)): + assigns_const = False + if assigns_const: + for edge in state.out_edges(node): + if isinstance(edge.dst, AccessNode): + assigned_const.add(edge.dst.data) + self.assign_context[edge.dst.data] = {"state": state, "tasklet": node} + elif isinstance(node, AccessNode): + if node.data not in access_nodes: + access_nodes[node.data] = [] + access_nodes[node.data].append((node, state)) + + # check that the found access nodes only get written to + for data, nodes in access_nodes.items(): + write_only = True + for node, state in nodes: + if node.has_reads(state): + # The read is only a problem if it is not written before -> the access node has no incoming edge + if state.in_degree(node) == 0: + write_only = False + else: + # There is also a problem if any edge is an update instead of write + for edge in [*state.out_edges(node), *state.out_edges(node)]: + if edge.data.wcr is not None: + write_only = False + + if write_only: + self.write_only_values.add(data) + + # Want only the values which are only written to and one option uses a constant value + self.write_only_values = assigned_const.intersection(self.write_only_values) + + if len(self.write_only_values) == 0: + return False + return True + + def apply(self, _, sdfg: sd.SDFG): + # create a new state before the guard state where the zero assignment happens + new_assign_state = sdfg.add_state_before(self.if_guard, label="const_assignment_state") + + # Move all the Tasklets together with the AccessNode + for value in self.write_only_values: + state = self.assign_context[value]["state"] + tasklet = self.assign_context[value]["tasklet"] + new_assign_state.add_node(tasklet) + for edge in state.out_edges(tasklet): + state.remove_edge(edge) + state.remove_node(edge.dst) + new_assign_state.add_node(edge.dst) + new_assign_state.add_edge(tasklet, edge.src_conn, edge.dst, edge.dst_conn, edge.data) + + state.remove_node(tasklet) + # Remove the state if it was emptied + if state.is_empty(): + sdfg.remove_node(state) + return sdfg diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 17d006921e..3712916a91 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -334,7 +334,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) - nstate.set_label(newname) + nstate.label = newname ####################################################### # Add nested SDFG states into top-level SDFG diff --git a/doc/sdfg/images/elements.svg b/doc/sdfg/images/elements.svg index 80d35e39f0..6402de8e1d 100644 --- a/doc/sdfg/images/elements.svg +++ b/doc/sdfg/images/elements.svg @@ -1,90 +1,506 @@ - + - - - -Access Nodes - -T -ransient -Global - -Stream - -V -iew - -Reference - -T -asklet - - - - - - - - - -Nested SDFG - -Consume - - -Map - -... - - - -Library Node - - -... - -A[0] -CR: Sum -V -olume: 1 - -B[i, j] -V -olume: 1 -Memlet -W -rite-Conflict -Resolution - -State -State -T -ransition - + + + + +Access Nodes + +T +ransient +Global + +Stream + +V +iew + +Reference + +T +asklet + + + + + + + + + +Nested SDFG + +Consume + + +Map + +... + + + +Library Node + + +... + +A[0] +CR: Sum +V +olume: 1 + +B[i, j] +V +olume: 1 +Memlet +W +rite-Conflict +Resolution + +State +State +T +ransition +Control FlowRegion diff --git a/doc/sdfg/ir.rst b/doc/sdfg/ir.rst index 3c651fab19..f7bbb0ff79 100644 --- a/doc/sdfg/ir.rst +++ b/doc/sdfg/ir.rst @@ -29,7 +29,7 @@ Some of the main differences between SDFGs and other representations are: The Language ------------ -In a nutshell, an SDFG is a state machine of acyclic dataflow multigraphs. Here is an example graph: +In a nutshell, an SDFG is a hierarchical state machine of acyclic dataflow multigraphs. Here is an example graph: .. raw:: html @@ -43,7 +43,7 @@ In a nutshell, an SDFG is a state machine of acyclic dataflow multigraphs. Here The cyan rectangles are called **states** and together they form a state machine, executing the code from the starting state and following the blue edge that matches the conditions. In each state, an acyclic multigraph controls execution -through dataflow. There are four elements in the above state: +through dataflow. There are four elements in the above states: * **Access nodes** (ovals) that give access to data containers * **Memlets** (edges/dotted arrows) that represent units of data movement @@ -58,7 +58,14 @@ The state machine shown in the example is a for-loop (``for _ in range(5)``). Th the guard state controls the loop, and at the end the result is copied to the special ``__return`` data container, which designates the return value of the function. -There are other kinds of elements in an SDFG, as detailed below. +The state machine is analogous to a control flow graph, where states represent basic blocks. Multiple such basic blocks, +such as with the described loop, can be put together to form a **control flow region**. This allows them to be +represented with a single graph node in the SDFG's state machine, which is useful for optimization and analysis. +The SDFG itself can be thought of as one big control flow region. This means that control flow regions are directed +graphs, where nodes are states or other control flow regions, and edges are state transitions. + +In addition to the elements seen in the example above, there are other kinds of elements in an SDFG, which are detailed +below. .. _sdfg-lang: @@ -142,6 +149,12 @@ new value, and specifies how the update is performed. In the summation example, end of each state there is an implicit synchronization point, so it will not finish executing until all the last nodes have been reached (this assumption can be removed in extreme cases, see :class:`~dace.sdfg.state.SDFGState.nosync`). +**Control Flow Region**: Forms a directed graph of states and other control flow regions, where edges are state +transitions. This allows representing complex control flow in a single graph node, which is useful for analysis and +optimization. The SDFG itself is a control flow region, which means that control flow regions are recursive / +hierarchical. Similar to the SDFG, each control flow region has a unique starting state, which is the entry point to +the region and is executed first. + **State Transition**: Transitions, internally referred to as *inter-state edges*, specify how execution proceeds after the end of a State. Inter-state edges optionally contain a symbolic *condition* that is checked at the end of the preceding state. If any of the conditions are true, execution will continue to the destination of this edge (the @@ -783,5 +796,7 @@ file uses the :func:`~dace.sdfg.sdfg.SDFG.from_file` static method. For example, The ``compress`` argument can be used to save a smaller (``gzip`` compressed) file. It can keep the same extension, but it is customary to use ``.sdfg.gz`` or ``.sdfgz`` to let others know it is compressed. +It is recommended to use this option for large SDFGs, as it not only saves space, but also speeds up loading and +editing of the SDFG in visualization tools and the VSCode extension. diff --git a/requirements.txt b/requirements.txt index 996449dbef..266b3368c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,13 +14,13 @@ Jinja2==3.1.2 MarkupSafe==2.1.3 mpmath==1.3.0 networkx==3.1 -numpy==1.24.3 +numpy==1.26.1 ply==3.11 -PyYAML==6.0 +PyYAML==6.0.1 requests==2.31.0 six==1.16.0 sympy==1.9 -urllib3==2.0.6 +urllib3==2.0.7 websockets==11.0.3 -Werkzeug==2.3.5 +Werkzeug==3.0.1 zipp==3.15.0 diff --git a/samples/fpga/rtl/add_fortytwo.py b/samples/fpga/rtl/add_fortytwo.py index 9c14ad098b..5abcd76a5b 100644 --- a/samples/fpga/rtl/add_fortytwo.py +++ b/samples/fpga/rtl/add_fortytwo.py @@ -1,8 +1,9 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -# -# This sample shows adding a constant integer value to a stream of integers. -# -# It is intended for running hardware_emulation or hardware xilinx targets. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" + This sample shows adding a constant integer value to a stream of integers. + + It is intended for running hardware_emulation or hardware xilinx targets. +""" import dace import numpy as np @@ -116,21 +117,21 @@ ###################################################################### if __name__ == '__main__': + with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): + # init data structures + N.set(8192) + a = np.random.randint(0, 100, N.get()).astype(np.int32) + b = np.zeros((N.get(), )).astype(np.int32) - # init data structures - N.set(8192) - a = np.random.randint(0, 100, N.get()).astype(np.int32) - b = np.zeros((N.get(), )).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) + # show initial values + print("a={}, b={}".format(a, b)) - # call program - sdfg(A=a, B=b, N=N) + # call program + sdfg(A=a, B=b, N=N) - # show result - print("a={}, b={}".format(a, b)) + # show result + print("a={}, b={}".format(a, b)) - # check result - for i in range(N.get()): - assert b[i] == a[i] + 42 + # check result + for i in range(N.get()): + assert b[i] == a[i] + 42 diff --git a/samples/fpga/rtl/axpy.py b/samples/fpga/rtl/axpy.py index 8b720aaa1e..4f386c82a4 100644 --- a/samples/fpga/rtl/axpy.py +++ b/samples/fpga/rtl/axpy.py @@ -1,7 +1,10 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -# -# This sample shows the AXPY BLAS routine. It is implemented through Xilinx IPs in order to utilize floating point -# operations. It is intended for running hardware_emulation or hardware xilinx targets. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" + This sample shows the AXPY BLAS routine. It is implemented through Xilinx IPs in order to utilize floating point + operations. + + It is intended for running hardware_emulation or hardware xilinx targets. +""" import dace import numpy as np @@ -259,4 +262,4 @@ def make_sdfg(veclen=2): expected = a * x + y diff = np.linalg.norm(expected - result) / N.get() print("Difference:", diff) - exit(0 if diff <= 1e-5 else 1) + assert diff <= 1e-5 diff --git a/samples/fpga/rtl/axpy_double_pump.py b/samples/fpga/rtl/axpy_double_pump.py index 2d44ab7689..c79948007b 100644 --- a/samples/fpga/rtl/axpy_double_pump.py +++ b/samples/fpga/rtl/axpy_double_pump.py @@ -1,73 +1,74 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -# -# This sample shows the AXPY BLAS routine. It is implemented through Xilinx -# IPs in order to utilize double pumping, which doubles the performance per -# consumed FPGA resource. The double pumping operation is "inwards", which -# means that the internal vectorization width of the core computation is half -# that of the external vectorization width. This translates into utilizing half -# the amount of internal computing resources, compared to a regular vectorized -# implementetation. The block diagram of the design for a 32-bit floating-point -# implementation using vectorization width 2 is: -# -# ap_aclk s_axis_y_in s_axis_x_in a -# │ │ │ │ -# │ │ │ │ -# │ │ │ │ -# ┌───────┼─────────┬────────┼─────────┐ │ │ -# │ │ │ │ │ │ │ -# │ │ │ ▼ │ ▼ │ -# │ │ │ ┌────────────┐ │ ┌────────────┐ │ -# │ │ └─►│ │ └─►│ │ │ -# │ │ │ Clock sync │ │ Clock sync │ │ -# │ │ ┌─►│ │ ┌─►│ │ │ -# │ ▼ 300 MHz │ └─────┬──────┘ │ └─────┬──────┘ │ -# │ ┌────────────┐ │ │ │ │ │ -# │ │ Clock │ │ │ │ │ │ -# │ │ │ ├────────┼─────────┤ │ │ -# │ │ Multiplier │ │ │ │ │ │ -# │ └─────┬──────┘ │ ▼ 64 bit │ ▼ 64 bit │ -# │ │ 600 MHz │ ┌────────────┐ │ ┌────────────┐ │ -# │ │ │ │ │ │ │ │ │ -# │ └─────────┼─►│ Data issue │ └─►│ Data issue │ │ -# │ │ │ │ │ │ │ -# │ │ └─────┬──────┘ └─────┬──────┘ │ -# │ │ │ 32 bit │ 32 bit │ -# │ │ │ │ │ -# │ │ │ │ │ -# │ │ │ ▼ ▼ -# │ │ │ ┌────────────┐ -# │ │ │ │ │ -# │ ├────────┼────────────────►│ Multiplier │ -# │ │ │ │ │ -# │ │ │ └─────┬──────┘ -# │ │ │ │ -# │ │ │ ┌──────────────┘ -# │ │ │ │ -# │ │ ▼ ▼ -# │ │ ┌────────────┐ -# │ │ │ │ -# │ ├─────►│ Adder │ -# │ │ │ │ -# │ │ └─────┬──────┘ -# │ │ │ -# │ │ ▼ 32 bit -# │ │ ┌─────────────┐ -# │ │ │ │ -# │ ├─────►│ Data packer │ -# │ │ │ │ -# │ │ └─────┬───────┘ -# │ │ │ 64 bit -# │ │ ▼ -# │ │ ┌────────────┐ -# │ └─────►│ │ -# │ │ Clock sync │ -# └───────────────────────►│ │ -# └─────┬──────┘ -# │ -# ▼ -# m_axis_result_out -# -# It is intended for running hardware_emulation or hardware xilinx targets. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" + This sample shows the AXPY BLAS routine. It is implemented through Xilinx + IPs in order to utilize double pumping, which doubles the performance per + consumed FPGA resource. The double pumping operation is "inwards", which + means that the internal vectorization width of the core computation is half + that of the external vectorization width. This translates into utilizing half + the amount of internal computing resources, compared to a regular vectorized + implementetation. The block diagram of the design for a 32-bit floating-point + implementation using vectorization width 2 is: + + ap_aclk s_axis_y_in s_axis_x_in a + │ │ │ │ + │ │ │ │ + │ │ │ │ + ┌───────┼─────────┬────────┼─────────┐ │ │ + │ │ │ │ │ │ │ + │ │ │ ▼ │ ▼ │ + │ │ │ ┌────────────┐ │ ┌────────────┐ │ + │ │ └─►│ │ └─►│ │ │ + │ │ │ Clock sync │ │ Clock sync │ │ + │ │ ┌─►│ │ ┌─►│ │ │ + │ ▼ 300 MHz │ └─────┬──────┘ │ └─────┬──────┘ │ + │ ┌────────────┐ │ │ │ │ │ + │ │ Clock │ │ │ │ │ │ + │ │ │ ├────────┼─────────┤ │ │ + │ │ Multiplier │ │ │ │ │ │ + │ └─────┬──────┘ │ ▼ 64 bit │ ▼ 64 bit │ + │ │ 600 MHz │ ┌────────────┐ │ ┌────────────┐ │ + │ │ │ │ │ │ │ │ │ + │ └─────────┼─►│ Data issue │ └─►│ Data issue │ │ + │ │ │ │ │ │ │ + │ │ └─────┬──────┘ └─────┬──────┘ │ + │ │ │ 32 bit │ 32 bit │ + │ │ │ │ │ + │ │ │ │ │ + │ │ │ ▼ ▼ + │ │ │ ┌────────────┐ + │ │ │ │ │ + │ ├────────┼────────────────►│ Multiplier │ + │ │ │ │ │ + │ │ │ └─────┬──────┘ + │ │ │ │ + │ │ │ ┌──────────────┘ + │ │ │ │ + │ │ ▼ ▼ + │ │ ┌────────────┐ + │ │ │ │ + │ ├─────►│ Adder │ + │ │ │ │ + │ │ └─────┬──────┘ + │ │ │ + │ │ ▼ 32 bit + │ │ ┌─────────────┐ + │ │ │ │ + │ ├─────►│ Data packer │ + │ │ │ │ + │ │ └─────┬───────┘ + │ │ │ 64 bit + │ │ ▼ + │ │ ┌────────────┐ + │ └─────►│ │ + │ │ Clock sync │ + └───────────────────────►│ │ + └─────┬──────┘ + │ + ▼ + m_axis_result_out + + It is intended for running hardware_emulation or hardware xilinx targets. +""" import dace import numpy as np @@ -452,4 +453,4 @@ def make_sdfg(veclen=2): diff = np.linalg.norm(expected - result) / N.get() print("Difference:", diff) - exit(0 if diff <= 1e-5 else 1) + assert diff <= 1e-5 diff --git a/samples/fpga/rtl/fladd.py b/samples/fpga/rtl/fladd.py index f22d419cbc..daf1ed269b 100644 --- a/samples/fpga/rtl/fladd.py +++ b/samples/fpga/rtl/fladd.py @@ -1,10 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -# -# This sample shows how to utilize an IP core in an RTL tasklet. This is done -# through the vector add problem, which adds two floating point vectors -# together. -# -# It is intended for running hardware_emulation or hardware xilinx targets. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" + This sample shows how to utilize an IP core in an RTL tasklet. This is done + through the vector add problem, which adds two floating point vectors + together. + + It is intended for running hardware_emulation or hardware xilinx targets. +""" import dace import numpy as np @@ -190,4 +191,4 @@ expected = a + b diff = np.linalg.norm(expected - c) / N.get() print("Difference:", diff) - exit(0 if diff <= 1e-5 else 1) + assert diff <= 1e-5 diff --git a/samples/fpga/rtl/pipeline.py b/samples/fpga/rtl/pipeline.py index b487da91ce..dbd0460fb0 100644 --- a/samples/fpga/rtl/pipeline.py +++ b/samples/fpga/rtl/pipeline.py @@ -1,9 +1,10 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -# -# This sample shows a DEPTH deep pipeline, where each stage adds 1 to the -# integer input stream. -# -# It is intended for running hardware_emulation or hardware xilinx targets. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" + This sample shows a DEPTH deep pipeline, where each stage adds 1 to the + integer input stream. + + It is intended for running hardware_emulation or hardware xilinx targets. +""" import dace import numpy as np @@ -151,21 +152,21 @@ ###################################################################### if __name__ == '__main__': + with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): + # init data structures + N.set(8192) + a = np.random.randint(0, 100, N.get()).astype(np.int32) + b = np.zeros((N.get(), )).astype(np.int32) - # init data structures - N.set(8192) - a = np.random.randint(0, 100, N.get()).astype(np.int32) - b = np.zeros((N.get(), )).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) + # show initial values + print("a={}, b={}".format(a, b)) - # call program - sdfg(A=a, B=b, N=N) + # call program + sdfg(A=a, B=b, N=N) - # show result - print("a={}, b={}".format(a, b)) + # show result + print("a={}, b={}".format(a, b)) - # check result - for i in range(N.get()): - assert b[i] == a[i] + depth + # check result + for i in range(N.get()): + assert b[i] == a[i] + depth diff --git a/samples/fpga/rtl/rtl_multi_tasklet.py b/samples/fpga/rtl/rtl_multi_tasklet.py index a646eb6be9..4a4a09deec 100644 --- a/samples/fpga/rtl/rtl_multi_tasklet.py +++ b/samples/fpga/rtl/rtl_multi_tasklet.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Two sequential RTL tasklets connected through a memlet. + + It is intended for running simulation xilinx targets. """ import dace -import argparse - import numpy as np # add sdfg @@ -32,7 +32,7 @@ m_axis_b_tdata <= 0; s_axis_a_tready <= 1'b1; state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a + end else if (s_axis_a_tvalid && state == READY) begin // case: load a m_axis_b_tdata <= s_axis_a_tdata; s_axis_a_tready <= 1'b0; state <= BUSY; @@ -41,7 +41,7 @@ else m_axis_b_tdata <= m_axis_b_tdata; state <= DONE; -end +end assign m_axis_b_tvalid = (m_axis_b_tdata >= 80) ? 1'b1:1'b0; """, @@ -59,7 +59,7 @@ m_axis_c_tdata <= 0; s_axis_b_tready <= 1'b1; state <= READY; - end else if (s_axis_b_tvalid && state == READY) begin // case: load a + end else if (s_axis_b_tvalid && state == READY) begin // case: load a m_axis_c_tdata <= s_axis_b_tdata; s_axis_b_tready <= 1'b0; state <= BUSY; @@ -68,9 +68,9 @@ else m_axis_c_tdata <= m_axis_c_tdata; state <= DONE; -end +end -assign m_axis_c_tvalid = (m_axis_c_tdata >= 100) ? 1'b1:1'b0; +assign m_axis_c_tvalid = (m_axis_c_tdata >= 100) ? 1'b1:1'b0; """, language=dace.Language.SystemVerilog) @@ -92,21 +92,21 @@ ###################################################################### if __name__ == '__main__': + with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): + # init data structures + a = np.random.randint(0, 80, 1).astype(np.int32) + b = np.array([0]).astype(np.int32) + c = np.array([0]).astype(np.int32) - # init data structures - a = np.random.randint(0, 80, 1).astype(np.int32) - b = np.array([0]).astype(np.int32) - c = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}, c={}".format(a, b, c)) + # show initial values + print("a={}, b={}, c={}".format(a, b, c)) - # call program - sdfg(A=a, B=b, C=c) + # call program + sdfg(A=a, B=b, C=c) - # show result - print("a={}, b={}, c={}".format(a, b, c)) + # show result + print("a={}, b={}, c={}".format(a, b, c)) - # check result - assert b == 80 - assert c == 100 + # check result + assert b == 80 + assert c == 100 diff --git a/samples/fpga/rtl/rtl_tasklet_parameter.py b/samples/fpga/rtl/rtl_tasklet_parameter.py index d20688b385..112e88a6bf 100644 --- a/samples/fpga/rtl/rtl_tasklet_parameter.py +++ b/samples/fpga/rtl/rtl_tasklet_parameter.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Simple RTL tasklet with a single scalar input and a single scalar output. It increments b from a up to 100. + + It is intended for running simulation xilinx targets. """ import dace -import argparse - import numpy as np # add sdfg @@ -47,7 +47,7 @@ m_axis_b_tdata <= 0; s_axis_a_tready <= 1'b1; state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a + end else if (s_axis_a_tvalid && state == READY) begin // case: load a m_axis_b_tdata <= s_axis_a_tdata; s_axis_a_tready <= 1'b0; state <= BUSY; @@ -56,9 +56,9 @@ else m_axis_b_tdata <= m_axis_b_tdata; state <= DONE; - end + end - assign m_axis_b_tvalid = (m_axis_b_tdata >= MAX_VAL) ? 1'b1:1'b0; + assign m_axis_b_tvalid = (m_axis_b_tdata >= MAX_VAL) ? 1'b1:1'b0; ''', language=dace.Language.SystemVerilog) @@ -76,19 +76,19 @@ ###################################################################### if __name__ == '__main__': + with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): + # init data structures + a = np.random.randint(0, 100, 1).astype(np.int32) + b = np.array([0]).astype(np.int32) - # init data structures - a = np.random.randint(0, 100, 1).astype(np.int32) - b = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) + # show initial values + print("a={}, b={}".format(a, b)) - # call program - sdfg(A=a, B=b) + # call program + sdfg(A=a, B=b) - # show result - print("a={}, b={}".format(a, b)) + # show result + print("a={}, b={}".format(a, b)) - # check result - assert b == sdfg.constants["MAX_VAL"] + # check result + assert b == sdfg.constants["MAX_VAL"] diff --git a/samples/fpga/rtl/rtl_tasklet_pipeline.py b/samples/fpga/rtl/rtl_tasklet_pipeline.py index 9166806c63..3ef20cd03f 100644 --- a/samples/fpga/rtl/rtl_tasklet_pipeline.py +++ b/samples/fpga/rtl/rtl_tasklet_pipeline.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Pipelined, AXI-handshake compliant example that increments b from a up to 100. + + It is intended for running simulation xilinx targets. """ import dace -import argparse - import numpy as np # add symbol @@ -59,7 +59,7 @@ state <= state_next; end - always_comb + always_comb begin state_next = state; case(state) @@ -132,21 +132,21 @@ ###################################################################### if __name__ == '__main__': + with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): + # init data structures + num_elements = dace.symbolic.evaluate(N, sdfg.constants) + a = np.random.randint(0, 100, num_elements).astype(np.int32) + b = np.array([0] * num_elements).astype(np.int32) - # init data structures - num_elements = dace.symbolic.evaluate(N, sdfg.constants) - a = np.random.randint(0, 100, num_elements).astype(np.int32) - b = np.array([0] * num_elements).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) + # show initial values + print("a={}, b={}".format(a, b)) - # call program - sdfg(A=a, B=b) + # call program + sdfg(A=a, B=b) - # show result - print("a={}, b={}".format(a, b)) + # show result + print("a={}, b={}".format(a, b)) - assert b[ - 0] == 100 # TODO: implement detection of #elements to process, s.t. we can extend the assertion to the whole array - assert np.all(map((lambda x: x == 0), b[1:-1])) # should still be at the init value (for the moment) + assert b[ + 0] == 100 # TODO: implement detection of #elements to process, s.t. we can extend the assertion to the whole array + assert np.all(map((lambda x: x == 0), b[1:-1])) # should still be at the init value (for the moment) diff --git a/samples/fpga/rtl/rtl_tasklet_scalar.py b/samples/fpga/rtl/rtl_tasklet_scalar.py index c9f6380a2b..cf8d53ec91 100644 --- a/samples/fpga/rtl/rtl_tasklet_scalar.py +++ b/samples/fpga/rtl/rtl_tasklet_scalar.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Simple RTL tasklet with a single scalar input and a single scalar output. It increments b from a up to 100. + + It is intended for running simulation xilinx targets. """ import dace -import argparse - import numpy as np # add sdfg @@ -79,19 +79,19 @@ ###################################################################### if __name__ == '__main__': + with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): + # init data structures + a = np.random.randint(0, 100, 1).astype(np.int32) + b = np.array([0]).astype(np.int32) - # init data structures - a = np.random.randint(0, 100, 1).astype(np.int32) - b = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) + # show initial values + print("a={}, b={}".format(a, b)) - # call program - sdfg(A=a, B=b) + # call program + sdfg(A=a, B=b) - # show result - print("a={}, b={}".format(a, b)) + # show result + print("a={}, b={}".format(a, b)) - # check result - assert b == 100 + # check result + assert b == 100 diff --git a/samples/fpga/rtl/rtl_tasklet_vector.py b/samples/fpga/rtl/rtl_tasklet_vector.py index c099a6a38d..9015b4f35e 100644 --- a/samples/fpga/rtl/rtl_tasklet_vector.py +++ b/samples/fpga/rtl/rtl_tasklet_vector.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ RTL tasklet with a vector input of 4 int32 (width=128bits) and a single scalar output. It increments b from a[31:0] up to 100. + + It is intended for running simulation xilinx targets. """ import dace -import argparse - import numpy as np # add symbol @@ -44,13 +44,13 @@ typedef enum [1:0] {READY, BUSY, DONE} state_e; state_e state; - + always@(posedge ap_aclk) begin if (ap_areset) begin // case: reset m_axis_b_tdata <= 0; s_axis_a_tready <= 1'b1; state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a + end else if (s_axis_a_tvalid && state == READY) begin // case: load a m_axis_b_tdata <= s_axis_a_tdata[0]; s_axis_a_tready <= 1'b0; state <= BUSY; @@ -60,9 +60,9 @@ m_axis_b_tdata <= m_axis_b_tdata; state <= DONE; end - end - - assign m_axis_b_tvalid = (m_axis_b_tdata >= s_axis_a_tdata[0] + s_axis_a_tdata[1] && (state == BUSY || state == DONE)) ? 1'b1:1'b0; + end + + assign m_axis_b_tvalid = (m_axis_b_tdata >= s_axis_a_tdata[0] + s_axis_a_tdata[1] && (state == BUSY || state == DONE)) ? 1'b1:1'b0; ''', language=dace.Language.SystemVerilog) @@ -80,19 +80,19 @@ ###################################################################### if __name__ == '__main__': + with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): + # init data structures + a = np.random.randint(0, 100, dace.symbolic.evaluate(WIDTH, sdfg.constants)).astype(np.int32) + b = np.array([0]).astype(np.int32) - # init data structures - a = np.random.randint(0, 100, dace.symbolic.evaluate(WIDTH, sdfg.constants)).astype(np.int32) - b = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) + # show initial values + print("a={}, b={}".format(a, b)) - # call program - sdfg(A=a, B=b) + # call program + sdfg(A=a, B=b) - # show result - print("a={}, b={}".format(a, b)) + # show result + print("a={}, b={}".format(a, b)) - # check result - assert b == a[0] + a[1] + # check result + assert b == a[0] + a[1] diff --git a/test_all.sh b/test_all.sh index c4240fa820..cc34b74b36 100755 --- a/test_all.sh +++ b/test_all.sh @@ -3,6 +3,12 @@ set -a +if [[ -z "${CXX}" ]]; then + CXX="g++" # I don't think that is a good default, but it was the hardcoded compiler before I made changes... +else + CXX="${CXX}" +fi + SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" PYTHONPATH=$SCRIPTPATH @@ -53,7 +59,7 @@ bail_skip() { test_start() { TESTS=`expr $TESTS + 1` CURTEST="$TESTPREFIX$1" - echo "---------- TEST: $TESTPREFIX$1 ----------" + echo "---------- TEST: $TESTPREFIX$1 ---------- [ This is test $TESTS of $TOTAL_TESTS ]" } testcmd() { @@ -64,14 +70,14 @@ testcmd() { #$* | tee -a test.log TESTCNT=`expr $TESTS - 1` MSG="($TESTCNT / $TOTAL_TESTS) $CURTEST (Fails: $ERRORS)" - ($* || echo "_TFAIL_ $?") |& awk "BEGIN{printf \"$MSG\r\"} /_TFAIL_/{printf \"$TGAP\r\"; exit \$NF} {printf \"$TGAP\r\"; print; printf \"$MSG\r\";} END{printf \"$TGAP\r\"}" + ($* || echo "_TFAIL_ $?") 2>&1 | awk "BEGIN{printf \"$MSG\r\"} /_TFAIL_/{printf \"$TGAP\r\"; exit \$NF} {printf \"$TGAP\r\"; print; printf \"$MSG\r\";} END{printf \"$TGAP\r\"}" } ################################################ runtest_cpp() { test_start $1 - testcmd g++ -std=c++14 -Wall -Wextra -O3 -march=native -ffast-math -fopenmp -fPIC \ + testcmd $CXX -std=c++14 -Wall -Wextra -O3 -march=native -ffast-math -fopenmp -fPIC \ -I $SCRIPTPATH/dace/runtime/include $1 -o ./$1.out if [ $? -ne 0 ]; then bail "$1 (compilation)"; fi testcmd ./$1.out diff --git a/tests/cuda_block_test.py b/tests/cuda_block_test.py index f77e80673f..676785e0e5 100644 --- a/tests/cuda_block_test.py +++ b/tests/cuda_block_test.py @@ -10,8 +10,10 @@ @dace.program(dace.float64[N], dace.float64[N]) def cudahello(V, Vout): + @dace.mapscope(_[0:N:32]) def multiplication(i): + @dace.map(_[0:32]) def mult_block(bi): in_V << V[i + bi] @@ -55,6 +57,7 @@ def test_gpu(): @pytest.mark.gpu def test_different_block_sizes_nesting(): + @dace.program def nested(V: dace.float64[34], v1: dace.float64[1]): with dace.tasklet: @@ -105,6 +108,7 @@ def diffblocks(V: dace.float64[130], v1: dace.float64[4], v2: dace.float64[128]) @pytest.mark.gpu def test_custom_block_size_onemap(): + @dace.program def tester(A: dace.float64[400, 300]): for i, j in dace.map[0:400, 0:300]: @@ -132,6 +136,7 @@ def tester(A: dace.float64[400, 300]): @pytest.mark.gpu def test_custom_block_size_twomaps(): + @dace.program def tester(A: dace.float64[400, 300, 2, 32]): for i, j in dace.map[0:400, 0:300]: @@ -154,9 +159,42 @@ def tester(A: dace.float64[400, 300, 2, 32]): sdfg.compile() +@pytest.mark.gpu +def test_block_thread_specialization(): + + @dace.program + def tester(A: dace.float64[200]): + for i in dace.map[0:200:32]: + for bi in dace.map[0:32]: + with dace.tasklet: + a >> A[i + bi] + a = 1 + with dace.tasklet: # Tasklet to be specialized + a >> A[i + bi] + a = 2 + + sdfg = tester.to_sdfg() + sdfg.apply_gpu_transformations(sequential_innermaps=False) + tasklet = next(n for n, _ in sdfg.all_nodes_recursive() + if isinstance(n, dace.nodes.Tasklet) and '2' in n.code.as_string) + tasklet.location['gpu_thread'] = dace.subsets.Range.from_string('2:9:3') + tasklet.location['gpu_block'] = 1 + + code = sdfg.generate_code()[1].clean_code # Get GPU code (second file) + assert '>= 2' in code and '<= 8' in code + assert ' == 1' in code + + a = np.random.rand(200) + ref = np.ones_like(a) + ref[32:64][2:9:3] = 2 + sdfg(a) + assert np.allclose(a, ref) + + if __name__ == "__main__": test_cpu() test_gpu() test_different_block_sizes_nesting() test_custom_block_size_onemap() test_custom_block_size_twomaps() + test_block_thread_specialization() diff --git a/tests/fortran/array_to_loop_offset.py b/tests/fortran/array_to_loop_offset.py index 43d01d9b6b..5042859f8c 100644 --- a/tests/fortran/array_to_loop_offset.py +++ b/tests/fortran/array_to_loop_offset.py @@ -112,8 +112,112 @@ def test_fortran_frontend_arr2loop_2d_offset(): for j in range(7,10): assert a[i-1, j-1] == i * 2 +def test_fortran_frontend_arr2loop_2d_offset2(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,7:9) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(5,7:9) :: d + + d(:,:) = 43 + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,9], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(1,6): + for j in range(7,10): + assert a[i-1, j-1] == 43 + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(0,5): + for j in range(0,3): + assert a[i, j] == 43 + +def test_fortran_frontend_arr2loop_2d_offset3(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,7:9) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(5,7:9) :: d + + d(2:4, 7:8) = 43 + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,9], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(2,4): + for j in range(7,9): + assert a[i-1, j-1] == 43 + for j in range(9,10): + assert a[i-1, j-1] == 42 + + for i in [1, 5]: + for j in range(7,10): + assert a[i-1, j-1] == 42 + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(1,4): + for j in range(0,2): + assert a[i, j] == 43 + for j in range(2,3): + assert a[i, j] == 42 + + for i in [0, 4]: + for j in range(0,3): + assert a[i, j] == 42 + if __name__ == "__main__": test_fortran_frontend_arr2loop_1d_offset() test_fortran_frontend_arr2loop_2d_offset() + test_fortran_frontend_arr2loop_2d_offset2() + test_fortran_frontend_arr2loop_2d_offset3() test_fortran_frontend_arr2loop_without_offset() diff --git a/tests/fortran/intrinsic_all.py b/tests/fortran/intrinsic_all.py new file mode 100644 index 0000000000..4a368aff2c --- /dev/null +++ b/tests/fortran/intrinsic_all.py @@ -0,0 +1,361 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + + +def test_fortran_frontend_all_array(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + logical, dimension(5) :: d + logical, dimension(2) :: res + CALL intrinsic_all_test_function(d, res) + end + + SUBROUTINE intrinsic_all_test_function(d, res) + logical, dimension(5) :: d + logical, dimension(2) :: res + + res(1) = ALL(d) + + END SUBROUTINE intrinsic_all_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + d = np.full([size], False, order="F", dtype=np.int32) + res = np.full([2], 42, order="F", dtype=np.int32) + + d[2] = True + sdfg(d=d, res=res) + assert res[0] == False + + d = np.full([size], True, order="F", dtype=np.int32) + sdfg(d=d, res=res) + assert res[0] == True + + +def test_fortran_frontend_all_array_dim(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + logical, dimension(5) :: d + logical, dimension(2) :: res + CALL intrinsic_all_test_function(d, res) + end + + SUBROUTINE intrinsic_all_test_function(d, res) + logical, dimension(5) :: d + logical, dimension(2) :: res + + res(1) = ALL(d, 1) + + END SUBROUTINE intrinsic_all_test_function + """ + + with pytest.raises(NotImplementedError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + + +def test_fortran_frontend_all_array_comparison(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + integer, dimension(5) :: first + integer, dimension(5) :: second + logical, dimension(7) :: res + CALL intrinsic_all_test_function(first, second, res) + end + + SUBROUTINE intrinsic_all_test_function(first, second, res) + integer, dimension(5) :: first + integer, dimension(5) :: second + logical, dimension(7) :: res + + res(1) = ALL(first .eq. second) + res(2) = ALL(first(:) .eq. second) + res(3) = ALL(first .eq. second(:)) + res(4) = ALL(first(:) .eq. second(:)) + res(5) = ALL(first(1:5) .eq. second(1:5)) + ! This will also be true - the only same + ! element is at position 3. + res(6) = ALL(first(1:3) .eq. second(3:5)) + res(7) = ALL(first(1:2) .eq. second(4:5)) + + END SUBROUTINE intrinsic_all_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + first = np.full([size], 1, order="F", dtype=np.int32) + second = np.full([size], 1, order="F", dtype=np.int32) + second[2] = 2 + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [0, 0, 0, 0, 0, 0, 1] + + second = np.full([size], 2, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + sdfg(first=first, second=second, res=res) + for val in res: + assert val == False + +def test_fortran_frontend_all_array_scalar_comparison(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + integer, dimension(5) :: first + logical, dimension(7) :: res + CALL intrinsic_all_test_function(first, res) + end + + SUBROUTINE intrinsic_all_test_function(first, res) + integer, dimension(5) :: first + logical, dimension(7) :: res + + res(1) = ALL(first .eq. 42) + res(2) = ALL(first(:) .eq. 42) + res(3) = ALL(first(1:2) .eq. 42) + res(4) = ALL(first(3) .eq. 42) + res(5) = ALL(first(3:5) .eq. 42) + res(6) = ALL(42 .eq. first) + res(7) = ALL(42 .ne. first) + + END SUBROUTINE intrinsic_all_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + first = np.full([size], 42, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, res=res) + for val in res[0:-1]: + assert val == True + assert res[-1] == False + + first[1] = 5 + sdfg(first=first, res=res) + assert list(res) == [0, 0, 0, 1, 1, 0, 0] + + first[1] = 42 + first[3] = 7 + sdfg(first=first, res=res) + assert list(res) == [0, 0, 1, 1, 0, 0, 0] + + first = np.full([size], 41, order="F", dtype=np.int32) + sdfg(first=first, res=res) + assert list(res) == [0, 0, 0, 0, 0, 0, 1] + +def test_fortran_frontend_all_array_comparison_wrong_subset(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + logical, dimension(5) :: first + logical, dimension(5) :: second + logical, dimension(2) :: res + CALL intrinsic_all_test_function(first, second, res) + end + + SUBROUTINE intrinsic_all_test_function(first, second, res) + logical, dimension(5) :: first + logical, dimension(5) :: second + logical, dimension(2) :: res + + res(1) = ALL(first(1:2) .eq. second(2:5)) + + END SUBROUTINE intrinsic_all_test_function + """ + + with pytest.raises(TypeError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + +def test_fortran_frontend_all_array_2d(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + logical, dimension(5,7) :: d + logical, dimension(2) :: res + CALL intrinsic_all_test_function(d, res) + end + + SUBROUTINE intrinsic_all_test_function(d, res) + logical, dimension(5,7) :: d + logical, dimension(2) :: res + + res(1) = ALL(d) + + END SUBROUTINE intrinsic_all_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 7] + d = np.full(sizes, True, order="F", dtype=np.int32) + res = np.full([2], 42, order="F", dtype=np.int32) + + d[2,2] = False + sdfg(d=d, res=res) + assert res[0] == False + + d[2,2] = True + sdfg(d=d, res=res) + assert res[0] == True + +def test_fortran_frontend_all_array_comparison_2d(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(7) :: res + CALL intrinsic_all_test_function(first, second, res) + end + + SUBROUTINE intrinsic_all_test_function(first, second, res) + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(7) :: res + + res(1) = ALL(first .eq. second) + res(2) = ALL(first(:,:) .eq. second) + res(3) = ALL(first .eq. second(:,:)) + res(4) = ALL(first(:,:) .eq. second(:,:)) + res(5) = ALL(first(1:5,:) .eq. second(1:5,:)) + res(6) = ALL(first(:,1:4) .eq. second(:,1:4)) + ! Now test subsets. + res(7) = ALL(first(2:3, 3:4) .eq. second(2:3, 3:4)) + + END SUBROUTINE intrinsic_all_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + second = np.full(sizes, 1, order="F", dtype=np.int32) + second[2,2] = 2 + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + for val in res: + assert val == False + + second = np.full(sizes, 1, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + sdfg(first=first, second=second, res=res) + for val in res: + assert val == True + +def test_fortran_frontend_all_array_comparison_2d_subset(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(2) :: res + CALL intrinsic_all_test_function(first, second, res) + end + + SUBROUTINE intrinsic_all_test_function(first, second, res) + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(2) :: res + + ! Now test subsets - make sure the equal values are only + ! in the tested area. + res(1) = ALL(first(1:2, 3:4) .ne. second(4:5, 2:3)) + res(2) = ALL(first(1:2, 3:4) .eq. second(4:5, 2:3)) + + END SUBROUTINE intrinsic_all_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + first[2:5, :] = 2 + first[0:2, 0:2] = 2 + + second = np.full(sizes, 1, order="F", dtype=np.int32) + second[0:3, :] = 3 + second[3:5, 0] = 3 + second[3:5, 3:5] = 3 + + res = np.full([2], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [0, 1] + +def test_fortran_frontend_all_array_comparison_2d_subset_offset(): + test_string = """ + PROGRAM intrinsic_all_test + implicit none + integer, dimension(20:24,4) :: first + integer, dimension(5,7:10) :: second + logical, dimension(2) :: res + CALL intrinsic_all_test_function(first, second, res) + end + + SUBROUTINE intrinsic_all_test_function(first, second, res) + integer, dimension(20:24,4) :: first + integer, dimension(5,7:10) :: second + logical, dimension(2) :: res + + ! Now test subsets - make sure the equal values are only + ! in the tested area. + res(1) = ALL(first(20:21, 3:4) .ne. second(4:5, 8:9)) + res(2) = ALL(first(20:21, 3:4) .eq. second(4:5, 8:9)) + + END SUBROUTINE intrinsic_all_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_all_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + first[2:5, :] = 2 + first[0:2, 0:2] = 2 + + second = np.full(sizes, 1, order="F", dtype=np.int32) + second[0:3, :] = 3 + second[3:5, 0] = 3 + second[3:5, 3:5] = 3 + + res = np.full([2], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [0, 1] + +if __name__ == "__main__": + + test_fortran_frontend_all_array() + test_fortran_frontend_all_array_dim() + test_fortran_frontend_all_array_comparison() + test_fortran_frontend_all_array_scalar_comparison() + test_fortran_frontend_all_array_comparison_wrong_subset() + test_fortran_frontend_all_array_2d() + test_fortran_frontend_all_array_comparison_2d() + test_fortran_frontend_all_array_comparison_2d_subset() + test_fortran_frontend_all_array_comparison_2d_subset_offset() diff --git a/tests/fortran/intrinsic_any.py b/tests/fortran/intrinsic_any.py new file mode 100644 index 0000000000..c1d82cd2e0 --- /dev/null +++ b/tests/fortran/intrinsic_any.py @@ -0,0 +1,364 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + + +def test_fortran_frontend_any_array(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + logical, dimension(5) :: d + logical, dimension(2) :: res + CALL intrinsic_any_test_function(d, res) + end + + SUBROUTINE intrinsic_any_test_function(d, res) + logical, dimension(5) :: d + logical, dimension(2) :: res + + res(1) = ANY(d) + + END SUBROUTINE intrinsic_any_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + d = np.full([size], False, order="F", dtype=np.int32) + res = np.full([2], 42, order="F", dtype=np.int32) + + d[2] = True + sdfg(d=d, res=res) + assert res[0] == True + + d[2] = False + sdfg(d=d, res=res) + assert res[0] == False + + +def test_fortran_frontend_any_array_dim(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + logical, dimension(5) :: d + logical, dimension(2) :: res + CALL intrinsic_any_test_function(d, res) + end + + SUBROUTINE intrinsic_any_test_function(d, res) + logical, dimension(5) :: d + logical, dimension(2) :: res + + res(1) = ANY(d, 1) + + END SUBROUTINE intrinsic_any_test_function + """ + + with pytest.raises(NotImplementedError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + + +def test_fortran_frontend_any_array_comparison(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + integer, dimension(5) :: first + integer, dimension(5) :: second + logical, dimension(7) :: res + CALL intrinsic_any_test_function(first, second, res) + end + + SUBROUTINE intrinsic_any_test_function(first, second, res) + integer, dimension(5) :: first + integer, dimension(5) :: second + logical, dimension(7) :: res + + res(1) = ANY(first .eq. second) + res(2) = ANY(first(:) .eq. second) + res(3) = ANY(first .eq. second(:)) + res(4) = ANY(first(:) .eq. second(:)) + res(5) = any(first(1:5) .eq. second(1:5)) + ! This will also be true - the only same + ! element is at position 3. + res(6) = any(first(1:3) .eq. second(3:5)) + res(7) = any(first(1:2) .eq. second(4:5)) + + END SUBROUTINE intrinsic_any_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + first = np.full([size], 1, order="F", dtype=np.int32) + second = np.full([size], 2, order="F", dtype=np.int32) + second[2] = 1 + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + for val in res[0:-1]: + assert val == True + assert res[-1] == False + + second = np.full([size], 2, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + sdfg(first=first, second=second, res=res) + for val in res: + assert val == False + +def test_fortran_frontend_any_array_scalar_comparison(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + integer, dimension(5) :: first + logical, dimension(7) :: res + CALL intrinsic_any_test_function(first, res) + end + + SUBROUTINE intrinsic_any_test_function(first, res) + integer, dimension(5) :: first + logical, dimension(7) :: res + + res(1) = ANY(first .eq. 42) + res(2) = ANY(first(:) .eq. 42) + res(3) = ANY(first(1:2) .eq. 42) + res(4) = ANY(first(3) .eq. 42) + res(5) = ANY(first(3:5) .eq. 42) + res(6) = ANY(42 .eq. first) + res(7) = ANY(42 .ne. first) + + END SUBROUTINE intrinsic_any_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + first = np.full([size], 1, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, res=res) + for val in res[0:-1]: + assert val == False + assert res[-1] == True + + first[1] = 42 + sdfg(first=first, res=res) + assert list(res) == [1, 1, 1, 0, 0, 1, 1] + + first[1] = 5 + first[3] = 42 + sdfg(first=first, res=res) + assert list(res) == [1, 1, 0, 0, 1, 1, 1] + + first[3] = 7 + first[2] = 42 + sdfg(first=first, res=res) + assert list(res) == [1, 1, 0, 1, 1, 1, 1] + +def test_fortran_frontend_any_array_comparison_wrong_subset(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + logical, dimension(5) :: first + logical, dimension(5) :: second + logical, dimension(2) :: res + CALL intrinsic_any_test_function(first, second, res) + end + + SUBROUTINE intrinsic_any_test_function(first, second, res) + logical, dimension(5) :: first + logical, dimension(5) :: second + logical, dimension(2) :: res + + res(1) = ANY(first(1:2) .eq. second(2:5)) + + END SUBROUTINE intrinsic_any_test_function + """ + + with pytest.raises(TypeError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + +def test_fortran_frontend_any_array_2d(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + logical, dimension(5,7) :: d + logical, dimension(2) :: res + CALL intrinsic_any_test_function(d, res) + end + + SUBROUTINE intrinsic_any_test_function(d, res) + logical, dimension(5,7) :: d + logical, dimension(2) :: res + + res(1) = ANY(d) + + END SUBROUTINE intrinsic_any_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 7] + d = np.full(sizes, False, order="F", dtype=np.int32) + res = np.full([2], 42, order="F", dtype=np.int32) + + d[2,2] = True + sdfg(d=d, res=res) + assert res[0] == True + + d[2,2] = False + sdfg(d=d, res=res) + assert res[0] == False + +def test_fortran_frontend_any_array_comparison_2d(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(7) :: res + CALL intrinsic_any_test_function(first, second, res) + end + + SUBROUTINE intrinsic_any_test_function(first, second, res) + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(7) :: res + + res(1) = ANY(first .eq. second) + res(2) = ANY(first(:,:) .eq. second) + res(3) = ANY(first .eq. second(:,:)) + res(4) = ANY(first(:,:) .eq. second(:,:)) + res(5) = any(first(1:5,:) .eq. second(1:5,:)) + res(6) = any(first(:,1:4) .eq. second(:,1:4)) + ! Now test subsets. + res(7) = any(first(2:3, 3:4) .eq. second(2:3, 3:4)) + + END SUBROUTINE intrinsic_any_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + second = np.full(sizes, 2, order="F", dtype=np.int32) + second[2,2] = 1 + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + for val in res: + assert val == True + + second = np.full(sizes, 2, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + sdfg(first=first, second=second, res=res) + for val in res: + assert val == False + +def test_fortran_frontend_any_array_comparison_2d_subset(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(2) :: res + CALL intrinsic_any_test_function(first, second, res) + end + + SUBROUTINE intrinsic_any_test_function(first, second, res) + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(2) :: res + + ! Now test subsets - make sure the equal values are only + ! in the tested area. + res(1) = any(first(1:2, 3:4) .ne. second(4:5, 2:3)) + res(2) = any(first(1:2, 3:4) .eq. second(4:5, 2:3)) + + END SUBROUTINE intrinsic_any_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + first[2:5, :] = 2 + first[0:2, 0:2] = 2 + + second = np.full(sizes, 1, order="F", dtype=np.int32) + second[0:3, :] = 3 + second[3:5, 0] = 3 + second[3:5, 3:5] = 3 + + res = np.full([2], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [0, 1] + +def test_fortran_frontend_any_array_comparison_2d_subset_offset(): + test_string = """ + PROGRAM intrinsic_any_test + implicit none + integer, dimension(20:24,4) :: first + integer, dimension(5,7:10) :: second + logical, dimension(2) :: res + CALL intrinsic_any_test_function(first, second, res) + end + + SUBROUTINE intrinsic_any_test_function(first, second, res) + integer, dimension(20:24,4) :: first + integer, dimension(5,7:10) :: second + logical, dimension(2) :: res + + ! Now test subsets - make sure the equal values are only + ! in the tested area. + res(1) = any(first(20:21, 3:4) .ne. second(4:5, 8:9)) + res(2) = any(first(20:21, 3:4) .eq. second(4:5, 8:9)) + + END SUBROUTINE intrinsic_any_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_any_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + first[2:5, :] = 2 + first[0:2, 0:2] = 2 + + second = np.full(sizes, 1, order="F", dtype=np.int32) + second[0:3, :] = 3 + second[3:5, 0] = 3 + second[3:5, 3:5] = 3 + + res = np.full([2], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [0, 1] + +if __name__ == "__main__": + + test_fortran_frontend_any_array() + test_fortran_frontend_any_array_dim() + test_fortran_frontend_any_array_comparison() + test_fortran_frontend_any_array_scalar_comparison() + test_fortran_frontend_any_array_comparison_wrong_subset() + test_fortran_frontend_any_array_2d() + test_fortran_frontend_any_array_comparison_2d() + test_fortran_frontend_any_array_comparison_2d_subset() + test_fortran_frontend_any_array_comparison_2d_subset_offset() diff --git a/tests/fortran/intrinsic_count.py b/tests/fortran/intrinsic_count.py new file mode 100644 index 0000000000..ef55f9dd55 --- /dev/null +++ b/tests/fortran/intrinsic_count.py @@ -0,0 +1,371 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + + +def test_fortran_frontend_count_array(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + logical, dimension(5) :: d + integer, dimension(2) :: res + CALL intrinsic_count_test_function(d, res) + end + + SUBROUTINE intrinsic_count_test_function(d, res) + logical, dimension(5) :: d + integer, dimension(2) :: res + + res(1) = COUNT(d) + + END SUBROUTINE intrinsic_count_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + d = np.full([size], False, order="F", dtype=np.int32) + res = np.full([2], 42, order="F", dtype=np.int32) + + d[2] = True + sdfg(d=d, res=res) + assert res[0] == 1 + + d[2] = False + sdfg(d=d, res=res) + assert res[0] == 0 + + +def test_fortran_frontend_count_array_dim(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + logical, dimension(5) :: d + logical, dimension(2) :: res + CALL intrinsic_count_test_function(d, res) + end + + SUBROUTINE intrinsic_count_test_function(d, res) + logical, dimension(5) :: d + logical, dimension(2) :: res + + res(1) = COUNT(d, 1) + + END SUBROUTINE intrinsic_count_test_function + """ + + with pytest.raises(NotImplementedError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + + +def test_fortran_frontend_count_array_comparison(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + integer, dimension(5) :: first + integer, dimension(5) :: second + logical, dimension(7) :: res + CALL intrinsic_count_test_function(first, second, res) + end + + SUBROUTINE intrinsic_count_test_function(first, second, res) + integer, dimension(5) :: first + integer, dimension(5) :: second + logical, dimension(7) :: res + + res(1) = COUNT(first .eq. second) + res(2) = COUNT(first(:) .eq. second) + res(3) = COUNT(first .eq. second(:)) + res(4) = COUNT(first(:) .eq. second(:)) + res(5) = COUNT(first(1:5) .eq. second(1:5)) + ! This will also be true - the only same + ! element is at position 3. + res(6) = COUNT(first(1:3) .eq. second(3:5)) + res(7) = COUNT(first(1:2) .eq. second(4:5)) + + END SUBROUTINE intrinsic_count_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + first = np.full([size], 1, order="F", dtype=np.int32) + second = np.full([size], 1, order="F", dtype=np.int32) + second[2] = 2 + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [4, 4, 4, 4, 4, 2, 2] + + second = np.full([size], 2, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + sdfg(first=first, second=second, res=res) + for val in res: + assert val == 0 + + second = np.full([size], 1, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + sdfg(first=first, second=second, res=res) + assert list(res) == [5, 5, 5, 5, 5, 3, 2] + +def test_fortran_frontend_count_array_scalar_comparison(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + integer, dimension(5) :: first + logical, dimension(9) :: res + CALL intrinsic_count_test_function(first, res) + end + + SUBROUTINE intrinsic_count_test_function(first, res) + integer, dimension(5) :: first + logical, dimension(9) :: res + + res(1) = COUNT(first .eq. 42) + res(2) = COUNT(first(:) .eq. 42) + res(3) = COUNT(first(1:2) .eq. 42) + res(4) = COUNT(first(3) .eq. 42) + res(5) = COUNT(first(3:5) .eq. 42) + res(6) = COUNT(42 .eq. first) + res(7) = COUNT(42 .ne. first) + res(8) = COUNT(6 .lt. first) + res(9) = COUNT(6 .gt. first) + + END SUBROUTINE intrinsic_count_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + first = np.full([size], 1, order="F", dtype=np.int32) + res = np.full([9], 0, order="F", dtype=np.int32) + + sdfg(first=first, res=res) + assert list(res) == [0, 0, 0, 0, 0, 0, 5, 0, size] + + first[1] = 42 + sdfg(first=first, res=res) + assert list(res) == [1, 1, 1, 0, 0, 1, 4, 1, size - 1] + + first[1] = 5 + first[2] = 42 + sdfg(first=first, res=res) + assert list(res) == [1, 1, 0, 1, 1, 1, 4, 1, size - 1] + + first[2] = 7 + first[3] = 42 + sdfg(first=first, res=res) + assert list(res) == [1, 1, 0, 0, 1, 1, 4, 2, size - 2] + +def test_fortran_frontend_count_array_comparison_wrong_subset(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + logical, dimension(5) :: first + logical, dimension(5) :: second + logical, dimension(2) :: res + CALL intrinsic_count_test_function(first, second, res) + end + + SUBROUTINE intrinsic_count_test_function(first, second, res) + logical, dimension(5) :: first + logical, dimension(5) :: second + logical, dimension(2) :: res + + res(1) = COUNT(first(1:2) .eq. second(2:5)) + + END SUBROUTINE intrinsic_count_test_function + """ + + with pytest.raises(TypeError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + +def test_fortran_frontend_count_array_2d(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + logical, dimension(5,7) :: d + logical, dimension(2) :: res + CALL intrinsic_count_test_function(d, res) + end + + SUBROUTINE intrinsic_count_test_function(d, res) + logical, dimension(5,7) :: d + logical, dimension(2) :: res + + res(1) = COUNT(d) + + END SUBROUTINE intrinsic_count_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 7] + d = np.full(sizes, True, order="F", dtype=np.int32) + res = np.full([2], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + assert res[0] == 35 + + d[2,2] = False + sdfg(d=d, res=res) + assert res[0] == 34 + + d = np.full(sizes, False, order="F", dtype=np.int32) + sdfg(d=d, res=res) + assert res[0] == 0 + + d[2,2] = True + sdfg(d=d, res=res) + assert res[0] == 1 + +def test_fortran_frontend_count_array_comparison_2d(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(7) :: res + CALL intrinsic_count_test_function(first, second, res) + end + + SUBROUTINE intrinsic_count_test_function(first, second, res) + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(7) :: res + + res(1) = COUNT(first .eq. second) + res(2) = COUNT(first(:,:) .eq. second) + res(3) = COUNT(first .eq. second(:,:)) + res(4) = COUNT(first(:,:) .eq. second(:,:)) + res(5) = COUNT(first(1:5,:) .eq. second(1:5,:)) + res(6) = COUNT(first(:,1:4) .eq. second(:,1:4)) + ! Now test subsets. + res(7) = COUNT(first(2:3, 3:4) .eq. second(2:3, 3:4)) + + END SUBROUTINE intrinsic_count_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + second = np.full(sizes, 2, order="F", dtype=np.int32) + second[1, 1] = 1 + res = np.full([7], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [1, 1, 1, 1, 1, 1, 0] + + second = np.full(sizes, 1, order="F", dtype=np.int32) + res = np.full([7], 0, order="F", dtype=np.int32) + sdfg(first=first, second=second, res=res) + assert list(res) == [20, 20, 20, 20, 20, 20, 4] + +def test_fortran_frontend_count_array_comparison_2d_subset(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(2) :: res + CALL intrinsic_count_test_function(first, second, res) + end + + SUBROUTINE intrinsic_count_test_function(first, second, res) + integer, dimension(5,4) :: first + integer, dimension(5,4) :: second + logical, dimension(2) :: res + + ! Now test subsets - make sure the equal values are only + ! in the tested area. + res(1) = COUNT(first(1:2, 3:4) .ne. second(4:5, 2:3)) + res(2) = COUNT(first(1:2, 3:4) .eq. second(4:5, 2:3)) + + END SUBROUTINE intrinsic_count_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + first[2:5, :] = 2 + first[0:2, 0:2] = 2 + + second = np.full(sizes, 1, order="F", dtype=np.int32) + second[0:3, :] = 2 + second[3:5, 0] = 2 + second[3:5, 3:5] = 2 + + res = np.full([2], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [0, 4] + +def test_fortran_frontend_count_array_comparison_2d_subset_offset(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + integer, dimension(20:24,4) :: first + integer, dimension(5,7:10) :: second + logical, dimension(2) :: res + CALL intrinsic_count_test_function(first, second, res) + end + + SUBROUTINE intrinsic_count_test_function(first, second, res) + integer, dimension(20:24,4) :: first + integer, dimension(5,7:10) :: second + logical, dimension(2) :: res + + ! Now test subsets - make sure the equal values are only + ! in the tested area. + res(1) = COUNT(first(20:21, 3:4) .ne. second(4:5, 8:9)) + res(2) = COUNT(first(20:21, 3:4) .eq. second(4:5, 8:9)) + + END SUBROUTINE intrinsic_count_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + first = np.full(sizes, 1, order="F", dtype=np.int32) + first[2:5, :] = 2 + first[0:2, 0:2] = 2 + + second = np.full(sizes, 1, order="F", dtype=np.int32) + second[0:3, :] = 2 + second[3:5, 0] = 2 + second[3:5, 3:5] = 2 + + res = np.full([2], 0, order="F", dtype=np.int32) + + sdfg(first=first, second=second, res=res) + assert list(res) == [0, 4] + +if __name__ == "__main__": + + test_fortran_frontend_count_array() + test_fortran_frontend_count_array_dim() + test_fortran_frontend_count_array_comparison() + test_fortran_frontend_count_array_scalar_comparison() + test_fortran_frontend_count_array_comparison_wrong_subset() + test_fortran_frontend_count_array_2d() + test_fortran_frontend_count_array_comparison_2d() + test_fortran_frontend_count_array_comparison_2d_subset() + test_fortran_frontend_count_array_comparison_2d_subset_offset() diff --git a/tests/fortran/intrinsic_merge.py b/tests/fortran/intrinsic_merge.py new file mode 100644 index 0000000000..1778b9c2fb --- /dev/null +++ b/tests/fortran/intrinsic_merge.py @@ -0,0 +1,283 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_merge_1d(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, mask, res) + end + + SUBROUTINE merge_test_function(input1, input2, mask, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + integer, dimension(7) :: mask + double precision, dimension(7) :: res + + res = MERGE(input1, input2, mask) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + mask = np.full([size], 0, order="F", dtype=np.int32) + res = np.full([size], 40, order="F", dtype=np.float64) + + sdfg(input1=first, input2=second, mask=mask, res=res) + for val in res: + assert val == 42 + + for i in range(int(size/2)): + mask[i] = 1 + sdfg(input1=first, input2=second, mask=mask, res=res) + for i in range(int(size/2)): + assert res[i] == 13 + for i in range(int(size/2), size): + assert res[i] == 42 + + mask[:] = 0 + for i in range(size): + if i % 2 == 1: + mask[i] = 1 + sdfg(input1=first, input2=second, mask=mask, res=res) + for i in range(size): + if i % 2 == 1: + assert res[i] == 13 + else: + assert res[i] == 42 + +def test_fortran_frontend_merge_comparison_scalar(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, res) + end + + SUBROUTINE merge_test_function(input1, input2, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: res + + res = MERGE(input1, input2, input1 .eq. 3) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + res = np.full([size], 40, order="F", dtype=np.float64) + + sdfg(input1=first, input2=second, res=res) + for val in res: + assert val == 42 + + for i in range(int(size/2)): + first[i] = 3 + sdfg(input1=first, input2=second, res=res) + for i in range(int(size/2)): + assert res[i] == 3 + for i in range(int(size/2), size): + assert res[i] == 42 + + first[:] = 13 + for i in range(size): + if i % 2 == 1: + first[i] = 3 + sdfg(input1=first, input2=second, res=res) + for i in range(size): + if i % 2 == 1: + assert res[i] == 3 + else: + assert res[i] == 42 + +def test_fortran_frontend_merge_comparison_arrays(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, res) + end + + SUBROUTINE merge_test_function(input1, input2, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(7) :: res + + res = MERGE(input1, input2, input1 .lt. input2) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + res = np.full([size], 40, order="F", dtype=np.float64) + + sdfg(input1=first, input2=second, res=res) + for val in res: + assert val == 13 + + for i in range(int(size/2)): + first[i] = 45 + sdfg(input1=first, input2=second, res=res) + for i in range(int(size/2)): + assert res[i] == 42 + for i in range(int(size/2), size): + assert res[i] == 13 + + first[:] = 13 + for i in range(size): + if i % 2 == 1: + first[i] = 45 + sdfg(input1=first, input2=second, res=res) + for i in range(size): + if i % 2 == 1: + assert res[i] == 42 + else: + assert res[i] == 13 + + +def test_fortran_frontend_merge_comparison_arrays_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(14) :: mask1 + double precision, dimension(14) :: mask2 + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, mask1, mask2, res) + end + + SUBROUTINE merge_test_function(input1, input2, mask1, mask2, res) + double precision, dimension(7) :: input1 + double precision, dimension(7) :: input2 + double precision, dimension(14) :: mask1 + double precision, dimension(14) :: mask2 + double precision, dimension(7) :: res + + res = MERGE(input1, input2, mask1(3:9) .lt. mask2(5:11)) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size], 42, order="F", dtype=np.float64) + mask1 = np.full([size*2], 30, order="F", dtype=np.float64) + mask2 = np.full([size*2], 0, order="F", dtype=np.float64) + res = np.full([size], 40, order="F", dtype=np.float64) + + mask1[2:9] = 3 + mask2[4:11] = 4 + sdfg(input1=first, input2=second, mask1=mask1, mask2=mask2, res=res) + for val in res: + assert val == 13 + + +def test_fortran_frontend_merge_array_shift(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM merge_test + implicit none + double precision, dimension(7) :: input1 + double precision, dimension(21) :: input2 + double precision, dimension(14) :: mask1 + double precision, dimension(14) :: mask2 + double precision, dimension(7) :: res + CALL merge_test_function(input1, input2, mask1, mask2, res) + end + + SUBROUTINE merge_test_function(input1, input2, mask1, mask2, res) + double precision, dimension(7) :: input1 + double precision, dimension(21) :: input2 + double precision, dimension(14) :: mask1 + double precision, dimension(14) :: mask2 + double precision, dimension(7) :: res + + res = MERGE(input1, input2(13:19), mask1(3:9) .gt. mask2(5:11)) + + END SUBROUTINE merge_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + first = np.full([size], 13, order="F", dtype=np.float64) + second = np.full([size*3], 42, order="F", dtype=np.float64) + mask1 = np.full([size*2], 30, order="F", dtype=np.float64) + mask2 = np.full([size*2], 0, order="F", dtype=np.float64) + res = np.full([size], 40, order="F", dtype=np.float64) + + second[12:19] = 100 + mask1[2:9] = 3 + mask2[4:11] = 4 + sdfg(input1=first, input2=second, mask1=mask1, mask2=mask2, res=res) + for val in res: + assert val == 100 + + +if __name__ == "__main__": + + test_fortran_frontend_merge_1d() + test_fortran_frontend_merge_comparison_scalar() + test_fortran_frontend_merge_comparison_arrays() + test_fortran_frontend_merge_comparison_arrays_offset() + test_fortran_frontend_merge_array_shift() diff --git a/tests/fortran/intrinsic_minmaxval.py b/tests/fortran/intrinsic_minmaxval.py new file mode 100644 index 0000000000..6a32237d37 --- /dev/null +++ b/tests/fortran/intrinsic_minmaxval.py @@ -0,0 +1,252 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_minval_double(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM minval_test + implicit none + double precision, dimension(7) :: d + double precision, dimension(4) :: res + CALL minval_test_function(d, res) + end + + SUBROUTINE minval_test_function(d, res) + double precision, dimension(7) :: d + double precision, dimension(0) :: dt + double precision, dimension(4) :: res + + res(1) = MINVAL(d) + res(2) = MINVAL(d(:)) + res(3) = MINVAL(d(3:6)) + res(4) = MINVAL(dt) + + END SUBROUTINE minval_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "minval_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + + assert res[0] == d[0] + assert res[1] == d[0] + assert res[2] == d[2] + # It should be the dace max for integer + assert res[3] == np.finfo(np.float64).max + + # Minimum is in the beginning + for i in range(size): + d[i] = 10 - i + sdfg(d=d, res=res) + assert res[0] == d[-1] + assert res[1] == d[-1] + assert res[2] == d[5] + # It should be the dace max for integer + assert res[3] == np.finfo(np.float64).max + +def test_fortran_frontend_minval_int(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM minval_test + implicit none + integer, dimension(7) :: d + integer, dimension(4) :: res + CALL minval_test_function(d, res) + end + + SUBROUTINE minval_test_function(d, res) + integer, dimension(7) :: d + integer, dimension(0) :: dt + integer, dimension(4) :: res + + res(1) = MINVAL(d) + res(2) = MINVAL(d(:)) + res(3) = MINVAL(d(3:6)) + res(4) = MINVAL(dt) + + END SUBROUTINE minval_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "minval_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + d = np.full([size], 0, order="F", dtype=np.int32) + for i in range(size): + d[i] = i + 1 + res = np.full([4], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + + assert res[0] == d[0] + assert res[1] == d[0] + assert res[2] == d[2] + # It should be the dace max for integer + assert res[3] == np.iinfo(np.int32).max + + # Minimum is in the beginning + for i in range(size): + d[i] = 10 - i + sdfg(d=d, res=res) + assert res[0] == d[-1] + assert res[1] == d[-1] + assert res[2] == d[5] + # It should be the dace max for integer + assert res[3] == np.iinfo(np.int32).max + + # Minimum is in the middle + d = np.full([size], 0, order="F", dtype=np.int32) + d[:] = [-5, 10, -6, 4, 32, 42, -1] + res = np.full([4], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + + assert res[0] == d[2] + assert res[1] == d[2] + assert res[2] == d[2] + # It should be the dace max for integer + assert res[3] == np.iinfo(np.int32).max + +def test_fortran_frontend_maxval_double(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM minval_test + implicit none + double precision, dimension(7) :: d + double precision, dimension(4) :: res + CALL minval_test_function(d, res) + end + + SUBROUTINE minval_test_function(d, res) + double precision, dimension(7) :: d + double precision, dimension(0) :: dt + double precision, dimension(4) :: res + + res(1) = MAXVAL(d) + res(2) = MAXVAL(d(:)) + res(3) = MAXVAL(d(3:6)) + res(4) = MAXVAL(dt) + + END SUBROUTINE minval_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "minval_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + + assert res[0] == d[-1] + assert res[1] == d[-1] + assert res[2] == d[5] + # It should be the dace max for integer + assert res[3] == np.finfo(np.float64).min + + # Minimum is in the beginning + for i in range(size): + d[i] = 10 - i + sdfg(d=d, res=res) + assert res[0] == d[0] + assert res[1] == d[0] + assert res[2] == d[2] + # It should be the dace max for integer + assert res[3] == np.finfo(np.float64).min + +def test_fortran_frontend_maxval_int(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM minval_test + implicit none + integer, dimension(7) :: d + integer, dimension(4) :: res + CALL minval_test_function(d, res) + end + + SUBROUTINE minval_test_function(d, res) + integer, dimension(7) :: d + integer, dimension(0) :: dt + integer, dimension(4) :: res + + res(1) = MAXVAL(d) + res(2) = MAXVAL(d(:)) + res(3) = MAXVAL(d(3:6)) + res(4) = MAXVAL(dt) + + END SUBROUTINE minval_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + sdfg = fortran_parser.create_sdfg_from_string(test_string, "minval_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + size = 7 + + # Minimum is in the beginning + d = np.full([size], 0, order="F", dtype=np.int32) + for i in range(size): + d[i] = i + 1 + res = np.full([4], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + + assert res[0] == d[-1] + assert res[1] == d[-1] + assert res[2] == d[5] + # It should be the dace max for integer + assert res[3] == np.iinfo(np.int32).min + + # Minimum is in the beginning + for i in range(size): + d[i] = 10 - i + sdfg(d=d, res=res) + assert res[0] == d[0] + assert res[1] == d[0] + assert res[2] == d[2] + # It should be the dace max for integer + assert res[3] == np.iinfo(np.int32).min + + # Minimum is in the middle + d = np.full([size], 0, order="F", dtype=np.int32) + d[:] = [41, 10, 42, -5, 32, 41, 40] + res = np.full([4], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + + assert res[0] == d[2] + assert res[1] == d[2] + assert res[2] == d[2] + # It should be the dace max for integer + assert res[3] == np.iinfo(np.int32).min + +if __name__ == "__main__": + + test_fortran_frontend_minval_double() + test_fortran_frontend_minval_int() + test_fortran_frontend_maxval_double() + test_fortran_frontend_maxval_int() diff --git a/tests/fortran/intrinsic_product.py b/tests/fortran/intrinsic_product.py new file mode 100644 index 0000000000..fcf9dc8057 --- /dev/null +++ b/tests/fortran/intrinsic_product.py @@ -0,0 +1,116 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_product_array(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(7) :: d + double precision, dimension(3) :: res + CALL index_test_function(d, res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(7) :: d + double precision, dimension(3) :: res + + res(1) = PRODUCT(d) + res(2) = PRODUCT(d(:)) + res(3) = PRODUCT(d(2:5)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == np.prod(d) + assert res[1] == np.prod(d) + assert res[2] == np.prod(d[1:5]) + +def test_fortran_frontend_product_array_dim(): + test_string = """ + PROGRAM intrinsic_count_test + implicit none + logical, dimension(5) :: d + logical, dimension(2) :: res + CALL intrinsic_count_test_function(d, res) + end + + SUBROUTINE intrinsic_count_test_function(d, res) + logical, dimension(5) :: d + logical, dimension(2) :: res + + res(1) = PRODUCT(d, 1) + + END SUBROUTINE intrinsic_count_test_function + """ + + with pytest.raises(NotImplementedError): + fortran_parser.create_sdfg_from_string(test_string, "intrinsic_count_test", False) + +def test_fortran_frontend_product_2d(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + + res(1) = PRODUCT(d) + res(2) = PRODUCT(d(:,:)) + res(3) = PRODUCT(d(2:4, 2)) + res(4) = PRODUCT(d(2:4, 2:3)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 3] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 1 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == np.prod(d) + assert res[1] == np.prod(d) + assert res[2] == np.prod(d[1:4, 1]) + assert res[3] == np.prod(d[1:4, 1:3]) + +if __name__ == "__main__": + + test_fortran_frontend_product_array() + test_fortran_frontend_product_array_dim() + test_fortran_frontend_product_2d() diff --git a/tests/fortran/intrinsic_sum.py b/tests/fortran/intrinsic_sum.py new file mode 100644 index 0000000000..e933589e0f --- /dev/null +++ b/tests/fortran/intrinsic_sum.py @@ -0,0 +1,176 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_sum2loop_1d_without_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(7) :: d + double precision, dimension(3) :: res + CALL index_test_function(d, res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(7) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d(:)) + res(2) = SUM(d) + res(3) = SUM(d(2:6)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == (1 + size) * size / 2 + assert res[1] == (1 + size) * size / 2 + assert res[2] == (2 + size - 1) * (size - 2)/ 2 + +def test_fortran_frontend_sum2loop_1d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(2:6) :: d + double precision, dimension(3) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(2:6) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:)) + res(3) = SUM(d(3:5)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == (1 + size) * size / 2 + assert res[1] == (1 + size) * size / 2 + assert res[2] == (2 + size - 1) * (size - 2) / 2 + +def test_fortran_frontend_arr2loop_2d(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:,:)) + res(3) = SUM(d(2:4, 2)) + res(4) = SUM(d(2:4, 2:3)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 3] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 0 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == 105 + assert res[1] == 105 + assert res[2] == 21 + assert res[3] == 45 + +def test_fortran_frontend_arr2loop_2d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(2:6,7:10) :: d + double precision, dimension(3) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(2:6,7:10) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:,:)) + res(3) = SUM(d(3:5, 8:9)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 0 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == 190 + assert res[1] == 190 + assert res[2] == 57 + +if __name__ == "__main__": + + test_fortran_frontend_sum2loop_1d_without_offset() + test_fortran_frontend_sum2loop_1d_offset() + test_fortran_frontend_arr2loop_2d() + test_fortran_frontend_arr2loop_2d_offset() diff --git a/tests/fortran/sum_to_loop_offset.py b/tests/fortran/sum_to_loop_offset.py new file mode 100644 index 0000000000..e933589e0f --- /dev/null +++ b/tests/fortran/sum_to_loop_offset.py @@ -0,0 +1,176 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_sum2loop_1d_without_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(7) :: d + double precision, dimension(3) :: res + CALL index_test_function(d, res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(7) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d(:)) + res(2) = SUM(d) + res(3) = SUM(d(2:6)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == (1 + size) * size / 2 + assert res[1] == (1 + size) * size / 2 + assert res[2] == (2 + size - 1) * (size - 2)/ 2 + +def test_fortran_frontend_sum2loop_1d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(2:6) :: d + double precision, dimension(3) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(2:6) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:)) + res(3) = SUM(d(3:5)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == (1 + size) * size / 2 + assert res[1] == (1 + size) * size / 2 + assert res[2] == (2 + size - 1) * (size - 2) / 2 + +def test_fortran_frontend_arr2loop_2d(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:,:)) + res(3) = SUM(d(2:4, 2)) + res(4) = SUM(d(2:4, 2:3)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 3] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 0 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == 105 + assert res[1] == 105 + assert res[2] == 21 + assert res[3] == 45 + +def test_fortran_frontend_arr2loop_2d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(2:6,7:10) :: d + double precision, dimension(3) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(2:6,7:10) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:,:)) + res(3) = SUM(d(3:5, 8:9)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 0 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == 190 + assert res[1] == 190 + assert res[2] == 57 + +if __name__ == "__main__": + + test_fortran_frontend_sum2loop_1d_without_offset() + test_fortran_frontend_sum2loop_1d_offset() + test_fortran_frontend_arr2loop_2d() + test_fortran_frontend_arr2loop_2d_offset() diff --git a/tests/rtl/hardware_test.py b/tests/rtl/hardware_test.py index 821688f481..727dc7362b 100644 --- a/tests/rtl/hardware_test.py +++ b/tests/rtl/hardware_test.py @@ -1,4 +1,7 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" + Test suite for testing RTL integration with DaCe targeting Xilinx FPGAs. +""" import dace from dace.fpga_testing import rtl_test import numpy as np @@ -13,7 +16,7 @@ def make_vadd_sdfg(N: dace.symbol, veclen: int = 8): ''' Function for generating a simple vector addition SDFG that adds a vector `A` of `N` elements to a scalar `B` into a vector `C` of `N` elements, all using SystemVerilog. - The tasklet creates `veclen` instances of a floating point adder that operates on `N` elements. + The tasklet creates `veclen` instances of a floating point adder that operates on `N` elements. :param N: The number of elements the SDFG takes as input and output. :param veclen: The number of floating point adders to instantiate. @@ -197,7 +200,7 @@ def make_vadd_multi_sdfg(N, M): :param N: The number of elements to compute on. :param M: The number of compute PEs to initialize. - :return: An SDFG that has arguments `A` and `B`. + :return: An SDFG that has arguments `A` and `B`. ''' # add sdfg sdfg = dace.SDFG(f'integer_vector_plus_42_multiple_kernels_{N.get() // M.get()}') @@ -321,7 +324,7 @@ def make_vadd_multi_sdfg(N, M): @rtl_test() def test_hardware_vadd(): ''' - Test for the simple vector addition. + Test for the simple vector addition. ''' # add symbol @@ -346,7 +349,7 @@ def test_hardware_vadd(): @rtl_test() def test_hardware_add42_single(): ''' - Test for adding a constant using a single PE. + Test for adding a constant using a single PE. ''' N = dace.symbol('N') M = dace.symbol('M') @@ -428,10 +431,11 @@ def test_hardware_vadd_temporal_vectorization(): ''' Tests whether the multi-pumping optimization can be applied automatically by applying the temporal vectorization transformation. It starts from a numpy vector addition for generating the SDFG. This SDFG is then optimized by applying the vectorization, streaming memory, fpga and temporal vectorization transformations in that order. ''' - # TODO !!!!! THIS TEST STALLS IN HARDWARE EMULATION WITH VITIS 2021.2 !!!!! - # But it works fine for 2020.2 and 2022.2. It seems like everything but the - # last transaction correctly goes through just fine. The last transaction - # is never output by the floating point adder, but the inputs are consumed. + # TODO !!!!! THIS TEST STALLS IN HARDWARE EMULATION WITH VITIS 2021.2 and 2022.1 !!!!! + # But it works fine for 2020.2, 2022.2, and 2023.1. It seems like + # everything but the last transaction correctly goes through just fine. The + # last transaction is never output by the floating point adder, but the + # inputs are consumed. with dace.config.set_temporary('compiler', 'xilinx', 'frequency', value='"0:300\\|1:600"'): # Generate the test data and expected results size_n = 1024 diff --git a/tests/rtl/simulation_test.py b/tests/rtl/simulation_test.py index f20ff6133a..6b7ac2cd15 100644 --- a/tests/rtl/simulation_test.py +++ b/tests/rtl/simulation_test.py @@ -1,5 +1,7 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" + Test suite for testing RTL tasklets in DaCe with Verilator as a backend for simulation. +""" import dace import numpy as np import pytest diff --git a/tests/sdfg/data/tensor_test.py b/tests/sdfg/data/tensor_test.py new file mode 100644 index 0000000000..3057539f70 --- /dev/null +++ b/tests/sdfg/data/tensor_test.py @@ -0,0 +1,129 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +import pytest + +from scipy import sparse + + +def test_read_csr_tensor(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.TensorIndexDense(), 0), (dace.data.TensorIndexCompressed(), 1)], + nnz, + "CSR_Tensor") + + sdfg = dace.SDFG('tensor_csr_to_dense') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + sdfg.add_view('vindptr', csr_obj.members['idx1_pos'].shape, csr_obj.members['idx1_pos'].dtype) + sdfg.add_view('vindices', csr_obj.members['idx1_crd'].shape, csr_obj.members['idx1_crd'].dtype) + sdfg.add_view('vdata', csr_obj.members['values'].shape, csr_obj.members['values'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_edge(A, None, indptr, 'views', dace.Memlet.from_array('A.idx1_pos', csr_obj.members['idx1_pos'])) + state.add_edge(A, None, indices, 'views', dace.Memlet.from_array('A.idx1_crd', csr_obj.members['idx1_crd'])) + state.add_edge(A, None, data, 'views', dace.Memlet.from_array('A.values', csr_obj.members['values'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(idx1_pos=A.indptr.__array_interface__['data'][0], + idx1_crd=A.indices.__array_interface__['data'][0], + values=A.data.__array_interface__['data'][0]) + + func(A=inpA, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +def test_csr_fields(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + + csr = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.TensorIndexDense(), 0), (dace.data.TensorIndexCompressed(), 1)], + nnz, + "CSR_Matrix", + ) + + expected_fields = ["idx1_pos", "idx1_crd"] + assert all(key in csr.members.keys() for key in expected_fields) + + +def test_dia_fields(): + + M, N, nnz, num_diags = (dace.symbol(s) for s in ('M', 'N', 'nnz', 'num_diags')) + + diag = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.TensorIndexDense(), num_diags), + (dace.data.TensorIndexRange(), 0), + (dace.data.TensorIndexOffset(), 1), + ], + nnz, + "DIA_Matrix", + ) + + expected_fields = ["idx1_offset", "idx2_offset"] + assert all(key in diag.members.keys() for key in expected_fields) + + +def test_coo_fields(): + + I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) + + coo = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.TensorIndexCompressed(unique=False), 0), + (dace.data.TensorIndexSingleton(unique=False), 1), + (dace.data.TensorIndexSingleton(), 2), + ], + nnz, + "COO_3D_Tensor", + ) + + expected_fields = ["idx0_pos", "idx0_crd", "idx1_crd", "idx2_crd"] + assert all(key in coo.members.keys() for key in expected_fields) + + +if __name__ == "__main__": + test_read_csr_tensor() + test_csr_fields() + test_dia_fields() + test_coo_fields() diff --git a/tests/sdfg/nested_control_flow_regions_test.py b/tests/sdfg/nested_control_flow_regions_test.py new file mode 100644 index 0000000000..f29c093dad --- /dev/null +++ b/tests/sdfg/nested_control_flow_regions_test.py @@ -0,0 +1,18 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +import dace + + +def test_is_start_state_deprecation(): + sdfg = dace.SDFG('deprecation_test') + with pytest.deprecated_call(): + sdfg.add_state('state1', is_start_state=True) + sdfg2 = dace.SDFG('deprecation_test2') + state = dace.SDFGState('state2') + with pytest.deprecated_call(): + sdfg2.add_node(state, is_start_state=True) + + +if __name__ == '__main__': + test_is_start_state_deprecation() diff --git a/tests/sdfg_validate_names_test.py b/tests/sdfg_validate_names_test.py index dad79c8950..1650a4e4b1 100644 --- a/tests/sdfg_validate_names_test.py +++ b/tests/sdfg_validate_names_test.py @@ -28,7 +28,7 @@ def test_state_duplication(self): sdfg = dace.SDFG('ok') s1 = sdfg.add_state('also_ok') s2 = sdfg.add_state('also_ok') - s2.set_label('also_ok') + s2.label = 'also_ok' sdfg.add_edge(s1, s2, dace.InterstateEdge()) sdfg.validate() self.fail('Failed to detect duplicate state') diff --git a/tests/subset_covers_precise_test.py b/tests/subset_covers_precise_test.py new file mode 100644 index 0000000000..8c688ea6c1 --- /dev/null +++ b/tests/subset_covers_precise_test.py @@ -0,0 +1,212 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import pytest + +import dace +from dace.config import Config +from dace.subsets import Indices, Range + + +def test_integer_overlap_same_step_no_cover(): + """ + Tests ranges with overlapping bounding boxes neither of them covering the other. + The ranges have the same step size. Covers_precise should return false. + """ + subset1 = Range.from_string("0:10:1") + subset2 = Range.from_string("5:11:1") + + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("0:10:2") + subset2 = Range.from_string("2:11:1") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + +def test_integer_bounding_box_cover_coprime_step(): + """ + Tests ranges where the boundingbox of subset1 covers the boundingbox of subset2 but + step sizes of the subsets are coprime so subset1 does not cover subset2. + """ + subset1 = Range.from_string("0:10:3") + subset2 = Range.from_string("0:10:2") + + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("0:10:3, 5:10:2") + subset2 = Range.from_string("0:10:2, 5:10:4") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("0:10:3, 6:10:2") + subset2 = Range.from_string("0:10:2, 5:10:4") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + +def test_integer_same_step_different_start(): + """ + Tests range where the bounding box of subset1 covers the bounding box of subset2 + but since subset2 starts at an offset that is not a multiple subset1's stepsize it + is not contained in subset1. + """ + subset1 = Range.from_string("0:10:3") + subset2 = Range.from_string("1:10:3") + + assert (subset1.covers_precise(subset2) is False) + + +def test_integer_bounding_box_symbolic_step(): + """ + Tests ranges where the step is symbolic but the start and end are not. + For 2 subsets s1 and s2 where s1's start is equal to s2's start and both subsets' step + sizes are symbolic s1.covers_precise(s2) should only return true iff s2's step size is + a multiple of s1's step size. + For 2 subsets s1 and s2 where s1's start is not equal to s2's start and both subsets' step + sizes are symbolic, s1.covers_precise(s2) should return false. + """ + subset1 = Range.from_string("0:20:s") + subset2 = Range.from_string("0:10:s") + subset3 = Range.from_string("0:10:2 * s") + + assert (subset1.covers_precise(subset2)) + assert (subset1.covers_precise(subset3)) + assert (subset3.covers_precise(subset1) is False) + assert (subset3.covers_precise(subset2) is False) + + subset1 = Range.from_string("30:50:k") + subset2 = Range.from_string("40:50:k") + assert (subset1.covers_precise(subset2) is False) + + +def test_ranges_symbolic_boundaries(): + """ + Tests where the boundaries of ranges are symbolic. + The function subset1.covers_precise(subset2) should return true only when the + start, end, and step size of subset1 are multiples of those in subset2 + """ + subset1 = Range.from_string("N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2)) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("N + 1:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("-N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + +def test_symbolic_boundaries_not_symbolic_positive(): + """ + Tests from test_symbolic_boundaries with symbolic_positive flag deactivated. + """ + symbolic_positive = Config.get('optimizer', 'symbolic_positive') + Config.set('optimizer', 'symbolic_positive', value=False) + + subset1 = Range.from_string("N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2)) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("N + 1:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("-N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + Config.set('optimizer', 'symbolic_positive', value=symbolic_positive) + + +def test_range_indices(): + """ + Tests the handling of indices covering ranges and vice versa. + Given a range r and indices i: + If r's bounding box covers i r.covers_precise(i) should return true iff + i is covered by the step of r. + i.covers_precise(r) should only return true iff r.start == r.end == i. + If i is not in r's bounding box i.covers_precise(r) and r.covers_precise(i) + should return false + """ + subset1 = Indices.from_string('1') + subset2 = Range.from_string('0:2:1') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2) is False) + + subset1 = Indices.from_string('3') + subset2 = Range.from_string('0:4:2') + assert (subset2.covers_precise(subset1) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Indices.from_string('3') + subset2 = Range.from_string('0:2:1') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + +def test_index_index(): + """ + Tests the handling of indices covering indices. + Given two indices i1 and i2 i1.covers_precise should only return true iff i1 = i2 + """ + subset1 = Indices.from_string('1') + subset2 = Indices.from_string('1') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + + subset1 = Indices.from_string('1') + subset2 = Indices.from_string('2') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + + subset1 = Indices.from_string('1, 2') + subset2 = Indices.from_string('1, 2') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + + subset1 = Indices.from_string('2, 1') + subset2 = Indices.from_string('1, 2') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + + subset1 = Indices.from_string('i') + subset2 = Indices.from_string('j') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + + subset1 = Indices.from_string('i') + subset2 = Indices.from_string('i') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + + subset1 = Indices.from_string('i, j') + subset2 = Indices.from_string('i, k') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + + subset1 = Indices.from_string('i, j') + subset2 = Indices.from_string('i, j') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + + + + +if __name__ == "__main__": + test_integer_overlap_same_step_no_cover() + test_integer_bounding_box_cover_coprime_step() + test_integer_same_step_different_start() + test_integer_bounding_box_symbolic_step() + test_ranges_symbolic_boundaries() + test_symbolic_boundaries_not_symbolic_positive() + test_range_indices() + test_index_index() diff --git a/tests/transformations/change_strides_test.py b/tests/transformations/change_strides_test.py new file mode 100644 index 0000000000..3975761fd5 --- /dev/null +++ b/tests/transformations/change_strides_test.py @@ -0,0 +1,48 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace import nodes +from dace.dtypes import ScheduleType +from dace.memlet import Memlet +from dace.transformation.change_strides import change_strides + + +def change_strides_test(): + sdfg = dace.SDFG('change_strides_test') + N = dace.symbol('N') + M = dace.symbol('M') + sdfg.add_array('A', [N, M], dace.float64) + sdfg.add_array('B', [N, M, 3], dace.float64) + state = sdfg.add_state() + + task1, mentry1, mexit1 = state.add_mapped_tasklet( + name="map1", + map_ranges={'i': '0:N', 'j': '0:M'}, + inputs={'a': Memlet(data='A', subset='i, j')}, + outputs={'b': Memlet(data='B', subset='i, j, 0')}, + code='b = a + 1', + external_edges=True, + propagate=True) + + # Check that states are as expected + changed_sdfg = change_strides(sdfg, ['N'], ScheduleType.Sequential) + assert len(changed_sdfg.states()) == 3 + assert len(changed_sdfg.out_edges(changed_sdfg.start_state)) == 1 + work_state = changed_sdfg.out_edges(changed_sdfg.start_state)[0].dst + nsdfg = None + for node in work_state.nodes(): + if isinstance(node, nodes.NestedSDFG): + nsdfg = node + # Check shape and strides of data inside nested SDFG + assert nsdfg is not None + assert nsdfg.sdfg.data('A').shape == (N, M) + assert nsdfg.sdfg.data('B').shape == (N, M, 3) + assert nsdfg.sdfg.data('A').strides == (1, N) + assert nsdfg.sdfg.data('B').strides == (1, N, M*N) + + +def main(): + change_strides_test() + + +if __name__ == '__main__': + main() diff --git a/tests/transformations/move_assignment_outside_if_test.py b/tests/transformations/move_assignment_outside_if_test.py new file mode 100644 index 0000000000..323e83cf61 --- /dev/null +++ b/tests/transformations/move_assignment_outside_if_test.py @@ -0,0 +1,161 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.transformation.interstate import MoveAssignmentOutsideIf +from dace.sdfg import InterstateEdge +from dace.memlet import Memlet +from dace.sdfg.nodes import Tasklet + + +def one_variable_simple_test(const_value: int = 0): + """ Test with one variable which has formula and const branch. Uses the given const value """ + sdfg = dace.SDFG('one_variable_simple_test') + # Create guard state and one state where A is set to 0 and another where it is set using B and some formula + guard = sdfg.add_state('guard', is_start_state=True) + formula_state = sdfg.add_state('formula', is_start_state=False) + const_state = sdfg.add_state('const', is_start_state=False) + sdfg.add_array('A', [1], dace.float64) + sdfg.add_array('B', [1], dace.float64) + + # Add tasklet inside states + formula_tasklet = formula_state.add_tasklet('formula_assign', {'b'}, {'a'}, 'a = 2*b') + formula_state.add_memlet_path(formula_state.add_read('B'), formula_tasklet, memlet=Memlet(data='B', subset='0'), + dst_conn='b') + formula_state.add_memlet_path(formula_tasklet, formula_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + const_tasklet = const_state.add_tasklet('const_assign', {}, {'a'}, f"a = {const_value}") + const_state.add_memlet_path(const_tasklet, const_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + + # Create if-else condition such that either the formula state or the const state is executed + sdfg.add_edge(guard, formula_state, InterstateEdge(condition='B[0] < 0.5')) + sdfg.add_edge(guard, const_state, InterstateEdge(condition='B[0] >= 0.5')) + sdfg.validate() + + # Assure transformation is applied + assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + # SDFG now starts with a state containing the const_tasklet + assert const_tasklet in sdfg.start_state.nodes() + # The formula state has only one in_edge with the condition + assert len(sdfg.in_edges(formula_state)) == 1 + assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(B[0] < 0.5)' + # All state have at most one out_edge -> there is no if-else branching anymore + for state in sdfg.states(): + assert len(sdfg.out_edges(state)) <= 1 + + +def multiple_variable_test(): + """ Test with multiple variables where not all appear in the const branch """ + sdfg = dace.SDFG('one_variable_simple_test') + # Create guard state and one state where A is set to 0 and another where it is set using B and some formula + guard = sdfg.add_state('guard', is_start_state=True) + formula_state = sdfg.add_state('formula', is_start_state=False) + const_state = sdfg.add_state('const', is_start_state=False) + sdfg.add_array('A', [1], dace.float64) + sdfg.add_array('B', [1], dace.float64) + sdfg.add_array('C', [1], dace.float64) + sdfg.add_array('D', [1], dace.float64) + + A = formula_state.add_access('A') + B = formula_state.add_access('B') + C = formula_state.add_access('C') + D = formula_state.add_access('D') + formula_tasklet_a = formula_state.add_tasklet('formula_assign', {'b'}, {'a'}, 'a = 2*b') + formula_state.add_memlet_path(B, formula_tasklet_a, memlet=Memlet(data='B', subset='0'), dst_conn='b') + formula_state.add_memlet_path(formula_tasklet_a, A, memlet=Memlet(data='A', subset='0'), src_conn='a') + formula_tasklet_b = formula_state.add_tasklet('formula_assign', {'c'}, {'b'}, 'a = 2*c') + formula_state.add_memlet_path(C, formula_tasklet_b, memlet=Memlet(data='C', subset='0'), dst_conn='c') + formula_state.add_memlet_path(formula_tasklet_b, B, memlet=Memlet(data='B', subset='0'), src_conn='b') + formula_tasklet_c = formula_state.add_tasklet('formula_assign', {'d'}, {'c'}, 'a = 2*d') + formula_state.add_memlet_path(D, formula_tasklet_c, memlet=Memlet(data='D', subset='0'), dst_conn='d') + formula_state.add_memlet_path(formula_tasklet_c, C, memlet=Memlet(data='C', subset='0'), src_conn='c') + + const_tasklet_a = const_state.add_tasklet('const_assign', {}, {'a'}, 'a = 0') + const_state.add_memlet_path(const_tasklet_a, const_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + const_tasklet_b = const_state.add_tasklet('const_assign', {}, {'b'}, 'b = 0') + const_state.add_memlet_path(const_tasklet_b, const_state.add_write('B'), memlet=Memlet(data='B', subset='0'), + src_conn='b') + + # Create if-else condition such that either the formula state or the const state is executed + sdfg.add_edge(guard, formula_state, InterstateEdge(condition='D[0] < 0.5')) + sdfg.add_edge(guard, const_state, InterstateEdge(condition='D[0] >= 0.5')) + sdfg.validate() + + # Assure transformation is applied + assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + # There are no other tasklets in the start state beside the const assignment tasklet as there are no other const + # assignments + for node in sdfg.start_state.nodes(): + if isinstance(node, Tasklet): + assert node == const_tasklet_a or node == const_tasklet_b + # The formula state has only one in_edge with the condition + assert len(sdfg.in_edges(formula_state)) == 1 + assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(D[0] < 0.5)' + # All state have at most one out_edge -> there is no if-else branching anymore + for state in sdfg.states(): + assert len(sdfg.out_edges(state)) <= 1 + + +def multiple_variable_not_all_const_test(): + """ Test with multiple variables where not all get const-assigned in const branch """ + sdfg = dace.SDFG('one_variable_simple_test') + # Create guard state and one state where A is set to 0 and another where it is set using B and some formula + guard = sdfg.add_state('guard', is_start_state=True) + formula_state = sdfg.add_state('formula', is_start_state=False) + const_state = sdfg.add_state('const', is_start_state=False) + sdfg.add_array('A', [1], dace.float64) + sdfg.add_array('B', [1], dace.float64) + sdfg.add_array('C', [1], dace.float64) + + A = formula_state.add_access('A') + B = formula_state.add_access('B') + C = formula_state.add_access('C') + formula_tasklet_a = formula_state.add_tasklet('formula_assign', {'b'}, {'a'}, 'a = 2*b') + formula_state.add_memlet_path(B, formula_tasklet_a, memlet=Memlet(data='B', subset='0'), dst_conn='b') + formula_state.add_memlet_path(formula_tasklet_a, A, memlet=Memlet(data='A', subset='0'), src_conn='a') + formula_tasklet_b = formula_state.add_tasklet('formula_assign', {'c'}, {'b'}, 'a = 2*c') + formula_state.add_memlet_path(C, formula_tasklet_b, memlet=Memlet(data='C', subset='0'), dst_conn='c') + formula_state.add_memlet_path(formula_tasklet_b, B, memlet=Memlet(data='B', subset='0'), src_conn='b') + + const_tasklet_a = const_state.add_tasklet('const_assign', {}, {'a'}, 'a = 0') + const_state.add_memlet_path(const_tasklet_a, const_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + const_tasklet_b = const_state.add_tasklet('const_assign', {'c'}, {'b'}, 'b = 1.5 * c') + const_state.add_memlet_path(const_state.add_read('C'), const_tasklet_b, memlet=Memlet(data='C', subset='0'), + dst_conn='c') + const_state.add_memlet_path(const_tasklet_b, const_state.add_write('B'), memlet=Memlet(data='B', subset='0'), + src_conn='b') + + # Create if-else condition such that either the formula state or the const state is executed + sdfg.add_edge(guard, formula_state, InterstateEdge(condition='C[0] < 0.5')) + sdfg.add_edge(guard, const_state, InterstateEdge(condition='C[0] >= 0.5')) + sdfg.validate() + + # Assure transformation is applied + assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + # There are no other tasklets in the start state beside the const assignment tasklet as there are no other const + # assignments + for node in sdfg.start_state.nodes(): + if isinstance(node, Tasklet): + assert node == const_tasklet_a or node == const_tasklet_b + # The formula state has only one in_edge with the condition + assert len(sdfg.in_edges(formula_state)) == 1 + assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(C[0] < 0.5)' + # Guard still has two outgoing edges as if-else pattern still exists + assert len(sdfg.out_edges(guard)) == 2 + # const state now has only const_tasklet_b left plus two access nodes + assert len(const_state.nodes()) == 3 + for node in const_state.nodes(): + if isinstance(node, Tasklet): + assert node == const_tasklet_b + + +def main(): + one_variable_simple_test(0) + one_variable_simple_test(2) + multiple_variable_test() + multiple_variable_not_all_const_test() + + +if __name__ == '__main__': + main() diff --git a/tests/transformations/otf_map_fusion_test.py b/tests/transformations/otf_map_fusion_test.py index eb871566d1..4786901887 100644 --- a/tests/transformations/otf_map_fusion_test.py +++ b/tests/transformations/otf_map_fusion_test.py @@ -330,6 +330,36 @@ def test_trivial_fusion_nested_sdfg(): assert (res == res_fused).all() +@dace.program +def trivial_fusion_none_connectors(B: dace.float64[10, 20]): + tmp = dace.define_local([10, 20], dtype=B.dtype) + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + b >> tmp[i, j] + b = 0 + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp[i, j] + b >> B[i, j] + b = a + 2 + + +def test_trivial_fusion_none_connectors(): + sdfg = trivial_fusion_none_connectors.to_sdfg() + sdfg.simplify() + assert count_maps(sdfg) == 2 + + sdfg.apply_transformations(OTFMapFusion) + assert count_maps(sdfg) == 1 + + B = np.zeros((10, 20)) + ref = np.zeros((10, 20)) + 2 + + sdfg(B=B) + assert np.allclose(B, ref) + + @dace.program def undefined_subset(A: dace.float64[10], B: dace.float64[10]): tmp = dace.define_local([10], dtype=A.dtype) @@ -703,6 +733,7 @@ def test_hdiff(): test_trivial_fusion_permute() test_trivial_fusion_not_remove_map() test_trivial_fusion_nested_sdfg() + test_trivial_fusion_none_connectors() # Defined subsets test_undefined_subset() diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py new file mode 100644 index 0000000000..091b2a9db8 --- /dev/null +++ b/tests/transformations/wcr_conversion_test.py @@ -0,0 +1,247 @@ +import dace + +from dace.transformation.dataflow import AugAssignToWCR + + +def test_aug_assign_tasklet_lhs(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + k + + sdfg = sdfg_aug_assign_tasklet_lhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + (k + 1) + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = k + a + + sdfg = sdfg_aug_assign_tasklet_rhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = (k + 1) + a + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + k; + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + (k + 1); + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = (k + 1) + a; + """ + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(a, c); + """ + + sdfg = sdfg_aug_assign_tasklet_func_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_rhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(c, a); + """ + + sdfg = sdfg_aug_assign_tasklet_func_rhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_free_map(): + + @dace.program + def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[0] + k << B[i] + b >> A[0] + """ + b = k * a; + """ + + sdfg = sdfg_aug_assign_free_map.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_state_fission_map(): + + @dace.program + def sdfg_aug_assign_state_fission(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet: + a << B[i] + b >> A[i] + b = a + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + sdfg = sdfg_aug_assign_state_fission.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 2 + + +def test_free_map_permissive(): + + @dace.program + def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = k * a; + """ + + sdfg = sdfg_free_map_permissive.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=False) + assert applied == 0 + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True) + assert applied == 1 diff --git a/tests/trivial_map_elimination_test.py b/tests/trivial_map_elimination_test.py index 44b1f77652..9600dad640 100644 --- a/tests/trivial_map_elimination_test.py +++ b/tests/trivial_map_elimination_test.py @@ -25,7 +25,69 @@ def trivial_map_sdfg(): return sdfg +def trivial_map_init_sdfg(): + sdfg = dace.SDFG('trivial_map_range_expanded') + sdfg.add_array('B', [5, 1], dace.float64) + state = sdfg.add_state() + + # Nodes + map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5')) + map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1')) + + tasklet = state.add_tasklet('tasklet', {}, {'b'}, 'b = 1') + write = state.add_write('B') + + # Edges + state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet()) + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet()) + + state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b', + dst_conn='IN_B') + state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B', + dst_conn='IN_B') + state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'), + src_conn='OUT_B') + + sdfg.validate() + return sdfg + + +def trivial_map_pseudo_init_sdfg(): + sdfg = dace.SDFG('trivial_map_range_expanded') + sdfg.add_array('A', [5, 1], dace.float64) + sdfg.add_array('B', [5, 1], dace.float64) + state = sdfg.add_state() + + # Nodes + map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5')) + map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1')) + + read = state.add_read('A') + tasklet = state.add_tasklet('tasklet', {'a'}, {'b'}, 'b = a') + write = state.add_write('B') + + # Edges + state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet()) + state.add_memlet_path(read, map_entry_outer, map_entry_inner, memlet=dace.Memlet.simple('A', '0:5, 0'), + dst_conn='IN_A') + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet()) + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet.simple('A', 'j, 0'), src_conn='OUT_A', dst_conn='a') + + state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b', + dst_conn='IN_B') + state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B', + dst_conn='IN_B') + state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'), + src_conn='OUT_B') + + sdfg.validate() + return sdfg + + class TrivialMapEliminationTest(unittest.TestCase): + """ + Tests the case where the map has an empty input edge + """ def test_can_be_applied(self): graph = trivial_map_sdfg() @@ -56,5 +118,75 @@ def test_raplaces_map_params_in_scope(self): self.assertEqual(out_memlet.data.subset, dace.subsets.Range([(0, 0, 1)])) +class TrivialMapInitEliminationTest(unittest.TestCase): + def test_can_be_applied(self): + graph = trivial_map_init_sdfg() + + count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) + graph.validate() + + self.assertGreater(count, 0) + + def test_removes_map(self): + graph = trivial_map_init_sdfg() + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 2) + + graph.apply_transformations(TrivialMapElimination) + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + + def test_reconnects_edges(self): + graph = trivial_map_init_sdfg() + + graph.apply_transformations(TrivialMapElimination) + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + # Check that there is an outgoing edge from the map entry + self.assertEqual(len(state.out_edges(map_entries[0])), 1) + + +class TrivialMapPseudoInitEliminationTest(unittest.TestCase): + """ + Test cases where the map has an empty input and a non empty input + """ + def test_can_be_applied(self): + graph = trivial_map_pseudo_init_sdfg() + + count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) + graph.validate() + graph.view() + + self.assertGreater(count, 0) + + def test_removes_map(self): + graph = trivial_map_pseudo_init_sdfg() + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 2) + + graph.apply_transformations(TrivialMapElimination) + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + + def test_reconnects_edges(self): + graph = trivial_map_pseudo_init_sdfg() + + graph.apply_transformations(TrivialMapElimination) + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + # Check that there is an outgoing edge from the map entry + self.assertEqual(len(state.out_edges(map_entries[0])), 1) + + if __name__ == '__main__': unittest.main()