Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Failed to lower tt.bitcast of !tt.ptr<i1> #192

Open
Nullkooland opened this issue Nov 21, 2024 · 5 comments
Open

[Bug]: Failed to lower tt.bitcast of !tt.ptr<i1> #192

Nullkooland opened this issue Nov 21, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@Nullkooland
Copy link
Contributor

Nullkooland commented Nov 21, 2024

Triton python code

@triton_heuristics.pointwise(
    size_hints=[16384], 
    filename=__file__,
    triton_meta={'signature': {0: '*i16', 1: '*i16', 2: '*i1', 3: 'i32'}, 'device': DeviceProperties(type='cpu', index=None, cc='', major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_isin_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, ...}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel: int, XBLOCK : tl.constexpr):
    xnumel = 16384
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)
    x0 = xindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[16384], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]))
    tmp1 = tl.load(in_ptr1 + (0))
    tmp2 = tl.broadcast_to(tmp1, [XBLOCK])
    tmp4 = tl.load(in_ptr1 + (1))
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
    tmp8 = tl.load(in_ptr1 + (2))
    tmp9 = tl.broadcast_to(tmp8, [XBLOCK])
    tmp3 = tmp0 == tmp2
    tmp6 = tmp0 == tmp5
    tmp7 = tmp3 | tmp6
    tmp10 = tmp0 == tmp9
    tmp11 = tmp7 | tmp10
    tl.store(tl.make_block_ptr(out_ptr0, shape=[16384], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp11, [XBLOCK]).to(tl.int8))

Triton IR

module {
  tt.func public @triton_(%arg0: !tt.ptr<i16>, %arg1: !tt.ptr<i16>, %arg2: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i64 = arith.constant 1 : i64
    %c16384_i64 = arith.constant 16384 : i64
    %c256_i32 = arith.constant 256 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c256_i32 : i32
    %2 = tt.make_tensor_ptr %arg0, [%c16384_i64], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xi16>>
    %3 = tt.load %2 : !tt.ptr<tensor<256xi16>>
    %4 = tt.addptr %arg1, %c0_i32 : !tt.ptr<i16>, i32
    %5 = tt.load %4 : !tt.ptr<i16>
    %6 = tt.splat %5 : i16 -> tensor<256xi16>
    %7 = tt.addptr %arg1, %c1_i32 : !tt.ptr<i16>, i32
    %8 = tt.load %7 : !tt.ptr<i16>
    %9 = tt.splat %8 : i16 -> tensor<256xi16>
    %10 = tt.addptr %arg1, %c2_i32 : !tt.ptr<i16>, i32
    %11 = tt.load %10 : !tt.ptr<i16>
    %12 = tt.splat %11 : i16 -> tensor<256xi16>
    %13 = arith.cmpi eq, %3, %6 : tensor<256xi16>
    %14 = arith.cmpi eq, %3, %9 : tensor<256xi16>
    %15 = arith.ori %13, %14 : tensor<256xi1>
    %16 = arith.cmpi eq, %3, %12 : tensor<256xi16>
    %17 = arith.ori %15, %16 : tensor<256xi1>
    %18 = tt.bitcast %arg2 : !tt.ptr<i1> -> !tt.ptr<i8>
    %19 = tt.make_tensor_ptr %18, [%c16384_i64], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xi8>>
    %20 = arith.extui %17 : tensor<256xi1> to tensor<256xi8>
    tt.store %19, %20 : !tt.ptr<tensor<256xi8>>
    tt.return
  }
}

Crash log

triton-shared-opt --triton-to-linalg-experimental /.../triton_.ttir
gen_triton_kernel.py:25:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:25:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:27:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
❯ triton-shared-opt --triton-to-linalg-experimental /.../triton_.ttir
gen_triton_kernel.py:25:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:25:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:27:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp4 = tl.load(in_ptr1 + (1))
                  ^
gen_triton_kernel.py:27:19: note: see current operation: %15 = tt.load %14 : !tt.ptr<i16>
gen_triton_kernel.py:27:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp4 = tl.load(in_ptr1 + (1))
                  ^
gen_triton_kernel.py:27:19: note: see current operation: %15 = tt.load %14 : !tt.ptr<i16>
gen_triton_kernel.py:29:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp8 = tl.load(in_ptr1 + (2))
                  ^
gen_triton_kernel.py:29:19: note: see current operation: %19 = tt.load %18 : !tt.ptr<i16>
gen_triton_kernel.py:29:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp8 = tl.load(in_ptr1 + (2))
                  ^
gen_triton_kernel.py:29:19: note: see current operation: %19 = tt.load %18 : !tt.ptr<i16>
gen_triton_kernel.py:36:31: error: 'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got '!tt.ptr<i1>'
    tl.store(tl.make_block_ptr(out_ptr0, shape=[16384], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp11, [XBLOCK]).to(tl.int8))
                              ^
gen_triton_kernel.py:36:31: note: see current operation: %30 = "arith.bitcast"(%arg2) : (!tt.ptr<i1>) -> !tt.ptr<i8>

Additional information

The triton kernel is the codegen result of TorchInductor, the src torch program is:

a = torch.randint(low=-100, high=100, size=(128, 128), dtype=torch.int16, device=device)
b = torch.tensor(data=[-1, 0, 1], dtype=torch.int16, device=device)

@torch.compile(fullgraph=True)
def test_func(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    out = torch.isin(a, b)
    return out

out = test_func(a, b)

Looks like tiron-shared's triton-to-linalg lower pass cannot properly handle i1 (boolean) type?

@Nullkooland Nullkooland added the bug Something isn't working label Nov 21, 2024
@parsifal-47
Copy link
Contributor

I think this PR should address the issue:
#171

@Nullkooland
Copy link
Contributor Author

I think this PR should address the issue: #171

Thanks, when will this PR be upstreamed?

@parsifal-47
Copy link
Contributor

I think this PR should address the issue: #171

Thanks, when will this PR be upstreamed?

It needs a review, I pinged @nhat-nguyen, but he could be busy, if you know somebody else with write permissions let me know

@Nullkooland
Copy link
Contributor Author

I think this PR should address the issue: #171

Hi, I tried your PR, but it still does not handle my case where has tt.bitcast %arg2 : !tt.ptr<i1> -> !tt.ptr<i8>,
could you verify using the IR attached in this issue?

@parsifal-47
Copy link
Contributor

I think this PR should address the issue: #171

Hi, I tried your PR, but it still does not handle my case where has tt.bitcast %arg2 : !tt.ptr<i1> -> !tt.ptr<i8>, could you verify using the IR attached in this issue?

you are right, it handles tensor conversions, but not individual pointers to integers. This is relatively easy fix, but I see that it also crashes in pointer analysis. I can take a look once PR#171 is in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants