diff --git a/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/activation_sampler.py b/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/activation_sampler.py index bc456229ca1..f227747d970 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/activation_sampler.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/activation_sampler.py @@ -42,7 +42,7 @@ # Import AIMET specific modules from aimet_common.utils import AimetLogger -from aimet_torch.utils import CachedDataset, ModuleData, get_named_module, cache_intermediate_datasets,\ +from aimet_torch.utils import CachedDataset, ModuleData, cache_intermediate_datasets,\ change_tensor_device_placement, in_eval_mode, save_to_cache, get_ordered_list_of_modules from aimet_torch._base.quantsim import _QuantizationSimModelInterface, _QuantizedModuleProtocol @@ -65,8 +65,8 @@ def create_modulelist_for_group_modules(model: torch.nn.Module, sim: _Quantizati fp_modulelist = torch.nn.ModuleList() quant_modulelist = torch.nn.ModuleList() for name in modules: - fp_modulelist.append(get_named_module(model, name)) - quant_modulelist.append(get_named_module(sim.model, name)) + fp_modulelist.append(model.get_submodule(name)) + quant_modulelist.append(sim.model.get_submodule(name)) sub_fp_models.append(fp_modulelist) sub_sim_models.append(quant_modulelist) @@ -282,7 +282,7 @@ def create_cached_block_schedule_list(model: torch.nn.Module, dummy_input, block caching_modules = {module: {'block': None, 'name': name} for name, module in modules} for name in block_names: - parent_module: torch.nn.Module = get_named_module(model, name) + parent_module: torch.nn.Module = model.get_submodule(name) for _, module in parent_module.named_modules(): if module in caching_modules: module_name = caching_modules[module]['name'] diff --git a/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/adaround_weight.py b/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/adaround_weight.py index 24d813e20d5..11da9f8b01d 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/adaround_weight.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/_base/adaround/adaround_weight.py @@ -57,7 +57,6 @@ from aimet_torch._base.adaround.adaround_loss import AdaroundHyperParameters from aimet_torch._base.adaround.activation_sampler import create_modulelist_for_group_modules, get_block_inputs, \ get_block_outputs, create_cached_block_schedule_list -from aimet_torch.utils import get_named_module logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) @@ -315,7 +314,7 @@ def fwd_mod_ls(mod_ls, *args, **kwargs): params.forward_fn, cached_dataset) else: block_name, fp_block = block_cfg - quant_sim_block: torch.nn.Module = get_named_module(quant_sim.model, block_name) + quant_sim_block: torch.nn.Module = quant_sim.model.get_submodule(block_name) cached_fp_dataset, cached_quant_dataset = get_block_inputs(model, quant_sim, block_name, @@ -442,7 +441,7 @@ def _replace_quantization_layer(cls, quant_sim_model: torch.nn.Module, module_na Replace the quantized module's weight tensor quantizer with the Adaround tensor quantizer :param quant_module: quant module """ - quant_module = utils.get_named_module(quant_sim_model, module_name) + quant_module = quant_sim_model.get_submodule(module_name) cls._validate_quant_module_for_adaround(quant_module) adaround_layer = cls._get_adaround_wrapper(quant_module) @@ -450,7 +449,7 @@ def _replace_quantization_layer(cls, quant_sim_model: torch.nn.Module, module_na upper_module = quant_sim_model upper_module_name, _, target_module_name = module_name.rpartition('.') if upper_module_name: - upper_module = utils.get_named_module(quant_sim_model, upper_module_name) + upper_module = quant_sim_model.get_submodule(upper_module_name) # Temporarily replace quant module with wrapped module with cls._patch_module_layer(upper_module, target_module_name, adaround_layer): diff --git a/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py index 2603f31fae9..1f8f33efd1a 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py @@ -692,7 +692,7 @@ def _get_qmodule(self, op: Op) -> Optional[_QuantizedModuleProtocol]: return None module_name = '.'.join(module_names) - return utils.get_named_module(self.model, module_name) + return self.model.get_submodule(module_name) @torch.no_grad() def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple], # pylint: disable=arguments-differ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py index edfa9fe4e54..defe799b741 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py @@ -314,7 +314,7 @@ def disable_output_quantizers(qmodule): module_name, module = ordered_conv_linear_nodes[0] if module not in layers_to_ignore: logger.info('Correcting layer %s using Analytical Bias Correction', module_name) - quantize_layer = utils.get_named_module(model, module_name) + quantize_layer = model.get_submodule(module_name) call_analytical_correct_bias(quantize_layer, None, None) logger.info('Corrected bias for the layer') ordered_conv_linear_nodes.pop(0) @@ -325,8 +325,8 @@ def disable_output_quantizers(qmodule): continue else: # Analytical Bias Correction is only done for Conv layers - reference_layer = utils.get_named_module(model_copy, module_name) - quantize_layer = utils.get_named_module(model, module_name) + reference_layer = model_copy.get_submodule(module_name) + quantize_layer = model.get_submodule(module_name) if module in conv_bn_dict.keys(): bn_layer_info = conv_bn_dict[module] diff --git a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/activation_sampler.py b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/activation_sampler.py index e3e0b226de9..00cae79f696 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/activation_sampler.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/activation_sampler.py @@ -45,7 +45,6 @@ StopForwardException, change_tensor_device_placement, get_device, - get_named_module, in_eval_mode, get_module_to_name_dict, ) @@ -91,7 +90,7 @@ def _hook_to_collect_inp_data(module, inp, _): handles = [] for module_name in self._module_names: handles.append( - get_named_module(self._model, module_name).register_forward_hook( + self._model.get_submodule(module_name).register_forward_hook( _hook_to_collect_inp_data ) ) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py index 39916a178aa..6b4056acf8a 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/gptvq_weight.py @@ -54,7 +54,6 @@ from aimet_torch.gptvq.utils import get_module_name_to_hessian_tensor from aimet_torch._base.quantsim import _QuantizedModuleProtocol from aimet_torch.save_utils import SaveUtils -from aimet_torch.utils import get_named_module from aimet_torch.v2.nn import BaseQuantizationMixin from aimet_torch.v2.quantization.affine.quantizer import QuantizeDequantize from aimet_torch.v2.quantization.tensor import QuantizedTensorBase @@ -203,7 +202,7 @@ def _replace_param_quantizers(sim: QuantizationSimModel, rows_per_block: int, mo :param module_name_set: Module name set containing candidates of GPTVQ optimization """ for module_name in module_name_set: - module = get_named_module(sim.model, module_name) + module = sim.model.get_submodule(module_name) assert isinstance(module, BaseQuantizationMixin) param_quantizer = module.param_quantizers["weight"] @@ -353,7 +352,7 @@ def _get_applicable_name_to_module_dict( """ name_to_quant_module = collections.OrderedDict() for name in module_names: - quant_module = get_named_module(sim.model, name) + quant_module = sim.model.get_submodule(name) if name not in module_names_to_exclude: name_to_quant_module[name] = quant_module return name_to_quant_module diff --git a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/utils.py index ea4cdbe9537..e9d5ce6cc45 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/gptvq/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/gptvq/utils.py @@ -45,7 +45,6 @@ import aimet_torch.v2.quantization as Q from aimet_torch.gptvq.activation_sampler import ActivationSampler from aimet_torch.gptvq.defs import GPTVQParameters -from aimet_torch.utils import get_named_module from aimet_torch.v2.nn import BaseQuantizationMixin from aimet_torch.v2.quantsim import QuantizationSimModel @@ -386,7 +385,7 @@ def get_module_name_to_hessian_tensor(gptvq_params: GPTVQParameters, """ name_to_hessian = {} for module_name in module_names: - quant_module = get_named_module(sim.model, module_name) + quant_module = sim.model.get_submodule(module_name) _, num_cols = get_2d_tensor_shape(quant_module) device = quant_module.weight.device name_to_hessian[module_name] = torch.zeros((num_cols, num_cols), device=device) @@ -402,7 +401,7 @@ def get_module_name_to_hessian_tensor(gptvq_params: GPTVQParameters, inp_data = inp_data.unsqueeze(0) curr_batch_size = inp_data.shape[0] - quant_module = get_named_module(sim.model, name) + quant_module = sim.model.get_submodule(name) update_hessian(quant_module, inp_data, n_samples, curr_batch_size, name_to_hessian[name]) n_samples += curr_batch_size diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index 6f11cb44447..e53d55f1284 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -44,7 +44,6 @@ import os import pickle import sys -import functools import logging import warnings @@ -858,16 +857,6 @@ def save_to_cache(tensor, dir_path, idx): pickle.dump(tensor, cache) -def get_named_module(model, name): - """ - Given the name, get the target module in the model - :param model: Model that contains the target module - :param name: Name of the target module - :return: - """ - return functools.reduce(getattr, name.split("."), model) - - def cache_intermediate_datasets( cached_dataset, cache_on_cpu, model, module_name, forward_fn, path=None, incl_kwargs: bool = False): """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/visualize_model.py b/TrainingExtensions/torch/src/python/aimet_torch/visualize_model.py index 0718f9b48a9..a1629bc1256 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/visualize_model.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/visualize_model.py @@ -42,7 +42,6 @@ from bokeh import plotting from bokeh.layouts import column from aimet_torch import plotting_utils -from aimet_torch.utils import get_named_module def visualize_changes_after_optimization( @@ -67,7 +66,7 @@ def visualize_changes_after_optimization( if selected_layers: for name, module in new_model.named_modules(): if name in selected_layers and hasattr(module, "weight"): - old_model_module = get_named_module(old_model, name) + old_model_module = old_model.get_submodule(name) new_model_module = module subplots.append( plotting_utils.visualize_changes_after_optimization_single_layer( @@ -79,7 +78,7 @@ def visualize_changes_after_optimization( for name, module in new_model.named_modules(): if hasattr(module, "weight") and\ isinstance(module, (torch.nn.modules.conv.Conv2d, torch.nn.modules.linear.Linear)): - old_model_module = get_named_module(old_model, name) + old_model_module = old_model.get_submodule(name) new_model_module = module subplots.append( plotting_utils.visualize_changes_after_optimization_single_layer( diff --git a/TrainingExtensions/torch/test/python/test_channel_pruning.py b/TrainingExtensions/torch/test/python/test_channel_pruning.py index 0257f3085d8..8e6af92d577 100644 --- a/TrainingExtensions/torch/test/python/test_channel_pruning.py +++ b/TrainingExtensions/torch/test/python/test_channel_pruning.py @@ -56,7 +56,7 @@ from aimet_torch.channel_pruning.weight_reconstruction import WeightReconstructor from aimet_torch.channel_pruning.channel_pruner import InputChannelPruner from .models.mnist_torch_model import Net as mnist_model -from aimet_torch.utils import create_fake_data_loader, get_layer_name, get_named_module,\ +from aimet_torch.utils import create_fake_data_loader, get_layer_name,\ create_rand_tensors_given_shapes, get_device from aimet_torch.layer_database import Layer, LayerDatabase @@ -507,7 +507,7 @@ def test_data_sub_sampling_and_reconstruction_without_bias(self): num_reconstruction_samples= num_reconstruction_samples) - conv_layer = get_named_module(comp_model, conv2_pr_layer_name) + conv_layer = comp_model.get_submodule(conv2_pr_layer_name) assert conv_layer == comp_model.conv2 # original weight before reconstruction diff --git a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py index c3a29c8d45b..92874c20011 100644 --- a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py +++ b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_export.py @@ -56,7 +56,6 @@ from aimet_torch.v1.quantsim import OnnxExportApiArgs from aimet_torch.v1.qc_quantize_op import QcQuantizeWrapper from aimet_common.defs import QuantScheme -from aimet_torch.utils import get_named_module from ..models_.models_to_test import ( SimpleConditional, @@ -414,7 +413,7 @@ def test_json_interchangeable(self): if not isinstance(module, QcQuantizeWrapper): continue wrapper = module - qmodule = get_named_module(sim_v2.model, name) + qmodule = sim_v2.model.get_submodule(name) assert wrapper.input_quantizers[0].enabled == (qmodule.input_quantizers[0] is not None) assert wrapper.output_quantizers[0].enabled == (qmodule.output_quantizers[0] is not None) diff --git a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_qat.py b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_qat.py index b55dc27951a..f562d9ba580 100644 --- a/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_qat.py +++ b/TrainingExtensions/torch/test/python/v2/ab_test/test_quantsim_qat.py @@ -52,7 +52,7 @@ from aimet_torch.v2.quantization.affine.backends import torch_builtins from aimet_torch.v2.quantsim import QuantizationSimModel from aimet_torch.v2.nn import BaseQuantizationMixin -from aimet_torch.utils import get_named_module, is_leaf_module +from aimet_torch.utils import is_leaf_module class STE(torch.autograd.Function): @@ -180,10 +180,8 @@ def test_grad_correctness(self, model_cls, input_shape, quant_scheme, config_pat quantized_modules = get_quantized_modules(aimetgrad_qsim.model) assert len(quantized_modules) > 0 - for param_name, _ in aimetgrad_qsim.model.named_parameters(): - aimetgrad_param = get_named_module(aimetgrad_qsim.model, param_name) - autograd_param = get_named_module(autograd_qsim.model, param_name) - + for aimetgrad_param, autograd_param in zip(aimetgrad_qsim.model.parameters(), + autograd_qsim.model.parameters()): if not aimetgrad_param.requires_grad and not autograd_param.requires_grad: continue