Skip to content

Commit

Permalink
Merge branch 'master' into refactor_sdfg_list_to_cfg_list
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 authored Feb 20, 2024
2 parents 9e65ef0 + c92ecc5 commit b1da057
Show file tree
Hide file tree
Showing 154 changed files with 1,462 additions and 1,796 deletions.
6 changes: 4 additions & 2 deletions dace/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ endif()

# Create HIP object files
if(DACE_ENABLE_HIP)
enable_language(HIP)

# Get local AMD architectures
if (NOT DEFINED LOCAL_HIP_ARCHITECTURES)
# Compile and run a test program
Expand Down Expand Up @@ -304,8 +306,8 @@ if(DACE_ENABLE_HIP)
set(DACE_LIBS ${DACE_LIBS} hip::host)

set_source_files_properties(${DACE_HIP_FILES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
hip_prepare_target_commands(${DACE_PROGRAM_NAME} OBJ DACE_HIP_OBJECTS DACE_HIP_SOURCES ${DACE_HIP_FILES})
set(DACE_OBJECTS ${DACE_OBJECTS} ${DACE_HIP_OBJECTS})
set_source_files_properties(${DACE_HIP_FILES} PROPERTIES LANGUAGE HIP)
set(DACE_OBJECTS ${DACE_OBJECTS} ${DACE_HIP_FILES})
endif() # DACE_ENABLE_HIP

# create verilator RTL simulation objects
Expand Down
55 changes: 18 additions & 37 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def __init__(self, library_filename, program_name):
:param program_name: Name of the DaCe program (for use in finding
the stub library loader).
"""
self._stub_filename = os.path.join(
os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._stub_filename = os.path.join(os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._library_filename = os.path.realpath(library_filename)
self._stub = None
self._lib = None
Expand Down Expand Up @@ -219,7 +218,6 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None):
self.has_gpu_code = True
break


def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..., Any]]:
"""
Tries to find a symbol by name in the compiled SDFG, and convert it to a callable function
Expand All @@ -233,7 +231,6 @@ def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..
except KeyError: # Function not found
return None


def get_state_struct(self) -> ctypes.Structure:
""" Attempt to parse the SDFG source code and extract the state struct. This method will parse the first
consecutive entries in the struct that are pointers. As soon as a non-pointer or other unparseable field is
Expand All @@ -247,7 +244,6 @@ def get_state_struct(self) -> ctypes.Structure:

return ctypes.cast(self._libhandle, ctypes.POINTER(self._try_parse_state_struct())).contents


def _try_parse_state_struct(self) -> Optional[Type[ctypes.Structure]]:
from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid import cycle
# the path of the main sdfg file containing the state struct
Expand Down Expand Up @@ -375,7 +371,6 @@ def _get_error_text(self, result: Union[str, int]) -> str:
else:
return result


def __call__(self, *args, **kwargs):
"""
Forwards the Python call to the compiled ``SDFG``.
Expand All @@ -400,13 +395,12 @@ def __call__(self, *args, **kwargs):
elif len(args) > 0 and self.argnames is not None:
kwargs.update(
# `_construct_args` will handle all of its arguments as kwargs.
{aname: arg for aname, arg in zip(self.argnames, args)}
)
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
{aname: arg
for aname, arg in zip(self.argnames, args)})
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
return self.fast_call(argtuple, initargtuple, do_gpu_check=True)


def fast_call(
self,
callargs: Tuple[Any, ...],
Expand Down Expand Up @@ -455,15 +449,13 @@ def fast_call(
self._lib.unload()
raise


def __del__(self):
if self._initialized is True:
self.finalize()
self._initialized = False
self._libhandle = ctypes.c_void_p(0)
self._lib.unload()


def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
"""
Main function that controls argument construction for calling
Expand All @@ -486,7 +478,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
typedict = self._typedict
if len(kwargs) > 0:
# Construct mapping from arguments to signature
arglist = []
arglist = []
argtypes = []
argnames = []
for a in sig:
Expand Down Expand Up @@ -536,10 +528,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
'you are doing, you can override this error in the '
'configuration by setting compiler.allow_view_arguments '
'to True.')
elif (not isinstance(atype, (dt.Array, dt.Structure)) and
not isinstance(atype.dtype, dtypes.callback) and
not isinstance(arg, (atype.dtype.type, sp.Basic)) and
not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
elif (not isinstance(atype, (dt.Array, dt.Structure)) and not isinstance(atype.dtype, dtypes.callback)
and not isinstance(arg, (atype.dtype.type, sp.Basic))
and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
is_int = isinstance(arg, int)
if is_int and atype.dtype.type == np.int64:
pass
Expand Down Expand Up @@ -573,29 +564,23 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
# Retain only the element datatype for upcoming checks and casts
arg_ctypes = tuple(at.dtype.as_ctypes() for at in argtypes)

constants = self.sdfg.constants
callparams = tuple(
(actype(arg.get())
if isinstance(arg, symbolic.symbol)
else arg, actype, atype, aname
)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants))
)
constants = self.sdfg.constants
callparams = tuple((arg, actype, atype, aname)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants)))

symbols = self._free_symbols
initargs = tuple(
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg
for arg, actype, atype, aname in callparams
if aname in symbols
)
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg for arg, actype, atype, aname in callparams
if aname in symbols)

try:
# Replace arrays with their base host/device pointers
newargs = [None] * len(callparams)
for i, (arg, actype, atype, _) in enumerate(callparams):
if dtypes.is_array(arg):
newargs[i] = ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
newargs[i] = ctypes.c_void_p(_array_interface_ptr(
arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
elif not isinstance(arg, (ctypes._SimpleCData)):
newargs[i] = actype(arg)
else:
Expand All @@ -607,11 +592,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
self._lastargs = newargs, initargs
return self._lastargs


def clear_return_values(self):
self._create_new_arrays = True


def _create_array(self, _: str, dtype: np.dtype, storage: dtypes.StorageType, shape: Tuple[int],
strides: Tuple[int], total_size: int):
ndarray = np.ndarray
Expand All @@ -636,7 +619,6 @@ def ndarray(*args, buffer=None, **kwargs):
# Create an array with the properties of the SDFG array
return ndarray(shape, dtype, buffer=zeros(total_size, dtype), strides=strides)


def _initialize_return_values(self, kwargs):
# Obtain symbol values from arguments and constants
syms = dict()
Expand Down Expand Up @@ -687,7 +669,6 @@ def _initialize_return_values(self, kwargs):
arr = self._create_array(*shape_desc)
self._return_arrays.append(arr)


def _convert_return_values(self):
# Return the values as they would be from a Python function
if self._return_arrays is None or len(self._return_arrays) == 0:
Expand Down
64 changes: 32 additions & 32 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,38 @@ def create_datadescriptor(obj, no_custom_desc=False):
return obj.__descriptor__()
elif not no_custom_desc and hasattr(obj, 'descriptor'):
return obj.descriptor
elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor":
# special case for torch tensors. Maybe __array__ could be used here for a more
# general solution, but torch doesn't support __array__ for cuda tensors.
try:
# If torch is importable, define translations between typeclasses and torch types. These are reused by daceml.
# conversion happens here in pytorch:
# https://github.com/pytorch/pytorch/blob/143ef016ee1b6a39cf69140230d7c371de421186/torch/csrc/utils/tensor_numpy.cpp#L237
import torch
TYPECLASS_TO_TORCH_DTYPE = {
dtypes.bool_: torch.bool,
dtypes.int8: torch.int8,
dtypes.int16: torch.int16,
dtypes.int32: torch.int32,
dtypes.int64: torch.int64,
dtypes.uint8: torch.uint8,
dtypes.float16: torch.float16,
dtypes.float32: torch.float32,
dtypes.float64: torch.float64,
dtypes.complex64: torch.complex64,
dtypes.complex128: torch.complex128,
}

TORCH_DTYPE_TO_TYPECLASS = {v: k for k, v in TYPECLASS_TO_TORCH_DTYPE.items()}

storage = dtypes.StorageType.GPU_Global if obj.device.type == 'cuda' else dtypes.StorageType.Default

return Array(dtype=TORCH_DTYPE_TO_TYPECLASS[obj.dtype],
strides=obj.stride(),
shape=tuple(obj.shape),
storage=storage)
except ImportError:
raise ValueError("Attempted to convert a torch.Tensor, but torch could not be imported")
elif dtypes.is_array(obj) and (hasattr(obj, '__array_interface__') or hasattr(obj, '__cuda_array_interface__')):
if dtypes.is_gpu_array(obj):
interface = obj.__cuda_array_interface__
Expand Down Expand Up @@ -79,38 +111,6 @@ def create_datadescriptor(obj, no_custom_desc=False):
dtype = dtypes.typeclass(obj.dtype.type)
itemsize = obj.itemsize
return Array(dtype=dtype, shape=obj.shape, strides=tuple(s // itemsize for s in obj.strides), storage=storage)
elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor":
# special case for torch tensors. Maybe __array__ could be used here for a more
# general solution, but torch doesn't support __array__ for cuda tensors.
try:
# If torch is importable, define translations between typeclasses and torch types. These are reused by daceml.
# conversion happens here in pytorch:
# https://github.com/pytorch/pytorch/blob/143ef016ee1b6a39cf69140230d7c371de421186/torch/csrc/utils/tensor_numpy.cpp#L237
import torch
TYPECLASS_TO_TORCH_DTYPE = {
dtypes.bool_: torch.bool,
dtypes.int8: torch.int8,
dtypes.int16: torch.int16,
dtypes.int32: torch.int32,
dtypes.int64: torch.int64,
dtypes.uint8: torch.uint8,
dtypes.float16: torch.float16,
dtypes.float32: torch.float32,
dtypes.float64: torch.float64,
dtypes.complex64: torch.complex64,
dtypes.complex128: torch.complex128,
}

TORCH_DTYPE_TO_TYPECLASS = {v: k for k, v in TYPECLASS_TO_TORCH_DTYPE.items()}

storage = dtypes.StorageType.GPU_Global if obj.device.type == 'cuda' else dtypes.StorageType.Default

return Array(dtype=TORCH_DTYPE_TO_TYPECLASS[obj.dtype],
strides=obj.stride(),
shape=tuple(obj.shape),
storage=storage)
except ImportError:
raise ValueError("Attempted to convert a torch.Tensor, but torch could not be imported")
elif symbolic.issymbolic(obj):
return Scalar(symbolic.symtype(obj))
elif isinstance(obj, dtypes.typeclass):
Expand Down
7 changes: 2 additions & 5 deletions dace/frontend/python/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@


def ndarray(shape, dtype=numpy.float64, *args, **kwargs):
""" Returns a numpy ndarray where all symbols have been evaluated to
numbers and types are converted to numpy types. """
repldict = {sym: sym.get() for sym in symbolic.symlist(shape).values()}
new_shape = [int(s.subs(repldict) if symbolic.issymbolic(s) else s) for s in shape]
""" Returns a numpy ndarray where all types are converted to numpy types. """
new_dtype = dtype.type if isinstance(dtype, dtypes.typeclass) else dtype
return numpy.ndarray(shape=new_shape, dtype=new_dtype, *args, **kwargs)
return numpy.ndarray(shape=shape, dtype=new_dtype, *args, **kwargs)


stream: Type[Deque[T]] = deque
Expand Down
21 changes: 6 additions & 15 deletions dace/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def _connected():
try:
urllib.request.urlopen('https://spcl.github.io/dace/webclient2/dist/sdfv.js', timeout=1)
urllib.request.urlopen('https://spcl.github.io/dace-webclient/dist/sdfv.js', timeout=1)
return True
except urllib.error.URLError:
return False
Expand All @@ -31,31 +31,22 @@ def isnotebook():
def preamble():
# Emit javascript headers for SDFG renderer
sdfv_js_deps = ['sdfv.js']
sdfv_css_deps = ['sdfv.css']
offline_sdfv_js_deps = ['sdfv_jupyter.js']

result = ''

# Rely on internet connection for Material icons
result += '<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">'

# Try to load dependencies from online sources
if _connected():
for dep in sdfv_js_deps:
result += '<script src="https://spcl.github.io/dace/webclient2/dist/%s"></script>\n' % dep
for dep in sdfv_css_deps:
result += '<link href="https://spcl.github.io/dace/webclient2/%s" rel="stylesheet">\n' % dep
result += '<script src="https://spcl.github.io/dace-webclient/dist/%s"></script>\n' % dep
return result

# Load local dependencies
root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'dace', 'viewer', 'webclient')
for dep in sdfv_js_deps:
root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'viewer', 'webclient')
for dep in offline_sdfv_js_deps:
file = os.path.join(root_path, 'dist', dep)
with open(file, 'r') as fp:
with open(file) as fp:
result += '<script>%s</script>\n' % fp.read()
for dep in sdfv_css_deps:
file = os.path.join(root_path, dep)
with open(file, 'r') as fp:
result += '<style>%s</style>\n' % fp.read()

# Run this code once
return result
Expand Down
3 changes: 2 additions & 1 deletion dace/runtime/include/dace/cuda/halfvec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,13 @@ namespace dace { namespace math {
HALF_VEC_UFUNC(exp)
HALF_VEC_UFUNC(tanh)
} }
#endif

// Vector comparison functions
DACE_DFI half2 max(half2 a, half2 b) {
return make_half2(max(a.x, b.x), max(a.y, b.y));
}
#endif


DACE_DFI half4 max(half4 a, half b) {
half2 bvec = __half2half2(b);
Expand Down
14 changes: 5 additions & 9 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,9 +1427,13 @@ def _repr_html_(self):

# Create renderer canvas and load SDFG
result += """
<div class="sdfv">
<div id="contents_{uid}" style="position: relative; resize: vertical; overflow: auto"></div>
</div>
<script>
var sdfg_{uid} = {sdfg};
</script>
<script>
var sdfv_{uid} = new SDFV();
var renderer_{uid} = new SDFGRenderer(sdfv_{uid}, parse_sdfg(sdfg_{uid}),
document.getElementById('contents_{uid}'));
Expand Down Expand Up @@ -2119,16 +2123,8 @@ def specialize(self, symbols: Dict[str, Any]):
:param symbols: Values to specialize.
"""
# Set symbol values to add
syms = {
# If symbols are passed, extract the value. If constants are
# passed, use them directly.
name: val.get() if isinstance(val, dace.symbolic.symbol) else val
for name, val in symbols.items()
}

# Update constants
for k, v in syms.items():
for k, v in symbols.items():
self.add_constant(str(k), v)

def is_loaded(self) -> bool:
Expand Down
Loading

0 comments on commit b1da057

Please sign in to comment.