diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 54dbc8d4ac..48d7641418 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -519,7 +519,8 @@ def auto_optimize(sdfg: SDFG, device: dtypes.DeviceType, validate: bool = True, validate_all: bool = False, - symbols: Dict[str, int] = None) -> SDFG: + symbols: Dict[str, int] = None, + gpu_global: bool = False) -> SDFG: """ Runs a basic sequence of transformations to optimize a given SDFG to decent performance. In particular, performs the following: @@ -565,6 +566,12 @@ def auto_optimize(sdfg: SDFG, # Apply GPU transformations and set library node implementations if device == dtypes.DeviceType.GPU: + def gpu_storage(sdfg: dace.SDFG): + for _, desc in sdfg.arrays.items(): + if not desc.transient and isinstance(desc, dace.data.Array): + desc.storage = dace.StorageType.GPU_Global + if gpu_global: + gpu_storage(sdfg) sdfg.apply_gpu_transformations() sdfg.simplify()