Skip to content

Commit

Permalink
Merge branch 'main' into dev_release
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Nov 6, 2023
2 parents 4307676 + 16d73b3 commit fc32f6d
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/otter_ai/models/otter/modeling_otter.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,16 @@ def init_weights(self):
for name, param in self.lang_encoder.named_parameters():
param.requires_grad = True

# Freeze all parameters in vision encoder
if "train_vision_encoder" in self.config.__dict__ and self.config.train_vision_encoder is True:
for param in self.vision_encoder.parameters():
param.requires_grad = True

# Freeze all parameters in lang encoders except gated_cross_attn_layers
if "train_lang_encoder" in self.config.__dict__ and self.config.train_lang_encoder is True:
for name, param in self.lang_encoder.named_parameters():
param.requires_grad = True

if "lora_config" in self.config.__dict__:
# Use another logic to unfreeze gated_cross_attn_layers and perceivers
master_print(f"LoRA trainable param: {(sum(param.numel() for name, param in self.lang_encoder.named_parameters() if 'lora' in name)) / 1e6:.3f} M")
Expand Down

0 comments on commit fc32f6d

Please sign in to comment.