Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bugs on database search #438

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

sgjzfzzf
Copy link
Contributor

@sgjzfzzf sgjzfzzf commented Feb 8, 2025

PR Category

Model Test

Type of Change

Bug Fix

Description

When reading the cache of the compiled kernel, the read index sometimes could be an error. You can replay this bug with the following script.

import flag_gems
import torch


x = torch.randn((256, 256), device="cuda")
y = torch.randn((256, 256), device="cuda")
with flag_gems.use_gems():
    z0 = torch.matmul(x, y)
    z1 = torch.matmul(x, y)

In the PR, it reads by filter, not index, so it has no relationship with the parameter storage position.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

@StrongSpoon
Copy link
Collaborator

StrongSpoon commented Feb 12, 2025

Could you please provide the environment information? Which version of triton are you using?
I didn't reproduce the error successfully. and below is the entry_key in my program.
((torch.float32, True), (torch.float32, True), (torch.float32, True), (<class 'int'>, 256), (<class 'int'>, 256), (<class 'int'>, 256), (<class 'int'>, 256), (<class 'int'>, 1), (<class 'int'>, 256), (<class 'int'>, 1), (<class 'int'>, 256), (<class 'int'>, 1), triton.language.fp32, 8)

@sgjzfzzf
Copy link
Contributor Author

Sorry for the confusion. You need to run this script twice to reproduce the bug. It'll report an error like

Traceback (most recent call last):
  File "/tmp/flaggems/matmul.py", line 8, in <module>
    z0 = torch.matmul(x, y)
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/flag_gems/ops/mm.py", line 129, in mm
    mm_kernel[grid](
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/triton/runtime/jit.py", line 330, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/flag_gems/utils/libentry.py", line 242, in run
    kernel = self.fn.run(*args, **kwargs)
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 203, in run
    ret = self.fn.run(
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 385, in run
    return self.fn.run(*args, **kwargs)
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/triton/runtime/jit.py", line 588, in run
    options = backend.parse_options(kwargs)
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 159, in parse_options
    return CUDAOptions(**args)
  File "<string>", line 23, in __init__
  File "/tmp/flaggems/.venv/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 122, in __post_init__
    assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
AssertionError: num_warps must be a power of 2

And the environment I'm working on is

uname -r
5.15.153.1-microsoft-standard-WSL2

python3
Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import flag_gems
>>> import torch
>>> import triton
>>> flag_gems.__version__
'2.2'
>>> torch.__version__
'2.6.0+cu124'
>>> triton.__version__
'3.2.0'

@StrongSpoon
Copy link
Collaborator

StrongSpoon commented Feb 13, 2025

I reproduced it with Triton v3.2. This error is caused by the difference between definitions of class Config. In version 3.2, developers added arguments like num_buffers_warp_spec, num_consumer_groups and so on. To adapt to the update, I suggest reconstructing the function preload. These new arguments should be loaded into config as well.

class Config:
    def __str__(self):
        res = []
        for k, v in self.kwargs.items():
            res.append(f"{k}: {v}")
        res.append(f"num_warps: {self.num_warps}")
        res.append(f"num_ctas: {self.num_ctas}")
        res.append(f"num_stages: {self.num_stages}")
        res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}")
        res.append(f"num_consumer_groups: {self.num_consumer_groups}")
        res.append(f"reg_dec_producer: {self.reg_dec_producer}")
        res.append(f"reg_inc_consumer: {self.reg_inc_consumer}")
        res.append(f"maxnreg: {self.maxnreg}")
        return ", ".join(res)

@StrongSpoon StrongSpoon self-requested a review February 13, 2025 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants