Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Just some cosmetics #1291

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions model_compression_toolkit/core/common/fusion/graph_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class FusedLayerType:
"""
def __init__(self):
self.__name__ = 'FusedLayer'


class GraphFuser:

def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

SchedulerInfo = namedtuple('SchedulerInfo', [OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING])


def compute_graph_max_cut(memory_graph: MemoryGraph,
n_iter: int = 50,
astar_n_iter: int = 500,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def set_bit_widths(mixed_precision_enable: bool,
node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
if node_name in sorted_nodes_names: # only configurable nodes are in this list
node_index_in_graph = sorted_nodes_names.index(node_name)
_set_node_final_qc(bit_widths_config,
_set_node_final_qc(bit_widths_config[node_index_in_graph],
node,
node_index_in_graph,
graph.fw_info)
else:
if node.is_activation_quantization_enabled():
Expand Down Expand Up @@ -83,8 +82,7 @@ def set_bit_widths(mixed_precision_enable: bool,


def _get_node_qc_by_bit_widths(node: BaseNode,
bit_width_cfg: List[int],
node_index_in_graph: int,
node_bit_width_cfg: int,
fw_info) -> Any:
"""
Get the node's quantization configuration that
Expand All @@ -93,8 +91,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,

Args:
node: Node to get its quantization configuration candidate.
bit_width_cfg: Configuration which determines the node's desired bit width.
node_index_in_graph: Index of the node in the bit_width_cfg.
node_bit_width_cfg: Configuration which determines the node's desired bit width.
fw_info: Information relevant to a specific framework about how layers should be quantized.

Returns:
Expand All @@ -104,24 +101,21 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
kernel_attr = fw_info.get_kernel_op_attributes(node.type)

if node.is_activation_quantization_enabled():
bit_index_in_cfg = bit_width_cfg[node_index_in_graph]
qc = node.candidates_quantization_cfg[bit_index_in_cfg]
qc = node.candidates_quantization_cfg[node_bit_width_cfg]

return qc

elif kernel_attr is not None:
if node.is_weights_quantization_enabled(kernel_attr[0]):
bit_index_in_cfg = bit_width_cfg[node_index_in_graph]
qc = node.candidates_quantization_cfg[bit_index_in_cfg]
qc = node.candidates_quantization_cfg[node_bit_width_cfg]

return qc

Logger.critical(f"Quantization configuration for node '{node.name}' not found in candidate configurations.") # pragma: no cover


