From fd97af0bd63f89cbc1b0b150a8137024d3846e7f Mon Sep 17 00:00:00 2001 From: Geet Sethi Date: Wed, 3 Jan 2024 10:59:48 -0800 Subject: [PATCH] MCH LRU Eviction Policy (#1608) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1608 Pull Request resolved: https://github.com/pytorch/torchrec/pull/1598 LRU eviction policy with user-variable decay exponent (e.g. decay_exponent=1 is LRU with linear distance). Reviewed By: dstaay-fb Differential Revision: D52100910 fbshipit-source-id: f9c3d7bbd0639566d12d9a71ecc59f385a3f3b16 --- torchrec/modules/mc_modules.py | 135 +++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index 5f5826bf5..0922b4628 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -440,6 +440,141 @@ def update_metadata_and_generate_eviction_scores( return evicted_indices, selected_new_indices +class LRU_EvictionPolicy(MCHEvictionPolicy): + def __init__( + self, + decay_exponent: float = 1.0, + threshold_filtering_func: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Union[float, torch.Tensor]]] + ] = None, # experimental + ) -> None: + super().__init__( + metadata_info=[ + MCHEvictionPolicyMetadataInfo( + metadata_name="last_access_iter", + is_mch_metadata=True, + is_history_metadata=True, + ), + ], + threshold_filtering_func=threshold_filtering_func, + ) + self._decay_exponent = decay_exponent + + @property + def metadata_info(self) -> List[MCHEvictionPolicyMetadataInfo]: + return self._metadata_info + + def record_history_metadata( + self, + current_iter: int, + incoming_ids: torch.Tensor, + history_metadata: Dict[str, torch.Tensor], + ) -> None: + history_last_access_iter = history_metadata["last_access_iter"] + history_last_access_iter[:] = current_iter + + def coalesce_history_metadata( + self, + current_iter: int, + history_metadata: Dict[str, torch.Tensor], + unique_ids_counts: torch.Tensor, + unique_inverse_mapping: torch.Tensor, + additional_ids: Optional[torch.Tensor] = None, + threshold_mask: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + coalesced_history_metadata: Dict[str, torch.Tensor] = {} + history_last_access_iter = history_metadata["last_access_iter"] + if additional_ids is not None: + history_last_access_iter = torch.cat( + [ + history_last_access_iter, + torch.full_like(additional_ids, current_iter), + ] + ) + coalesced_history_metadata["last_access_iter"] = torch.zeros_like( + unique_ids_counts + ).scatter_reduce_( + 0, + unique_inverse_mapping, + history_last_access_iter, + reduce="amax", + include_self=False, + ) + if threshold_mask is not None: + coalesced_history_metadata["last_access_iter"] = coalesced_history_metadata[ + "last_access_iter" + ][threshold_mask] + return coalesced_history_metadata + + def update_metadata_and_generate_eviction_scores( + self, + current_iter: int, + mch_size: int, + coalesced_history_argsort_mapping: torch.Tensor, + coalesced_history_sorted_unique_ids_counts: torch.Tensor, + coalesced_history_mch_matching_elements_mask: torch.Tensor, + coalesced_history_mch_matching_indices: torch.Tensor, + mch_metadata: Dict[str, torch.Tensor], + coalesced_history_metadata: Dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + mch_last_access_iter = mch_metadata["last_access_iter"] + + # sort coalesced history metadata + coalesced_history_metadata["last_access_iter"].copy_( + coalesced_history_metadata["last_access_iter"][ + coalesced_history_argsort_mapping + ] + ) + coalesced_history_sorted_uniq_ids_last_access_iter = coalesced_history_metadata[ + "last_access_iter" + ] + + # update metadata for matching ids + mch_last_access_iter[ + coalesced_history_mch_matching_indices + ] = coalesced_history_sorted_uniq_ids_last_access_iter[ + coalesced_history_mch_matching_elements_mask + ] + + # incoming non-matching ids + new_sorted_uniq_ids_last_access = ( + coalesced_history_sorted_uniq_ids_last_access_iter[ + ~coalesced_history_mch_matching_elements_mask + ] + ) + + # TODO: find cleaner way to avoid last element of zch + mch_last_access_iter[mch_size - 1] = current_iter + merged_access_iter = torch.cat( + [ + mch_last_access_iter, + new_sorted_uniq_ids_last_access, + ] + ) + # lower scores are evicted first. + merged_eviction_scores = torch.neg( + torch.pow( + current_iter - merged_access_iter + 1, + self._decay_exponent, + ) + ) + + # calculate evicted and replacement indices + ( + evicted_indices, + selected_new_indices, + ) = self._compute_selected_eviction_and_replacement_indices( + mch_size, + merged_eviction_scores, + ) + + mch_last_access_iter[evicted_indices] = new_sorted_uniq_ids_last_access[ + selected_new_indices + ] + + return evicted_indices, selected_new_indices + + class DistanceLFU_EvictionPolicy(MCHEvictionPolicy): def __init__( self,