-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT][FRONTEND] added support for tuples #5220
Draft
ptillet
wants to merge
33
commits into
main
Choose a base branch
from
phil/tuple-support-2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
51ab367
progress
ptillet 65da71f
.
ptillet 746a2a3
prototype works
ptillet 630ec6c
added test
ptillet f2439f9
fixup
ptillet d9af0ba
cleanup
ptillet f758ef2
.
ptillet 1b58df2
progress
ptillet 1558d6c
bugfix
ptillet 812af43
progress
ptillet e226cfd
.
ptillet 98c526e
Merge remote-tracking branch 'origin/main' into phil/tuple-support
ptillet 8cee89d
.
ptillet 627bef2
.
ptillet 756d75a
.
ptillet 2a86fb4
.
ptillet d614226
.
ptillet 5d29bef
fails again?
ptillet a790867
more hacks
ptillet fa23bfc
giant mess; more tests pass
ptillet d88cca0
very hacky but tests pass; TO REFACTOR
ptillet e299bf2
.
ptillet fcae528
.
ptillet d0168c9
progress
ptillet 0ba41ff
more progress
ptillet e7289dc
more progress
ptillet b7d8117
.
ptillet 33505ac
more progress
ptillet 3c08877
more fixes
ptillet dba9b2d
all tests pass
ptillet a35e89a
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet ae2ebf6
Merge branch 'main' into phil/tuple-support-2
ptillet 18f24ef
fixed TMA descriptors
ptillet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import pytest | ||
import triton | ||
import triton.language as tl | ||
import torch | ||
|
||
|
||
@triton.jit | ||
def _tuple_increment(values): | ||
for i in tl.static_range(len(values)): | ||
values[i] = values[i] + 1 | ||
return values | ||
|
||
|
||
@triton.jit | ||
def _tuple_index_func(Ptrs, values): | ||
for i in tl.static_range(len(values)): | ||
tl.store(Ptrs[i], values[i]) | ||
|
||
|
||
@triton.jit | ||
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): | ||
values = _tuple_increment(values) | ||
_tuple_index_func(Ptrs, values) | ||
|
||
|
||
@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) | ||
def test_index(size, device="cuda"): | ||
vals = tuple([i + 1 for i in range(size)]) | ||
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) | ||
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) | ||
assert vals == tuple([x.item() - 1 for x in rets]) | ||
|
||
|
||
# ---- | ||
|
||
|
||
@triton.jit | ||
def _tuple_assign(XPtrs, YPtrs, values): | ||
# assign from tuple | ||
X0, X1 = XPtrs | ||
x0, x1 = values | ||
tl.store(X0, x0) | ||
tl.store(X1, x1) | ||
# assign to tuple | ||
Y0, Y1, Y2 = YPtrs | ||
Y = Y0, Y1, Y2 | ||
y = x0, 10, x1 | ||
tl.store(Y[0], y[0]) | ||
tl.store(Y[1], y[1]) | ||
tl.store(Y[2], y[2]) | ||
|
||
|
||
def test_assign(device="cuda"): | ||
vals = (2., 3.) | ||
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) | ||
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) | ||
_tuple_assign[(1, )](x, y, vals) | ||
assert x[0] == vals[0] | ||
assert x[1] == vals[1] | ||
assert y[0] == vals[0] | ||
assert y[1] == 10 | ||
assert y[2] == vals[1] | ||
|
||
# ------- | ||
|
||
@triton.jit | ||
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): | ||
tl.store(Ptr + 5, cst2) | ||
tl.store(Ptr + 6, tuple1[0]) | ||
tl.store(Ptr + 7, tl.load(tuple1[1][0])) | ||
tl.store(Ptr + 8, tuple1[1][1][0]) | ||
tl.store(Ptr + 9, tl.load(tuple1[1][1][1])) | ||
|
||
# test serialization/deserialization of tuple arguments in | ||
# the frontend. | ||
@triton.jit | ||
def _tuple_serdes(Ptr, tuple1, cst1: tl.constexpr, val1, tuple2): | ||
tl.store(Ptr + 0, tl.load(tuple1[0])) | ||
tl.store(Ptr + 1, tuple1[1][0]) | ||
tl.store(Ptr + 2, tl.load(tuple1[1][1])) | ||
tl.store(Ptr + 3, cst1 + val1) | ||
tl.store(Ptr + 4, tl.load(tuple2[0])) | ||
_tuple_fn0(Ptr, 15, (-1, tuple1)) | ||
|
||
def test_serdes(device="cuda"): | ||
x0 = torch.tensor([8], dtype=torch.int32, device=device) | ||
x1 = torch.tensor([12], dtype=torch.int32, device=device) | ||
y0 = torch.tensor([10], dtype=torch.int32, device=device) | ||
z = torch.empty((10,), dtype=torch.int32, device=device) | ||
# we want to check that JIT specialization propagates to tuples: | ||
_tuple_serdes[(1,)](z, (x0, (1, x1)), 20, 1, (y0,)) | ||
print(z) | ||
|
||
|
||
# function call (tuple argument) | ||
# function call (tuple return value) | ||
# __getitem__ and __setitem__ | ||
# assignment (into a tuple, from a tuple) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,28 @@ | |
import hashlib | ||
import subprocess | ||
import sysconfig | ||
|
||
from abc import ABCMeta, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Tuple, Union | ||
from types import ModuleType | ||
|
||
def find_paths_if(iterable, pred): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: remove duplicate |
||
is_iterable = lambda x: isinstance(x, (list, tuple)) | ||
ret = [] | ||
def _impl(current, path): | ||
if pred(current): | ||
if len(path) == 1: | ||
ret.append((path[0],)) | ||
else: | ||
ret.append(tuple(path)) | ||
elif is_iterable(current): | ||
for idx, item in enumerate(current): | ||
_impl(item, path + [idx]) | ||
if is_iterable(iterable): | ||
_impl(iterable, []) | ||
else: | ||
ret = [tuple()] if pred(iterable) else [] | ||
return ret | ||
# Table that associates strings to AttrsDescriptor (sub)classes. | ||
# In this way we can dynamically select the correct class | ||
# constructor | ||
|
@@ -86,17 +102,22 @@ def _add_common_properties(self, params, values): | |
assert (len(params) == len(values)) | ||
|
||
# Divisibility property | ||
self.arg_properties["tt.divisibility"] = [ | ||
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) | ||
and not param.do_not_specialize and not param.do_not_specialize_on_alignment | ||
] | ||
divisibility_16 = [] | ||
for param, arg in zip(params, values): | ||
if param.do_not_specialize or param.do_not_specialize_on_alignment: | ||
continue | ||
paths = find_paths_if(arg, AttrsDescriptor.is_divisible_by_16) | ||
divisibility_16 += [(param.num,) + x for x in paths] | ||
self.arg_properties["tt.divisibility"] = divisibility_16 | ||
|
||
# Equal to 1 property | ||
self.arg_properties["tt.equal_to"] = [ | ||
param.num | ||
for param, arg in zip(params, values) | ||
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize | ||
] | ||
equal_to_1 = [] | ||
for param, arg in zip(params, values): | ||
if param.do_not_specialize: | ||
continue | ||
paths = find_paths_if(arg, AttrsDescriptor.is_equal_to_1) | ||
equal_to_1 += [(param.num,) + x for x in paths] | ||
self.arg_properties["tt.equal_to"] = equal_to_1 | ||
|
||
def _add_backend_properties(self, params=None, values=None): | ||
""" This method is for different subclasses to implement their own compile-time properties """ | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: add more unit tests