Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Jan 29, 2025
1 parent 8bf91e2 commit d9667b1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
3 changes: 0 additions & 3 deletions torchao/experimental/q_dq_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
logger.addHandler(handler)


import torch

from torchao.dtypes.affine_quantized_tensor import register_layout
from torchao.dtypes.utils import PlainLayout


Expand Down
12 changes: 5 additions & 7 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,10 @@ def quantize(self, model: nn.Module) -> nn.Module:
to_linear_activation_quantized,
)
from torchao.quantization.quant_api import (
_get_linear_subclass_inserter,
MappingType,
to_affine_quantized_intx,
ZeroPointDomain,
_get_linear_subclass_inserter,
to_affine_quantized_intx,
)
from torchao.quantization.utils import _get_per_token_block_size

Expand Down Expand Up @@ -568,18 +568,16 @@ def apply(weight):

layout = layout_arg
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
assert weight.device == torch.device(
"cpu"
assert (
weight.device == torch.device("cpu")
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.device=CPU"
assert (
weight.dtype == torch.float32
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.dtype=float32"
assert (
act_mapping_type == MappingType.ASYMMETRIC
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"
assert (
not layout.has_params_set()
), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
bit_width=bit_width,
group_size=group_size,
Expand Down
2 changes: 0 additions & 2 deletions torchao/experimental/tests/test_q_dq_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import copy
import tempfile
import unittest

import torch
Expand Down Expand Up @@ -116,7 +115,6 @@ def test_export(self):
)
eager_results = model(activations)

unwrapped_model = copy.deepcopy(model)
unwrap_tensor_subclass(model)

print("Exporting quantized model")
Expand Down

0 comments on commit d9667b1

Please sign in to comment.