Skip to content

Commit

Permalink
Feat (llm/export): rename brevitas quant custom op (#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 authored Aug 17, 2023
1 parent 461a0e0 commit 993307a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 20 deletions.
18 changes: 9 additions & 9 deletions src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def matmul_rhs_group_quant(
raise ValueError("Input shapes not supported.")


brevitas_lib = torch.library.Library("brevitas", "DEF")
brevitas_lib = torch.library.Library("quant", "DEF")
brevitas_lib.define(
"matmul_rhs_group_quant(Tensor lhs, Tensor rhs, Tensor rhs_scale, Tensor rhs_zero_point, int rhs_bit_width, int rhs_group_size) -> Tensor"
)
brevitas_lib.impl("matmul_rhs_group_quant", matmul_rhs_group_quant)


def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
Expand All @@ -72,20 +72,20 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
raise ValueError("Input shapes not supported.")


def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype


def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
# yapf: enable

brevitas_matmul_rhs_group_quant_library = [
brevitas〇matmul_rhs_group_quant〡shape,
brevitas〇matmul_rhs_group_quant〡dtype,
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
quant〇matmul_rhs_group_quant〡shape,
quant〇matmul_rhs_group_quant〡dtype,
quant〇matmul_rhs_group_quant〡has_value_semantics]

if __name__ == '__main__':

Expand All @@ -100,7 +100,7 @@ def forward(
rhs: torch.Tensor,
rhs_scale: torch.Tensor,
rhs_zero_point: torch.Tensor):
return torch.ops.brevitas.matmul_rhs_group_quant(
return torch.ops.quant.matmul_rhs_group_quant(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width=8, rhs_group_size=128)

mod = CustomOpExampleModule()
Expand All @@ -109,6 +109,6 @@ def forward(
module = torch_mlir.compile(
mod, (torch.ones(3, 4), torch.ones(5, 4), torch.ones(1), torch.ones(1)),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library)
print(module)
12 changes: 5 additions & 7 deletions src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@

# Due a tracing issue this annotation needs to be
# in the same module (== file) from which make_fx is called
# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant
# We also can't directly annotate torch.ops.quant.matmul_rhs_group_quant
# and so we trace a placeholder first and then replace it post tracing
@wrap(visible_to_make_fx=True)
def matmul_rhs_group_quant_placeholder(*args, **kwargs):
return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs)
return torch.ops.quant.matmul_rhs_group_quant(*args, **kwargs)


class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler):
Expand Down Expand Up @@ -261,9 +261,7 @@ def transform_fx(fx_g):

transform_fx(fx_g)
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant)
fx_g, src=matmul_rhs_group_quant_placeholder, target=torch.ops.quant.matmul_rhs_group_quant)

fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
Expand Down Expand Up @@ -319,7 +317,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir
module = torch_mlir.compile(
ts_g, (hidden_states_placeholder, inputs[1], inputs[2]),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False)
Expand All @@ -342,7 +340,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir
pkv0_placeholder,
pkv1_placeholder),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False)
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas_examples/llm/test_linear_mlir_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

# Due a tracing issue this annotation needs to be
# in the same module (== file) from which make_fx is called
# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant
# We also can't directly annotate torch.ops.quant.matmul_rhs_group_quant
# and so we trace a placeholder first and then replace it post tracing
@wrap(visible_to_make_fx=True)
def matmul_rhs_group_quant_placeholder(*args, **kwargs):
return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs)
return torch.ops.quant.matmul_rhs_group_quant(*args, **kwargs)


class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler):
Expand Down Expand Up @@ -84,7 +84,7 @@ def quantize_and_export(args):
replace_call_fn_target(
traced_model,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant)
target=torch.ops.quant.matmul_rhs_group_quant)

# print the output graph
print(traced_model.graph)
Expand All @@ -93,7 +93,7 @@ def quantize_and_export(args):
traced_model,
torch.randn(2, 128),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=True,
verbose=False)
Expand Down

0 comments on commit 993307a

Please sign in to comment.