Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][FRONTEND] added support for tuples #5220

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ void init_triton_ir(py::module &&m) {
"Function argument index out of range");
return self.getArgument(idx);
})
.def("get_num_args", &FuncOp::getNumArguments)
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
Expand Down
98 changes: 98 additions & 0 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytest
import triton
import triton.language as tl
import torch


@triton.jit
def _tuple_increment(values):
for i in tl.static_range(len(values)):
values[i] = values[i] + 1
return values


@triton.jit
def _tuple_index_func(Ptrs, values):
for i in tl.static_range(len(values)):
tl.store(Ptrs[i], values[i])


@triton.jit
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
values = _tuple_increment(values)
_tuple_index_func(Ptrs, values)


@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device="cuda"):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
assert vals == tuple([x.item() - 1 for x in rets])


# ----


@triton.jit
def _tuple_assign(XPtrs, YPtrs, values):
# assign from tuple
X0, X1 = XPtrs
x0, x1 = values
tl.store(X0, x0)
tl.store(X1, x1)
# assign to tuple
Y0, Y1, Y2 = YPtrs
Y = Y0, Y1, Y2
y = x0, 10, x1
tl.store(Y[0], y[0])
tl.store(Y[1], y[1])
tl.store(Y[2], y[2])


def test_assign(device="cuda"):
vals = (2., 3.)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
_tuple_assign[(1, )](x, y, vals)
assert x[0] == vals[0]
assert x[1] == vals[1]
assert y[0] == vals[0]
assert y[1] == 10
assert y[2] == vals[1]

# -------
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add more unit tests


@triton.jit
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
tl.store(Ptr + 5, cst2)
tl.store(Ptr + 6, tuple1[0])
tl.store(Ptr + 7, tl.load(tuple1[1][0]))
tl.store(Ptr + 8, tuple1[1][1][0])
tl.store(Ptr + 9, tl.load(tuple1[1][1][1]))

# test serialization/deserialization of tuple arguments in
# the frontend.
@triton.jit
def _tuple_serdes(Ptr, tuple1, cst1: tl.constexpr, val1, tuple2):
tl.store(Ptr + 0, tl.load(tuple1[0]))
tl.store(Ptr + 1, tuple1[1][0])
tl.store(Ptr + 2, tl.load(tuple1[1][1]))
tl.store(Ptr + 3, cst1 + val1)
tl.store(Ptr + 4, tl.load(tuple2[0]))
_tuple_fn0(Ptr, 15, (-1, tuple1))

def test_serdes(device="cuda"):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
z = torch.empty((10,), dtype=torch.int32, device=device)
# we want to check that JIT specialization propagates to tuples:
_tuple_serdes[(1,)](z, (x0, (1, x1)), 20, 1, (y0,))
print(z)


# function call (tuple argument)
# function call (tuple return value)
# __getitem__ and __setitem__
# assignment (into a tuple, from a tuple)
41 changes: 31 additions & 10 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,28 @@
import hashlib
import subprocess
import sysconfig

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
from types import ModuleType

def find_paths_if(iterable, pred):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: remove duplicate

is_iterable = lambda x: isinstance(x, (list, tuple))
ret = []
def _impl(current, path):
if pred(current):
if len(path) == 1:
ret.append((path[0],))
else:
ret.append(tuple(path))
elif is_iterable(current):
for idx, item in enumerate(current):
_impl(item, path + [idx])
if is_iterable(iterable):
_impl(iterable, [])
else:
ret = [tuple()] if pred(iterable) else []
return ret
# Table that associates strings to AttrsDescriptor (sub)classes.
# In this way we can dynamically select the correct class
# constructor
Expand Down Expand Up @@ -86,17 +102,22 @@ def _add_common_properties(self, params, values):
assert (len(params) == len(values))

# Divisibility property
self.arg_properties["tt.divisibility"] = [
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
]
divisibility_16 = []
for param, arg in zip(params, values):
if param.do_not_specialize or param.do_not_specialize_on_alignment:
continue
paths = find_paths_if(arg, AttrsDescriptor.is_divisible_by_16)
divisibility_16 += [(param.num,) + x for x in paths]
self.arg_properties["tt.divisibility"] = divisibility_16

# Equal to 1 property
self.arg_properties["tt.equal_to"] = [
param.num
for param, arg in zip(params, values)
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
]
equal_to_1 = []
for param, arg in zip(params, values):
if param.do_not_specialize:
continue
paths = find_paths_if(arg, AttrsDescriptor.is_equal_to_1)
equal_to_1 += [(param.num,) + x for x in paths]
self.arg_properties["tt.equal_to"] = equal_to_1

def _add_backend_properties(self, params=None, values=None):
""" This method is for different subclasses to implement their own compile-time properties """
Expand Down
Loading
Loading