Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace get_named_module with nn.Module.get_submodule #3821

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

# Import AIMET specific modules
from aimet_common.utils import AimetLogger
from aimet_torch.utils import CachedDataset, ModuleData, get_named_module, cache_intermediate_datasets,\
from aimet_torch.utils import CachedDataset, ModuleData, cache_intermediate_datasets,\
change_tensor_device_placement, in_eval_mode, save_to_cache, get_ordered_list_of_modules
from aimet_torch._base.quantsim import _QuantizationSimModelInterface, _QuantizedModuleProtocol

Expand All @@ -65,8 +65,8 @@ def create_modulelist_for_group_modules(model: torch.nn.Module, sim: _Quantizati
fp_modulelist = torch.nn.ModuleList()
quant_modulelist = torch.nn.ModuleList()
for name in modules:
fp_modulelist.append(get_named_module(model, name))
quant_modulelist.append(get_named_module(sim.model, name))
fp_modulelist.append(model.get_submodule(name))
quant_modulelist.append(sim.model.get_submodule(name))
sub_fp_models.append(fp_modulelist)
sub_sim_models.append(quant_modulelist)

Expand Down Expand Up @@ -282,7 +282,7 @@ def create_cached_block_schedule_list(model: torch.nn.Module, dummy_input, block
caching_modules = {module: {'block': None, 'name': name} for name, module in modules}

for name in block_names:
parent_module: torch.nn.Module = get_named_module(model, name)
parent_module: torch.nn.Module = model.get_submodule(name)
for _, module in parent_module.named_modules():
if module in caching_modules:
module_name = caching_modules[module]['name']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from aimet_torch._base.adaround.adaround_loss import AdaroundHyperParameters
from aimet_torch._base.adaround.activation_sampler import create_modulelist_for_group_modules, get_block_inputs, \
get_block_outputs, create_cached_block_schedule_list
from aimet_torch.utils import get_named_module

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

Expand Down Expand Up @@ -315,7 +314,7 @@ def fwd_mod_ls(mod_ls, *args, **kwargs):
params.forward_fn, cached_dataset)
else:
block_name, fp_block = block_cfg
quant_sim_block: torch.nn.Module = get_named_module(quant_sim.model, block_name)
quant_sim_block: torch.nn.Module = quant_sim.model.get_submodule(block_name)

cached_fp_dataset, cached_quant_dataset = get_block_inputs(model, quant_sim,
block_name,
Expand Down Expand Up @@ -442,15 +441,15 @@ def _replace_quantization_layer(cls, quant_sim_model: torch.nn.Module, module_na
Replace the quantized module's weight tensor quantizer with the Adaround tensor quantizer
:param quant_module: quant module
"""
quant_module = utils.get_named_module(quant_sim_model, module_name)
quant_module = quant_sim_model.get_submodule(module_name)
cls._validate_quant_module_for_adaround(quant_module)
adaround_layer = cls._get_adaround_wrapper(quant_module)

# We need to look for the container to patch for modules inside submodule
upper_module = quant_sim_model
upper_module_name, _, target_module_name = module_name.rpartition('.')
if upper_module_name:
upper_module = utils.get_named_module(quant_sim_model, upper_module_name)
upper_module = quant_sim_model.get_submodule(upper_module_name)

# Temporarily replace quant module with wrapped module
with cls._patch_module_layer(upper_module, target_module_name, adaround_layer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def _get_qmodule(self, op: Op) -> Optional[_QuantizedModuleProtocol]:
return None

module_name = '.'.join(module_names)
return utils.get_named_module(self.model, module_name)
return self.model.get_submodule(module_name)

@torch.no_grad()
def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple], # pylint: disable=arguments-differ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def disable_output_quantizers(qmodule):
module_name, module = ordered_conv_linear_nodes[0]
if module not in layers_to_ignore:
logger.info('Correcting layer %s using Analytical Bias Correction', module_name)
quantize_layer = utils.get_named_module(model, module_name)
quantize_layer = model.get_submodule(module_name)
call_analytical_correct_bias(quantize_layer, None, None)
logger.info('Corrected bias for the layer')
ordered_conv_linear_nodes.pop(0)
Expand All @@ -325,8 +325,8 @@ def disable_output_quantizers(qmodule):
continue
else:
# Analytical Bias Correction is only done for Conv layers
reference_layer = utils.get_named_module(model_copy, module_name)
quantize_layer = utils.get_named_module(model, module_name)
reference_layer = model_copy.get_submodule(module_name)
quantize_layer = model.get_submodule(module_name)

if module in conv_bn_dict.keys():
bn_layer_info = conv_bn_dict[module]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
StopForwardException,
change_tensor_device_placement,
get_device,
get_named_module,
in_eval_mode,
get_module_to_name_dict,
)
Expand Down Expand Up @@ -91,7 +90,7 @@ def _hook_to_collect_inp_data(module, inp, _):
handles = []
for module_name in self._module_names:
handles.append(
get_named_module(self._model, module_name).register_forward_hook(
self._model.get_submodule(module_name).register_forward_hook(
_hook_to_collect_inp_data
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from aimet_torch.gptvq.utils import get_module_name_to_hessian_tensor
from aimet_torch._base.quantsim import _QuantizedModuleProtocol
from aimet_torch.save_utils import SaveUtils
from aimet_torch.utils import get_named_module
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.v2.quantization.affine.quantizer import QuantizeDequantize
from aimet_torch.v2.quantization.tensor import QuantizedTensorBase
Expand Down Expand Up @@ -203,7 +202,7 @@ def _replace_param_quantizers(sim: QuantizationSimModel, rows_per_block: int, mo
:param module_name_set: Module name set containing candidates of GPTVQ optimization
"""
for module_name in module_name_set:
module = get_named_module(sim.model, module_name)
module = sim.model.get_submodule(module_name)
assert isinstance(module, BaseQuantizationMixin)

param_quantizer = module.param_quantizers["weight"]
Expand Down Expand Up @@ -353,7 +352,7 @@ def _get_applicable_name_to_module_dict(
"""
name_to_quant_module = collections.OrderedDict()
for name in module_names:
quant_module = get_named_module(sim.model, name)
quant_module = sim.model.get_submodule(name)
if name not in module_names_to_exclude:
name_to_quant_module[name] = quant_module
return name_to_quant_module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import aimet_torch.v2.quantization as Q
from aimet_torch.gptvq.activation_sampler import ActivationSampler
from aimet_torch.gptvq.defs import GPTVQParameters
from aimet_torch.utils import get_named_module
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.v2.quantsim import QuantizationSimModel

Expand Down Expand Up @@ -386,7 +385,7 @@ def get_module_name_to_hessian_tensor(gptvq_params: GPTVQParameters,
"""
name_to_hessian = {}
for module_name in module_names:
quant_module = get_named_module(sim.model, module_name)
quant_module = sim.model.get_submodule(module_name)
_, num_cols = get_2d_tensor_shape(quant_module)
device = quant_module.weight.device
name_to_hessian[module_name] = torch.zeros((num_cols, num_cols), device=device)
Expand All @@ -402,7 +401,7 @@ def get_module_name_to_hessian_tensor(gptvq_params: GPTVQParameters,
inp_data = inp_data.unsqueeze(0)
curr_batch_size = inp_data.shape[0]

quant_module = get_named_module(sim.model, name)
quant_module = sim.model.get_submodule(name)
update_hessian(quant_module, inp_data, n_samples, curr_batch_size, name_to_hessian[name])

n_samples += curr_batch_size
Expand Down
11 changes: 0 additions & 11 deletions TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import os
import pickle
import sys
import functools
import logging
import warnings

Expand Down Expand Up @@ -858,16 +857,6 @@ def save_to_cache(tensor, dir_path, idx):
pickle.dump(tensor, cache)


def get_named_module(model, name):
"""
Given the name, get the target module in the model
:param model: Model that contains the target module
:param name: Name of the target module
:return:
"""
return functools.reduce(getattr, name.split("."), model)


def cache_intermediate_datasets(
cached_dataset, cache_on_cpu, model, module_name, forward_fn, path=None, incl_kwargs: bool = False):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from bokeh import plotting
from bokeh.layouts import column
from aimet_torch import plotting_utils
from aimet_torch.utils import get_named_module


def visualize_changes_after_optimization(
Expand All @@ -67,7 +66,7 @@ def visualize_changes_after_optimization(
if selected_layers:
for name, module in new_model.named_modules():
if name in selected_layers and hasattr(module, "weight"):
old_model_module = get_named_module(old_model, name)
old_model_module = old_model.get_submodule(name)
new_model_module = module
subplots.append(
plotting_utils.visualize_changes_after_optimization_single_layer(
Expand All @@ -79,7 +78,7 @@ def visualize_changes_after_optimization(
for name, module in new_model.named_modules():
if hasattr(module, "weight") and\
isinstance(module, (torch.nn.modules.conv.Conv2d, torch.nn.modules.linear.Linear)):
old_model_module = get_named_module(old_model, name)
old_model_module = old_model.get_submodule(name)
new_model_module = module
subplots.append(
plotting_utils.visualize_changes_after_optimization_single_layer(
Expand Down
4 changes: 2 additions & 2 deletions TrainingExtensions/torch/test/python/test_channel_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from aimet_torch.channel_pruning.weight_reconstruction import WeightReconstructor
from aimet_torch.channel_pruning.channel_pruner import InputChannelPruner
from .models.mnist_torch_model import Net as mnist_model
from aimet_torch.utils import create_fake_data_loader, get_layer_name, get_named_module,\
from aimet_torch.utils import create_fake_data_loader, get_layer_name,\
create_rand_tensors_given_shapes, get_device
from aimet_torch.layer_database import Layer, LayerDatabase

Expand Down Expand Up @@ -507,7 +507,7 @@ def test_data_sub_sampling_and_reconstruction_without_bias(self):
num_reconstruction_samples=
num_reconstruction_samples)

conv_layer = get_named_module(comp_model, conv2_pr_layer_name)
conv_layer = comp_model.get_submodule(conv2_pr_layer_name)

assert conv_layer == comp_model.conv2
# original weight before reconstruction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from aimet_torch.v1.quantsim import OnnxExportApiArgs
from aimet_torch.v1.qc_quantize_op import QcQuantizeWrapper
from aimet_common.defs import QuantScheme
from aimet_torch.utils import get_named_module

from ..models_.models_to_test import (
SimpleConditional,
Expand Down Expand Up @@ -414,7 +413,7 @@ def test_json_interchangeable(self):
if not isinstance(module, QcQuantizeWrapper):
continue
wrapper = module
qmodule = get_named_module(sim_v2.model, name)
qmodule = sim_v2.model.get_submodule(name)

assert wrapper.input_quantizers[0].enabled == (qmodule.input_quantizers[0] is not None)
assert wrapper.output_quantizers[0].enabled == (qmodule.output_quantizers[0] is not None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from aimet_torch.v2.quantization.affine.backends import torch_builtins
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.utils import get_named_module, is_leaf_module
from aimet_torch.utils import is_leaf_module


class STE(torch.autograd.Function):
Expand Down Expand Up @@ -180,10 +180,8 @@ def test_grad_correctness(self, model_cls, input_shape, quant_scheme, config_pat
quantized_modules = get_quantized_modules(aimetgrad_qsim.model)
assert len(quantized_modules) > 0

for param_name, _ in aimetgrad_qsim.model.named_parameters():
aimetgrad_param = get_named_module(aimetgrad_qsim.model, param_name)
autograd_param = get_named_module(autograd_qsim.model, param_name)

for aimetgrad_param, autograd_param in zip(aimetgrad_qsim.model.parameters(),
autograd_qsim.model.parameters()):
if not aimetgrad_param.requires_grad and not autograd_param.requires_grad:
continue

Expand Down
Loading