diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 7364415be6dad..9cd435e91a716 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -617,67 +617,95 @@ def _use_flex_decoding(query, kernel_options): } -def _get_default_config_fwd(query) -> Tuple[int, int, int, int]: +def _get_rocm_config(query, mode: str) -> Tuple[int, int, int, int]: dtype = query.get_dtype() head_dim = query.get_size()[-1] - default_config = None - - if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 - if dtype == torch.float32: - default_config = (64, 64, 4, 3) - else: - default_config = (128, 64, 4, 3) - default_config = _h100_default_config.get((dtype, head_dim), default_config) - elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100 - if dtype == torch.float32: - default_config = (64, 64, 4, 3) - else: - default_config = (128, 64, 4, 3) - default_config = _a100_default_config.get((dtype, head_dim), default_config) - elif head_dim <= 256 and torch.version.hip: + fwd_config = None + + if mode == "fwd": + if head_dim <= 256: + if dtype == torch.float32: + fwd_config = (64, 64, 4, 1) + else: + fwd_config = (128, 64, 8, 1) + fwd_config = _rocm_default_config.get((dtype, head_dim), fwd_config) + else: # modest hardware or extremely large head_dim + if dtype == torch.float32: + fwd_config = (32, 16, 4, 1) + else: + fwd_config = (64, 32, 4, 1) + return fwd_config + else: # bwd if dtype == torch.float32: - default_config = (64, 64, 4, 1) - else: - default_config = (128, 64, 8, 1) - default_config = _rocm_default_config.get((dtype, head_dim), default_config) - else: # modest hardware or extremely large head_dim + return (16, 16, 4, 1) + elif head_dim <= 256: + if head_dim == 64: + return (64, 64, 4, 1) + elif head_dim == 128: + return (64, 128, 8, 1) + else: + return (64, 64, 4, 1) + else: # modest hardware or extremely large head_dim + return (16, 16, 4, 1) + + +def _get_nv_config(query, mode: str) -> Tuple[int, int, int, int]: + dtype = query.get_dtype() + head_dim = query.get_size()[-1] + fwd_config = None + + capability = torch.cuda.get_device_capability() + + if mode == "fwd": + if head_dim <= 256: + if dtype == torch.float32: + fwd_config = (64, 64, 4, 3) + else: + fwd_config = (128, 64, 4, 3) + if capability >= (9, 0): + fwd_config = _h100_default_config.get((dtype, head_dim), fwd_config) + elif capability >= (8, 0): + fwd_config = _a100_default_config.get((dtype, head_dim), fwd_config) + else: # modest hardware or extremely large head_dim + if dtype == torch.float32: + fwd_config = (32, 16, 4, 3) + else: + fwd_config = (64, 32, 4, 3) + return fwd_config + + else: # bwd if dtype == torch.float32: - default_config = (32, 16, 4, 3) - else: - default_config = (64, 32, 4, 3) + return (16, 16, 4, 1) + elif head_dim <= 256 and capability >= (9, 0): # H100 + if head_dim == 64: + return (64, 64, 4, 3) + elif head_dim == 128: + return (64, 128, 8, 3) + else: + return (64, 64, 4, 2) + elif capability >= (8, 0): # A100 + if head_dim == 64: + return (32, 128, 4, 3) + elif head_dim == 128: + return (64, 128, 8, 3) + else: + return (64, 64, 4, 2) + else: # modest hardware or extremely large head_dim + return (16, 16, 4, 1) - return default_config +def _get_default_config_fwd(query) -> Tuple[int, int, int, int]: + if torch.version.hip is None: + return _get_nv_config(query, "fwd") + else: + return _get_rocm_config(query, "fwd") -def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: - head_dim = query.get_size()[-1] - dtype = query.get_dtype() - if dtype == torch.float32: - return (16, 16, 4, 1) - if head_dim <= 256 and torch.version.hip: - if head_dim == 64: - return (64, 64, 4, 1) - elif head_dim == 128: - return (64, 128, 4, 1) - else: - return (64, 64, 4, 1) - elif head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 - if head_dim == 64: - return (64, 64, 4, 3) - elif head_dim == 128: - return (64, 128, 8, 3) - else: - return (64, 64, 4, 2) - elif torch.cuda.get_device_capability() >= (8, 0): # A100 - if head_dim == 64: - return (32, 128, 4, 3) - elif head_dim == 128: - return (64, 128, 8, 3) - else: - return (64, 64, 4, 2) - else: # modest hardware or extremely large head_dim - return (16, 16, 4, 1) +def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: + if torch.version.hip is None: + return _get_nv_config(query, "bwd") + else: + return _get_rocm_config(query, "bwd") def create_num_blocks_fake_generator(sparse_indices):