Skip to content

Commit

Permalink
Merge pull request #20 from EricLBuehler/cleaning
Browse files Browse the repository at this point in the history
Removal of commented sections
  • Loading branch information
EricLBuehler authored Feb 21, 2024
2 parents ed2a806 + af6fb66 commit 8bf75b3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 65 deletions.
4 changes: 2 additions & 2 deletions src/xlora/xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def hook(module, *args, **kwargs) -> None:

model_peft.internal_xlora_scalings = torch.full( # type: ignore
(payload.batch_size, payload.seq_len, xlora_classifier.n_layers, xlora_classifier.n_classes),
payload.override_scaling_pass_value, # requires_grad=True
) # TODO(EricLBuehler): is the requires_grad=True necessary?
payload.override_scaling_pass_value,
)

return

Expand Down
61 changes: 0 additions & 61 deletions src/xlora/xlora_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,6 @@ def forward(
model: PeftModel = self.model # type: ignore
with torch.no_grad():
with model.disable_adapter():
# TODO(EricLBuehler): Pending removal following analysis
"""
for module in model.base_model.modules():
if isinstance(module.forward.__self__, xLoRALayer):
inst = module.forward.__self__
inst.disabled = True # Disable it
"""

kwargs["output_hidden_states"] = True
kwargs["return_dict"] = True

Expand All @@ -162,64 +154,11 @@ def forward(
**kwargs,
)

# TODO(EricLBuehler): Pending removal following analysis
"""
# Enable the xLoRALayers
for module in model.base_model.modules():
if isinstance(module.forward.__self__, xLoRALayer):
inst = module.forward.__self__
inst.disabled = False # Disable it
"""

hidden_states = result.hidden_states # type: ignore

assert hidden_states is not None
hidden_state = hidden_states[-1] # Get the last hidden state

### Calculate the sequence lengths

# TODO(all): Pending removal following analysis
"""
# hidden_state=[batch_size, seq_len, hidden_size]
if self.config.stop_token_id is None: # Calculate via attention mask
if input_ids is not None:
assert attention_mask is not None, (
"Stop token id was not provided, so sequence length calculation via attention mask was attempted"
+ "but the attention mask was not given"
)
sequence_lengths: Union[int, torch.Tensor] = torch.eq(attention_mask, 0).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(hidden_state.device) # type: ignore
else:
sequence_lengths = -1
else: # Calculate via stop token id
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.stop_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(hidden_state.device) # type: ignore
else:
sequence_lengths = -1
# AFTER THIS: hidden_state=[batch_size, hidden_size]
if self.config.use_mean_pool:
assert isinstance(sequence_lengths, torch.Tensor)
max_length = hidden_state.shape[1]
mask = torch.arange(max_length).expand(len(sequence_lengths), max_length).to(
hidden_state.device
) < sequence_lengths.unsqueeze(1)
# Mask the hidden_states
masked_hidden_state = hidden_state * mask.unsqueeze(-1)
# Sum across the sequence length and divide by actual sequence length
summed = torch.sum(masked_hidden_state, dim=1)
hidden_state = summed / sequence_lengths.unsqueeze(1)
else:
# Get it for the last token
hidden_state = hidden_state[torch.arange(batch_size, device=hidden_state.device), sequence_lengths]
"""

### Classifier run
# hidden_state=[batch_size, seq_len, hidden_size]
for layer in self.inner:
Expand Down
3 changes: 1 addition & 2 deletions src/xlora/xlora_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class xLoRALayer:
xLoRA algorithm.
"""

__slots__ = {"model", "target_forward", "target", "layer_number", "disabled", "config"}
__slots__ = {"model", "target_forward", "target", "layer_number", "config"}

def __init__(
self,
Expand All @@ -34,7 +34,6 @@ def __init__(
self.target_forward = target_forward
self.target = target
self.layer_number = layer_number
self.disabled = False # TODO(EricLBuehler): Pending removal following analysis
self.config = config

@staticmethod
Expand Down

0 comments on commit 8bf75b3

Please sign in to comment.