Skip to content

Commit

Permalink
Update attention template
Browse files Browse the repository at this point in the history
This commit updates the attention template
to match what iree-compile expects as of
Nov 14 2024.
  • Loading branch information
manupak committed Nov 14, 2024
1 parent 25b29d0 commit 0579316
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions attentionbench/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,15 @@ def get_lowering_config(self) -> str:
+ f">"
)

def get_mma_schedule(self) -> str:
return (
f"#iree_gpu.mma_schedule<"
+ f"intrinsic = #iree_gpu.mma_layout<{self.intrinsic}>"
+ f", subgroup_m_count = {self.M_warp}"
+ f", subgroup_n_count = {self.N_warp}"
+ f">"
)
def get_lowering_config_for_mmt(self, extra_args) -> str:
base_str = (f"#iree_gpu.lowering_config<{{"
+ f"mma_kind = #iree_gpu.mma_layout<{self.intrinsic}>"
+ f", subgroup_m_count = {self.M_warp}"
+ f", subgroup_n_count = {self.N_warp}")
for arg in extra_args:
base_str += f", {arg}"
base_str += "}>"
return base_str

def get_translation_info(self) -> str:
llvm_func_attrs = []
Expand All @@ -93,11 +94,10 @@ def get_translation_info(self) -> str:
llvm_func_attrs += [f'"denormal-fp-math-f32" = "preserve-sign"']
return (
f"#iree_codegen.translation_info<"
+ f"LLVMGPUVectorDistribute"
+ f"pipeline = LLVMGPUVectorDistribute"
+ f" workgroup_size = [{self.N_warp * self.M_warp * 64}]"
+ f" subgroup_size = 64"
+ f" ,{{mma_schedule = {self.get_mma_schedule()}"
+ f" , llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
+ f" , {{llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
+ f"}}"
+ f">"
)
Expand Down Expand Up @@ -139,8 +139,8 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
%O = iree_linalg_ext.attention
{{ indexing_maps = [#Q, #K, #V, #S, #O]
,decomposition_config = {{
qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}>}},
pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}>}}
qk_attrs = {{attention_qk_matmul, lowering_config = {tuning.get_lowering_config_for_mmt(["promote_operands = [0, 1]"])}}},
pv_attrs = {{attention_pv_matmul, lowering_config = {tuning.get_lowering_config_for_mmt(["promote_operands = [1]"])}}}
}}
{",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""}
}}
Expand Down

0 comments on commit 0579316

Please sign in to comment.