def _set_node_final_qc(bit_width_cfg: List[int],
def _set_node_final_qc(node_bit_width_cfg: int,
node: BaseNode,
node_index_in_graph: int,
fw_info):
"""
Get the node's quantization configuration that
Expand All @@ -130,15 +124,13 @@ def _set_node_final_qc(bit_width_cfg: List[int],
If the node quantization config was not found, raise an exception.

Args:
bit_width_cfg: Configuration which determines the node's desired bit width.
node_bit_width_cfg: Configuration which determines the node's desired bit width.
node: Node to set its node quantization configuration.
node_index_in_graph: Index of the node in the bit_width_cfg.
fw_info: Information relevant to a specific framework about how layers should be quantized.

"""
node_qc = _get_node_qc_by_bit_widths(node,
bit_width_cfg,
node_index_in_graph,
node_bit_width_cfg,
fw_info)

if node_qc is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
such that only a single candidate would remain, with the bitwidth equal to the one defined in the matching layer's
base config in the TPC.

Note" This function modifies the graph inplace!
Note: This function modifies the graph inplace!

Args:
graph: A graph representation of the model to be quantized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def compute_resource_utilization_data(in_model: Any,
Returns:
ResourceUtilization: An object encapsulating the calculated resource utilization computations.


"""
core_config = _create_core_config_for_ru(core_config)
# We assume that the resource_utilization_data API is used to compute the model resource utilization for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def sum_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[A
Returns: A list with an lpSum object for lp problem definition with the vector's sum.

"""
if not set_constraints:
return [0] if len(ru_vector) == 0 else [sum(ru_vector)]
return [lpSum(ru_vector)]
if set_constraints:
return [lpSum(ru_vector)]
return [0] if len(ru_vector) == 0 else [sum(ru_vector)]



def max_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[float]:
Expand All @@ -53,9 +54,10 @@ def max_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[f
in the linear programming problem formalization.

"""
if not set_constraints:
return [0] if len(ru_vector) == 0 else [max(ru_vector)]
return [ru for ru in ru_vector]
if set_constraints:
return [ru for ru in ru_vector]
return [0] if len(ru_vector) == 0 else [max(ru_vector)]



def total_ru(ru_tensor: np.ndarray, set_constraints: bool = True) -> List[float]:
Expand All @@ -74,16 +76,14 @@ def total_ru(ru_tensor: np.ndarray, set_constraints: bool = True) -> List[float]
in the linear programming problem formalization.

"""
if not set_constraints:
if set_constraints:
weights_ru = lpSum([ru[0] for ru in ru_tensor])
return [weights_ru + activation_ru for _, activation_ru in ru_tensor]
else:
weights_ru = sum([ru[0] for ru in ru_tensor])
activation_ru = max([ru[1] for ru in ru_tensor])
return [weights_ru + activation_ru]

weights_ru = lpSum([ru[0] for ru in ru_tensor])
total_ru = [weights_ru + activation_ru for _, activation_ru in ru_tensor]

return total_ru


class MpRuAggregation(Enum):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,

assert lp_problem.status == LpStatusOptimal, Logger.critical(
"No solution was found during solving the LP problem")
Logger.info(LpStatus[lp_problem.status])
Logger.info(f"ILP status: {LpStatus[lp_problem.status]}")

# Take the bitwidth index only if its corresponding indicator is one.
config = np.asarray(
Expand All @@ -82,7 +82,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
in layer_to_indicator_vars_mapping.values()]
).flatten()

if target_resource_utilization.bops < np.inf:
if target_resource_utilization.bops_restricted():
return search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(config)
else:
return config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],

"""
# Refinement is not supported for BOPs utilization for now...
if target_resource_utilization.bops < np.inf:
if target_resource_utilization.bops_restricted():
Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
return mp_solution

Expand Down
3 changes: 1 addition & 2 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def core_runner(in_model: Any,
f'Mixed Precision has overwrite bit-width configuration{core_config.mixed_precision_config.configuration_overwrite}')
bit_widths_config = core_config.mixed_precision_config.configuration_overwrite

if (target_resource_utilization.activation_memory < np.inf or
target_resource_utilization.total_memory < np.inf):
if target_resource_utilization.activation_restricted() or target_resource_utilization.total_mem_restricted():
Logger.warning(
f"Running mixed precision for activation compression, please note this feature is experimental and is "
f"subject to future changes. If you encounter an issue, please open an issue in our GitHub "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class BaseManualBitWidthSelectionTest(MixedPrecisionActivationBaseTest):
def create_feature_network(self, input_shape):
return NetForBitSelection(input_shape)

def get_mp_core_config(self):
@staticmethod
def get_mp_core_config():
qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE,
relu_bound_to_power_of_2=False, weights_bias_correction=True,
input_scaling=False, activation_channel_equalization=False)
Expand All @@ -92,6 +93,7 @@ def get_core_configs(self):
core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths)
return {"mixed_precision_activation_model": core_config}


class ManualBitWidthByLayerTypeTest(BaseManualBitWidthSelectionTest):
"""
This test check the manual bit width configuration.
Expand Down Expand Up @@ -159,10 +161,8 @@ def __init__(self, unit_test, filters, bit_widths):
for filter, bit_width in zip(filters, bit_widths):
self.layer_names.update({filter.node_name: bit_width})


super().__init__(unit_test)


def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):
# in the compare we need bit_widths to be a list
bit_widths = [self.bit_widths] if not isinstance(self.bit_widths, list) else self.bit_widths
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info:
raise NotImplementedError

def verify_config(self, result_config, expected_config):
self.unit_test.assertTrue(all(result_config == expected_config))
self.unit_test.assertTrue(all(result_config == expected_config),
f"Configuration mismatch: expected {expected_config} but got {result_config}.")


class MixedPrecisionActivationSearch8Bit(MixedPrecisionActivationBaseTest):
Expand Down
Loading