diff --git a/src/xlora/xlora.py b/src/xlora/xlora.py index 527cc6b..5e7fe55 100644 --- a/src/xlora/xlora.py +++ b/src/xlora/xlora.py @@ -137,8 +137,8 @@ def hook(module, *args, **kwargs) -> None: model_peft.internal_xlora_scalings = torch.full( # type: ignore (payload.batch_size, payload.seq_len, xlora_classifier.n_layers, xlora_classifier.n_classes), - payload.override_scaling_pass_value, # requires_grad=True - ) # TODO(EricLBuehler): is the requires_grad=True necessary? + payload.override_scaling_pass_value, + ) return diff --git a/src/xlora/xlora_classifier.py b/src/xlora/xlora_classifier.py index 1a174e9..c108ef6 100644 --- a/src/xlora/xlora_classifier.py +++ b/src/xlora/xlora_classifier.py @@ -139,14 +139,6 @@ def forward( model: PeftModel = self.model # type: ignore with torch.no_grad(): with model.disable_adapter(): - # TODO(EricLBuehler): Pending removal following analysis - """ - for module in model.base_model.modules(): - if isinstance(module.forward.__self__, xLoRALayer): - inst = module.forward.__self__ - inst.disabled = True # Disable it - """ - kwargs["output_hidden_states"] = True kwargs["return_dict"] = True @@ -162,64 +154,11 @@ def forward( **kwargs, ) - # TODO(EricLBuehler): Pending removal following analysis - """ - # Enable the xLoRALayers - for module in model.base_model.modules(): - if isinstance(module.forward.__self__, xLoRALayer): - inst = module.forward.__self__ - inst.disabled = False # Disable it - """ - hidden_states = result.hidden_states # type: ignore assert hidden_states is not None hidden_state = hidden_states[-1] # Get the last hidden state - ### Calculate the sequence lengths - - # TODO(all): Pending removal following analysis - """ - # hidden_state=[batch_size, seq_len, hidden_size] - if self.config.stop_token_id is None: # Calculate via attention mask - if input_ids is not None: - assert attention_mask is not None, ( - "Stop token id was not provided, so sequence length calculation via attention mask was attempted" - + "but the attention mask was not given" - ) - sequence_lengths: Union[int, torch.Tensor] = torch.eq(attention_mask, 0).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(hidden_state.device) # type: ignore - else: - sequence_lengths = -1 - else: # Calculate via stop token id - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.stop_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(hidden_state.device) # type: ignore - else: - sequence_lengths = -1 - - # AFTER THIS: hidden_state=[batch_size, hidden_size] - if self.config.use_mean_pool: - assert isinstance(sequence_lengths, torch.Tensor) - max_length = hidden_state.shape[1] - mask = torch.arange(max_length).expand(len(sequence_lengths), max_length).to( - hidden_state.device - ) < sequence_lengths.unsqueeze(1) - - # Mask the hidden_states - masked_hidden_state = hidden_state * mask.unsqueeze(-1) - - # Sum across the sequence length and divide by actual sequence length - summed = torch.sum(masked_hidden_state, dim=1) - hidden_state = summed / sequence_lengths.unsqueeze(1) - else: - # Get it for the last token - hidden_state = hidden_state[torch.arange(batch_size, device=hidden_state.device), sequence_lengths] - """ - ### Classifier run # hidden_state=[batch_size, seq_len, hidden_size] for layer in self.inner: diff --git a/src/xlora/xlora_insertion.py b/src/xlora/xlora_insertion.py index f75859d..40cbd69 100644 --- a/src/xlora/xlora_insertion.py +++ b/src/xlora/xlora_insertion.py @@ -20,7 +20,7 @@ class xLoRALayer: xLoRA algorithm. """ - __slots__ = {"model", "target_forward", "target", "layer_number", "disabled", "config"} + __slots__ = {"model", "target_forward", "target", "layer_number", "config"} def __init__( self, @@ -34,7 +34,6 @@ def __init__( self.target_forward = target_forward self.target = target self.layer_number = layer_number - self.disabled = False # TODO(EricLBuehler): Pending removal following analysis self.config = config @staticmethod