diff --git a/python/src/ir.cc b/python/src/ir.cc index bb09dae7171a..1bd1d31a5b58 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -607,6 +607,7 @@ void init_triton_ir(py::module &&m) { "Function argument index out of range"); return self.getArgument(idx); }) + .def("get_num_args", &FuncOp::getNumArguments) .def( "add_entry_block", [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py new file mode 100644 index 000000000000..be6da3be3b18 --- /dev/null +++ b/python/test/unit/language/test_tuple.py @@ -0,0 +1,98 @@ +import pytest +import triton +import triton.language as tl +import torch + + +@triton.jit +def _tuple_increment(values): + for i in tl.static_range(len(values)): + values[i] = values[i] + 1 + return values + + +@triton.jit +def _tuple_index_func(Ptrs, values): + for i in tl.static_range(len(values)): + tl.store(Ptrs[i], values[i]) + + +@triton.jit +def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): + values = _tuple_increment(values) + _tuple_index_func(Ptrs, values) + + +@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) +def test_index(size, device="cuda"): + vals = tuple([i + 1 for i in range(size)]) + rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) + _tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) + assert vals == tuple([x.item() - 1 for x in rets]) + + +# ---- + + +@triton.jit +def _tuple_assign(XPtrs, YPtrs, values): + # assign from tuple + X0, X1 = XPtrs + x0, x1 = values + tl.store(X0, x0) + tl.store(X1, x1) + # assign to tuple + Y0, Y1, Y2 = YPtrs + Y = Y0, Y1, Y2 + y = x0, 10, x1 + tl.store(Y[0], y[0]) + tl.store(Y[1], y[1]) + tl.store(Y[2], y[2]) + + +def test_assign(device="cuda"): + vals = (2., 3.) + x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) + y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) + _tuple_assign[(1, )](x, y, vals) + assert x[0] == vals[0] + assert x[1] == vals[1] + assert y[0] == vals[0] + assert y[1] == 10 + assert y[2] == vals[1] + +# ------- + +@triton.jit +def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): + tl.store(Ptr + 5, cst2) + tl.store(Ptr + 6, tuple1[0]) + tl.store(Ptr + 7, tl.load(tuple1[1][0])) + tl.store(Ptr + 8, tuple1[1][1][0]) + tl.store(Ptr + 9, tl.load(tuple1[1][1][1])) + +# test serialization/deserialization of tuple arguments in +# the frontend. +@triton.jit +def _tuple_serdes(Ptr, tuple1, cst1: tl.constexpr, val1, tuple2): + tl.store(Ptr + 0, tl.load(tuple1[0])) + tl.store(Ptr + 1, tuple1[1][0]) + tl.store(Ptr + 2, tl.load(tuple1[1][1])) + tl.store(Ptr + 3, cst1 + val1) + tl.store(Ptr + 4, tl.load(tuple2[0])) + _tuple_fn0(Ptr, 15, (-1, tuple1)) + +def test_serdes(device="cuda"): + x0 = torch.tensor([8], dtype=torch.int32, device=device) + x1 = torch.tensor([12], dtype=torch.int32, device=device) + y0 = torch.tensor([10], dtype=torch.int32, device=device) + z = torch.empty((10,), dtype=torch.int32, device=device) + # we want to check that JIT specialization propagates to tuples: + _tuple_serdes[(1,)](z, (x0, (1, x1)), 20, 1, (y0,)) + print(z) + + +# function call (tuple argument) +# function call (tuple return value) +# __getitem__ and __setitem__ +# assignment (into a tuple, from a tuple) diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 6d33dbd6fa9b..308e62ebf136 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -3,12 +3,28 @@ import hashlib import subprocess import sysconfig - from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import Dict, List, Tuple, Union from types import ModuleType +def find_paths_if(iterable, pred): + is_iterable = lambda x: isinstance(x, (list, tuple)) + ret = [] + def _impl(current, path): + if pred(current): + if len(path) == 1: + ret.append((path[0],)) + else: + ret.append(tuple(path)) + elif is_iterable(current): + for idx, item in enumerate(current): + _impl(item, path + [idx]) + if is_iterable(iterable): + _impl(iterable, []) + else: + ret = [tuple()] if pred(iterable) else [] + return ret # Table that associates strings to AttrsDescriptor (sub)classes. # In this way we can dynamically select the correct class # constructor @@ -86,17 +102,22 @@ def _add_common_properties(self, params, values): assert (len(params) == len(values)) # Divisibility property - self.arg_properties["tt.divisibility"] = [ - param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] + divisibility_16 = [] + for param, arg in zip(params, values): + if param.do_not_specialize or param.do_not_specialize_on_alignment: + continue + paths = find_paths_if(arg, AttrsDescriptor.is_divisible_by_16) + divisibility_16 += [(param.num,) + x for x in paths] + self.arg_properties["tt.divisibility"] = divisibility_16 # Equal to 1 property - self.arg_properties["tt.equal_to"] = [ - param.num - for param, arg in zip(params, values) - if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize - ] + equal_to_1 = [] + for param, arg in zip(params, values): + if param.do_not_specialize: + continue + paths = find_paths_if(arg, AttrsDescriptor.is_equal_to_1) + equal_to_1 += [(param.num,) + x for x in paths] + self.arg_properties["tt.equal_to"] = equal_to_1 def _add_backend_properties(self, params=None, values=None): """ This method is for different subclasses to implement their own compile-time properties """ diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 85b15e8c1b24..b211eb769341 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -16,9 +16,12 @@ from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType from triton._utils import list_list_flatten, list_list_unflatten +from functools import reduce def mangle_ty(ty): + if ty.is_tuple(): + return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T' if ty.is_ptr(): return 'P' + mangle_ty(ty.element_ty) if ty.is_int(): @@ -57,8 +60,7 @@ def _is_triton_tensor(o: Any) -> bool: def _is_constexpr(o: Any) -> bool: - return isinstance(o, constexpr) - + return o is None or isinstance(o, (constexpr, language.core.dtype, int, bool)) def _is_triton_scalar(o: Any) -> bool: return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) @@ -192,7 +194,7 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context @@ -224,8 +226,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.gscope[k] = v self.lscope = {} - self.attributes = attributes - self.constants = constants self.jit_fn = jit_fn self.function_name = function_name self.is_kernel = is_kernel @@ -355,7 +355,7 @@ def visit_Module(self, node): def visit_List(self, node): ctx = self.visit(node.ctx) assert ctx is None - elts = [self.visit(elt) for elt in node.elts] + elts = language.tuple([self.visit(elt) for elt in node.elts]) return elts # By design, only non-kernel functions can return @@ -364,16 +364,15 @@ def visit_Return(self, node): if ret_value is None: self.builder.ret([]) ret_ty = language.void - elif isinstance(ret_value, tuple): - ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] + elif isinstance(ret_value, language.tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value.values] ret_types = [v.type for v in ret_values] self.builder.ret([v.handle for v in ret_values]) - ret_ty = tuple(ret_types) + ret_ty = language.tuple_type(ret_types) else: ret = language.semantic.to_tensor(ret_value, self.builder) self.builder.ret([ret.handle]) ret_ty = ret.type - if self.ret_type is None: self.ret_type = ret_ty elif self.ret_type != ret_ty: @@ -398,7 +397,6 @@ def visit_FunctionDef(self, node): init_node = ast.Assign(targets=[st_target], value=default_value) else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) - try: assert not self.visiting_arg_default_value self.visiting_arg_default_value = True @@ -408,34 +406,16 @@ def visit_FunctionDef(self, node): # initialize function visibility = "public" if self.is_kernel else "private" + fn_ty = self.prototype.serialize(self.builder) self.fn = self.builder.get_or_insert_function(self.module, self.function_name, - self.prototype.to_ir(self.builder), visibility, self.noinline) + fn_ty, visibility, self.noinline) self.module.push_back(self.fn) entry = self.fn.add_entry_block() - arg_values = [] - idx = 0 - for i in range(len(arg_names)): - if i in self.constants: - cst = self.constants[i] - if not _is_constexpr(cst): - cst = constexpr(self.constants[i]) - arg_values.append(cst) - continue - else: - if i in self.attributes: - for name, value in self.attributes[i]: - self.fn.set_arg_attr(idx, name, value) - - # Mark this argument as a pass-by-value TMA descriptor (nvidia) - if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): - self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) - - arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) - idx += 1 - - insert_pt = self.builder.get_insertion_block() + arg_values = self.prototype.deserialize(self.fn) + # bind arguments to symbols for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) + insert_pt = self.builder.get_insertion_block() self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) @@ -446,8 +426,11 @@ def visit_FunctionDef(self, node): self.ret_type = language.void self.builder.ret([]) else: - self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] - self.fn.reset_type(self.prototype.to_ir(self.builder)) + if isinstance(self.ret_type, language.tuple_type): + self.prototype.ret_types = self.ret_type.types + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.serialize(self.builder)) self.builder.ret([ self.builder.create_poison(ty.to_ir(self.builder)) for ty in self.prototype.ret_types @@ -486,30 +469,32 @@ def visit_AnnAssign(self, node): # default: call visit_Assign return self.visit_Assign(node) + def assignTarget(self, target, value): + if isinstance(target, ast.Subscript): + assert target.ctx.__class__.__name__ == "Store" + return self.visit_Subscript_Store(target, value) + if isinstance(target, ast.Tuple): + assert target.ctx.__class__.__name__ == "Store" + for i, name in enumerate(target.elts): + self.set_value(self.visit(name), value.values[i]) + return + assert isinstance(target, ast.Name) + self.set_value(self.visit(target), value) + def visit_Assign(self, node): - _names = [] - if isinstance(node, ast.AnnAssign): - _names += [self.visit(node.target)] - else: - for target in node.targets: - _names += [self.visit(target)] - if len(_names) > 1: - raise self._unsupported(node, "simultaneous multiple assignment is not supported.") - names = _names[0] - values = self.visit(node.value) - if not _is_list_like(names): - names = [names] - if not _is_list_like(values): - values = [values] - native_nontensor_types = (language.dtype, ) - for name, value in zip(names, values): - # by default, constexpr are assigned into python variable + # construct values to assign + def _sanitize_value(value): + native_nontensor_types = (language.dtype, language.tuple) value = _unwrap_if_constexpr(value) if value is not None and \ - not _is_triton_value(value) and \ - not isinstance(value, native_nontensor_types): + not _is_triton_tensor(value) and \ + not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) - self.set_value(name, value) + return value + + values = _sanitize_value(self.visit(node.value)) + assert len(node.targets) == 1 + self.assignTarget(node.targets[0], values) def visit_AugAssign(self, node): name = node.target.id @@ -532,7 +517,7 @@ def visit_Load(self, node): def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] - return tuple(args) + return language.tuple(args) def _apply_binary_method(self, method_name, lhs, rhs): # TODO: raise something meaningful if getattr fails below, esp for reverse method @@ -904,7 +889,7 @@ def visit_While(self, node): assert False, "Not implemented" ast.NodeVisitor.generic_visit(self, stmt) - def visit_Subscript(self, node): + def visit_Subscript_Load(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) @@ -912,6 +897,16 @@ def visit_Subscript(self, node): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] + def visit_Subscript_Store(self, node, value): + assert node.ctx.__class__.__name__ == "Store" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + assert isinstance(lhs, language.tuple) + lhs.__setitem__(slices, value) + + def visit_Subscript(self, node): + return self.visit_Subscript_Load(node) + def visit_ExtSlice(self, node): return [self.visit(dim) for dim in node.dims] @@ -1068,7 +1063,7 @@ def visit_Slice(self, node): lower = self.visit(node.lower) upper = self.visit(node.upper) step = self.visit(node.step) - return slice(lower, upper, step) + return language.slice(lower, upper, step) def visit_Index(self, node): return self.visit(node.value) @@ -1084,23 +1079,18 @@ def visit_Assert(self, node) -> Any: def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] - # generate function def - attributes = {} - constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) + args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) + args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values() + # mangle + fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) # generate function def if necessary if not self.module.has_function(fn_name): - prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = get_jit_fn_file_line(fn) - generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + arg_types = [language.core.constexpr if arg is None or isinstance(arg, (bool, int, language.core.dtype)) else arg.type for arg in args] + prototype = ASTFunction([], arg_types, args_cst, dict(), dict()) + generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, options=self.builder.options, codegen_fns=self.builder.codegen_fns, @@ -1116,8 +1106,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): else: callee_ret_type = self.function_ret_types[fn_name] symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0 or callee_ret_type is None: + args_val = [arg.handle for arg in args_val] + call_op = self.builder.call(symbol, args_val) + if callee_ret_type is None: return None elif call_op.get_num_results() == 1: return tensor(call_op.get_result(0), callee_ret_type) @@ -1125,8 +1116,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): # should return a tuple of tl.tensor results = [] for i in range(call_op.get_num_results()): - results.append(tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) + results.append(tensor(call_op.get_result(i), callee_ret_type.types[i])) + return language.tuple(results) def visit_Call(self, node): fn = _unwrap_if_constexpr(self.visit(node.func)) @@ -1145,7 +1136,11 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: - return fn(*args, **extra_kwargs, **kws) + ret = fn(*args, **extra_kwargs, **kws) + # builtin functions return plain tuples for readability + if isinstance(ret, tuple): + ret = language.tuple(ret) + return ret except Exception as e: # Normally when we raise a CompilationError, we raise it as # `from None`, because the original fileline from the exception @@ -1304,30 +1299,96 @@ def kernel_suffix(signature, specialization): return suffix +def find_paths_if(iterable, pred): + is_iterable = lambda x: isinstance(x, (list, tuple, language.tuple, language.tuple_type)) + ret = dict() + def _impl(current, path): + path = (path[0], ) if len(path) == 1 else tuple(path) + if is_iterable(current): + for idx, item in enumerate(current): + _impl(item, path + (idx,)) + elif pred(path, current): + if len(path) == 1: + ret[(path[0],)] = current + else: + ret[tuple(path)] = current + if is_iterable(iterable): + _impl(iterable, []) + else: + ret = dict() + return ret + + +class ASTFunction: + + def get_path(self, x, path): + return reduce(lambda a, idx: a[idx], path, x) + + def set_path(self, x, path, val): + prev = x if len(path) == 1 else self.get_path(x, path[:-1]) + prev[path[-1]] = val + + def __init__(self, ret_types, arg_types, constexprs, constants, attrs): + self.ret_types = ret_types + self.arg_types = arg_types + self.constexprs = constexprs + self.constants = constants + self.attrs = attrs + + def serialize(self, builder: ir.builder): + # fill up IR values in template + # > build function + is_val = lambda path, _: path not in self.constexprs + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(arg_types, ret_types) + + def deserialize(self, fn): + # create "template" + def make_template(val): + if isinstance(val, (list, tuple, language.tuple_type)): + return language.tuple([make_template(x) for x in val]) + return language.constexpr(None) + vals = make_template(self.arg_types) + is_val = lambda path, _: path not in self.constexprs + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + # > set attributes + for attr_path, attr_specs in self.attrs.items(): + for attr_name, attr_val in attr_specs: + if attr_path in val_paths: + fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + if isinstance(ty, nv_tma_desc_type): + fn.set_arg_attr(i, "tt.nv_tma_desc", 1) + # > add IR values to the template + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + self.set_path(vals, path, language.tensor(fn.args(i), ty)) + # > add constexpr values to the template + constants = self.constants | self.constexprs + for path, val in constants.items(): + self.set_path(vals, path, language.constexpr(val)) + return vals + + + def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): + constexprs = specialization.constants + constants = specialization.attrs.get_constants() + arg_types = [str_to_ty(ty) for ty in specialization.signature.values()] + # find index of constants in serialized order attrs = specialization.attrs - # create kernel prototype - cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in specialization.constants.items()} - # visit kernel AST - gscope = fn.__globals__.copy() - function_name = fn.repr(specialization) - tys = list(specialization.signature.values()) - new_constants = attrs.get_constants() - for k in new_constants: - if k in tys and tys[k] == "i1" and new_constants[k] == 1: - new_constants[k] = True - new_attrs = attrs.filter_out_constants() fn_attrs = new_attrs.get_fn_attrs() - all_constants = constants.copy() - all_constants.update(new_constants) - arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + fn_attrs = {k: v for k, v in fn_attrs.items() if k not in constants} file_name, begin_line = get_jit_fn_file_line(fn) - - prototype = language.function_type([], arg_types) - generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, + prototype = ASTFunction([], arg_types, constexprs, constants, fn_attrs) + generator = CodeGenerator(context, prototype, + gscope=fn.__globals__.copy(), + function_name=fn.repr(specialization), + jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index a76cb132ce47..369bdff8c88e 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -66,10 +66,6 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: raise TypeError("Signature keys must be string") if self.constants is None: self.constants = {} - else: - for k in self.constants.keys(): - if not isinstance(k, str): - raise TypeError("Constants keys must be string") if self.attrs is None: self.attrs = AttrsDescriptor() diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 737ff06e6aed..7db22fde3a97 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -69,7 +69,6 @@ float8e5, float8e5b16, full, - function_type, histogram, inline_asm_elementwise, int1, @@ -94,6 +93,7 @@ range, reduce, reshape, + slice, split, static_assert, static_print, @@ -101,6 +101,8 @@ store, tensor, trans, + tuple, + tuple_type, uint16, uint32, uint64, @@ -187,7 +189,6 @@ "floor", "fma", "full", - "function_type", "histogram", "inline_asm_elementwise", "interleave", @@ -230,6 +231,7 @@ "reduce", "reshape", "rsqrt", + "slice", "sigmoid", "sin", "softmax", @@ -246,6 +248,7 @@ "tensor", "trans", "triton", + "tuple", "uint16", "uint32", "uint64", @@ -261,6 +264,31 @@ ] +def parse_list_string(s): + s = s.strip() + if s.startswith('[') and s.endswith(']'): + s = s[1:-1] + result = [] + current = '' + depth = 0 + for c in s: + if c == '[': + depth += 1 + current += c + elif c == ']': + depth -= 1 + current += c + elif c == ',' and depth == 0: + result.append(current.strip()) + current = '' + else: + current += c + if current.strip(): + result.append(current.strip()) + return result + + + def str_to_ty(name): if name[0] == "*": name = name[1:] @@ -271,8 +299,16 @@ def str_to_ty(name): ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) + if name[0] == "[": + names = parse_list_string(name) + tys = [str_to_ty(x) for x in names] + return tuple_type(types=tys) + if name == "nvTmaDesc": return nv_tma_desc_type() + + if name == "constexpr": + return constexpr tys = { "fp8e4nv": float8e4nv, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 145c9648298d..191485861f2d 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -38,6 +38,14 @@ def wrapper(*args, **kwargs): return wrapper +def _flatten_list(lst): + for item in lst: + if isinstance(item, (list, tuple_type, tuple)): + yield from _flatten_list(item) + else: + yield item + + def _tensor_member_fn(fn: T) -> T: """Decorator that adds this free function as a member fn on class tensor. @@ -140,6 +148,7 @@ def __init__(self, value): self.value = value.value else: self.value = value + self.type = constexpr def __repr__(self) -> str: return f"constexpr[{self.value}]" @@ -303,6 +312,7 @@ class KIND(Enum): def __init__(self, name): name = _unwrap_if_constexpr(name) self.name = name + self.num_composite_types = 1 assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name if name in dtype.SINT_TYPES: self.int_signedness = dtype.SIGNEDNESS.SIGNED @@ -473,6 +483,10 @@ def is_ptr(): def is_const(): return False + @staticmethod + def is_tuple(): + return False + def __eq__(self, other: dtype): if not isinstance(other, dtype): return False @@ -564,6 +578,7 @@ def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = Fals self.element_ty = element_ty self.address_space = address_space self.const = const + self.num_composite_types = 1 self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' def to_ir(self, builder: ir.builder) -> ir.pointer_type: @@ -605,11 +620,11 @@ class block_type(dtype): def __init__(self, element_ty: dtype, shape: List): self.element_ty = element_ty + self.num_composite_types = 1 # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. - - assert (isinstance(shape, list)) + assert (isinstance(shape, (list, tuple))) # shape can be empty ([]) when an input is a 0D tensor. self.shape = _unwrap_shape(shape) @@ -647,19 +662,33 @@ def scalar(self): return self.element_ty -class function_type(dtype): +class tuple_type(dtype): - def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: - self.ret_types = ret_types - self.param_types = param_types + def __init__(self, types): + self.types = types + self.name = f"[{','.join(map(str, self.types))}]" + self.num_composite_types = len(self.types) def __str__(self): - return f'fn ({self.param_types}) -> {self.ret_types}' + return self.name + + def __iter__(self): + return iter(self.types) def to_ir(self, builder: ir.builder): - ir_param_types = [ty.to_ir(builder) for ty in self.param_types] - ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] - return builder.get_function_ty(ir_param_types, ret_types) + return [ty.to_ir(builder) for ty in self.types] + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def is_tuple(self): + return True + + +class slice_type(dtype): + + def __init__(self): + self.name = 'slice_type' # scalar types @@ -761,7 +790,10 @@ def __init__(self, handle, type: dtype): self.type = type # Tensor type (can be block_type) # Following the practice in pytorch, dtype is scalar type self.dtype = type.scalar - self.shape = [constexpr(s) for s in self.shape] + self.shape = tuple([constexpr(s) for s in self.shape]) + + def serialize(self): + return [self.handle] def _flatten_ir(self): return [self.handle] @@ -984,6 +1016,8 @@ def __not__(self, _builder=None): def __getitem__(self, slices, _builder=None): if isinstance(slices, (slice, constexpr)) or slices is None: slices = [slices] + if isinstance(slices, tuple): + slices = slices.values ret = self for dim, sl in enumerate(slices): if sl is None or isinstance(sl, constexpr) and sl.value is None: @@ -1144,6 +1178,70 @@ def flip(self, dim=None) -> tensor: ... +class tuple: + + def __init__(self, args: list): + self.values = [i for i in args] + + @property + def type(self): + def get_type(x): + if isinstance(x, dtype): + return dtype + return x.type + return tuple_type([get_type(x) for x in self.values]) + + def serialize(self): + return list(_flatten_list([x.serialize() for x in self.values])) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + return tuple(self.values[idx.start:idx.stop:idx.step]) + + # TODO: remove + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + if isinstance(other, list): + other = tuple(other) + return tuple(self.values + other.values) + # return tuple(a + b for a, b in zip(self.values, other.values)) + + def __eq__(self, other): + import builtins + if isinstance(other, (list, builtins.tuple)): + other = tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + import builtins + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + +class slice: + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = slice_type() class _experimental_tensor_descriptor_base(_value): """" A tensor descriptor with unknown shape and strides @@ -1559,7 +1657,7 @@ def expand_dims(input, axis, _builder=None): """ input = semantic.to_tensor(input, _builder) axis = _constexpr_to_value(axis) - axes = list(axis) if isinstance(axis, Sequence) else [axis] + axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] new_ndim = len(input.shape) + len(axes) axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] @@ -2210,14 +2308,12 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = reduce_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] @@ -2311,14 +2407,12 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] def make_combine_region(scan_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = scan_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d04f516e8152..6ac7c09a6d0d 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -308,6 +308,8 @@ def mangle_type(arg, is_const=False): return "fp32" elif hasattr(arg, "tma_desc_cpu_ptr"): return "nvTmaDesc" + elif isinstance(arg, tuple): + return "[" + ",".join(map(mangle_type, arg)) + "]" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) @@ -368,6 +370,7 @@ def create_function_from_signature(sig, kparams, backend): func_args.append(f"{name}=default_{name}") dict_entries.append(f"'{name}': {name}") if kp.is_constexpr: + signature_types.append('"constexpr"') constexpr_vals.append(name) else: non_constexpr_vals.append(name) @@ -601,32 +604,27 @@ def run(self, *args, grid, warmup, **kwargs): # done here rather than when we build the signature as otherwise # the kernel cache key could not distinguish between byte pointers # and None arguments, resulting in a downstream mismatch: - sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigkeys = [param.name for param in self.params] sigvals = sig_and_spec[:len(sigkeys)] signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) - constant_params = configs[0].get_constants() + attrs = backend.get_attrs_descriptor(self.params, bound_vals) constants = { - p.name: v + (p.num,): v for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or (p.num in constant_params) or v is None + if p.is_constexpr or v is None } for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + if self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=True): return None # compile the kernel - src = self.ASTSource(self, signature, constants, configs[0]) - kernel = self.compile( - src, - target=target, - options=options.__dict__, - ) + src = self.ASTSource(self, signature, constants, attrs) + kernel = self.compile(src, target=target, options=options.__dict__) self.cache[device][key] = kernel - self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) + self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=False) # Check that used global values have not changed. not_present = object() @@ -639,15 +637,11 @@ def run(self, *args, grid, warmup, **kwargs): # canonicalize grid assert grid is not None if callable(grid): - # Arguments are passed as a dict to `grid`, by contract. - # TODO(jlebar): In the new launch API, pass the compiler flags as a - # second parameter to `grid`. grid = grid(bound_args) grid_size = len(grid) grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 827ce61cbaf2..6e2a754230d9 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -17,6 +17,29 @@ libraries = ['cuda'] +def parse_list_string(s): + s = s.strip() + if s.startswith('[') and s.endswith(']'): + s = s[1:-1] + result = [] + current = '' + depth = 0 + for c in s: + if c == '[': + depth += 1 + current += c + elif c == ']': + depth -= 1 + current += c + elif c == ',' and depth == 0: + result.append(current.strip()) + current = '' + else: + current += c + if current.strip(): + result.append(current.strip()) + return result + @functools.lru_cache() def libcuda_dirs(): env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") @@ -117,20 +140,47 @@ def ty_to_cpp(ty): }[ty] +def _serialize_type(type): + if type[0] == '[': + return [] if type == "[]" else type[1:-1].split(',') + return [type] + + +def _serialize_signature(signature): + ret = dict() + map = dict() + i = 0 + for _, type in signature.items(): + types = _serialize_type(type) + ret.update({i + j: ty for j, ty in enumerate(types)}) + for ii in range(len(types)): + map[i + ii] = _ + i += len(types) + return ret, map + + def make_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): if ty[0] == '*': return "PyObject*" if ty == "nvTmaDesc": return "PyObject*" - + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" return ty_to_cpp(ty) def format_of(ty): + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -146,10 +196,17 @@ def format_of(ty): "uint64_t": "K", }[ty] + + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOOO" + args_format + signature = ','.join(signature.values()).replace('[','').replace(']','') + signature = signature.split(',') if signature else dict() + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) internal_args_list = [] for i, ty in signature.items(): if ty[0] == "*": @@ -159,6 +216,7 @@ def format_of(ty): internal_args_list.append(f"*tma_ptr{i}") else: internal_args_list.append(f"_arg{i}") + params = range(len(signature)) # generate glue code params = [f"&arg{i}" for i in signature.keys() if i not in constants] @@ -428,9 +486,8 @@ class CudaLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} + constants = {idx: value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(constants, signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch