Skip to content

Commit

Permalink
remove fair path
Browse files Browse the repository at this point in the history
  • Loading branch information
Artyom Kozhevnikov committed Aug 22, 2023
1 parent fea4d10 commit 8e9ed58
Show file tree
Hide file tree
Showing 44 changed files with 0 additions and 475 deletions.
70 changes: 0 additions & 70 deletions sonar/models/sonar_speech/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,76 +26,6 @@ class SonarSpeechEncoderLoader(
):
"""Loads sonar models."""

@finaloverride
def _upgrade_checkpoint(
self, checkpoint: Mapping[str, Any], config: SonarSpeechEncoderConfig
) -> Mapping[str, Any]:
state_dict = checkpoint["model"]

# Check if we have a fairseq2 checkpoint.
if "encoder_frontend.model_dim_proj" in state_dict:
return checkpoint

del state_dict["encoder.w2v_model.mask_emb"]

key_map = self._fairseq_key_map(config)

return upgrade_fairseq_checkpoint(checkpoint, key_map)

@staticmethod
def _fairseq_key_map(config: SonarSpeechEncoderConfig) -> Dict[str, str]:
key_map = {
# fmt: off
# encoder
r"^encoder.w2v_model.layer_norm\.": r"encoder_frontend.post_extract_layer_norm.",
r"^encoder.w2v_model.post_extract_proj\.": r"encoder_frontend.model_dim_proj.",
r"^encoder.w2v_model.encoder\.pos_conv\.0\.": r"encoder_frontend.pos_encoder.conv.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.batch_norm\.": r"encoder.layers.\1.conv.batch_norm.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.depthwise_conv\.": r"encoder.layers.\1.conv.depthwise_conv.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.layer_norm\.": r"encoder.layers.\1.conv_layer_norm.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv1\.": r"encoder.layers.\1.conv.pointwise_conv1.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.conv_module\.pointwise_conv2\.": r"encoder.layers.\1.conv.pointwise_conv2.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.layer_norm\.": r"encoder.layers.\1.ffn\2_layer_norm.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_1\.": r"encoder.layers.\1.ffn\2.inner_proj.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.ffn(1|2)\.w_2\.": r"encoder.layers.\1.ffn\2.output_proj.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"encoder.layers.\1.self_attn_layer_norm.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_q\.": r"encoder.layers.\1.self_attn.q_proj.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_k\.": r"encoder.layers.\1.self_attn.k_proj.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_v\.": r"encoder.layers.\1.self_attn.v_proj.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_out\.": r"encoder.layers.\1.self_attn.output_proj.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.linear_pos\.": r"encoder.layers.\1.self_attn.sdpa.r_proj.",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_u": r"encoder.layers.\1.self_attn.sdpa.u_bias",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.self_attn\.pos_bias_v": r"encoder.layers.\1.self_attn.sdpa.v_bias",
r"^encoder.w2v_model.encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.layer_norm.",
r"^encoder.w2v_model.encoder\.layer_norm\.": r"encoder.layer_norm.",

r"^decoder\.embed_tokens\.": r"encoder_pooler.decoder_frontend.embed.",
r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"encoder_pooler.decoder.layers.\1.self_attn_layer_norm.",
r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder_pooler.decoder.layers.\1.self_attn.output_proj.",
r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"encoder_pooler.decoder.layers.\1.self_attn.",
r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn.output_proj.",
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"encoder_pooler.decoder.layers.\1.encoder_decoder_attn.",
r"^decoder\.layers\.([0-9]+)\.fc1\.": r"encoder_pooler.decoder.layers.\1.ffn.inner_proj.",
r"^decoder\.layers\.([0-9]+)\.fc2\.": r"encoder_pooler.decoder.layers.\1.ffn.output_proj.",
r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder_pooler.decoder.layers.\1.ffn_layer_norm.",

r"^decoder\.embed_out": r"encoder_pooler.projection_out.weight",
# fmt: on
}

# In normal circumstances, we should never encounter a `LayerNorm` when
# `use_conformer` is `True`. Unfortunately, the w2v-BERT pretraining in
# fairseq was accidentally run with a pre-LN encoder, and ended up with
# a redundant `LayerNorm` right after the Conformer blocks. We mitigate
# that issue here by moving that `LayerNorm` to the sonar block.
if config.w2v2_encoder_config.use_conformer:
key_map.update(
{r"^encoder.w2v_model.encoder\.layer_norm\.": r"layer_norm."}
)

return key_map


load_sonar_speech_model = SonarSpeechEncoderLoader(
asset_store, download_manager, create_sonar_speech_encoder_model, sonar_speech_archs
Expand Down
104 changes: 0 additions & 104 deletions sonar/models/sonar_text/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,55 +34,6 @@ class SonarTextEncoderLoader(
):
"""Loads SonarEncoder models."""

@finaloverride
def _upgrade_checkpoint(
self, checkpoint: Mapping[str, Any], config: SonarTextEncoderConfig
) -> Mapping[str, Any]:
# Return directly if found fairseq2 attribute in state dict
if (
"model" in checkpoint.keys()
and "encoder_frontend.embed.weight" in checkpoint["model"].keys()
):
return checkpoint

state_dict = checkpoint["state_dict"]

try:
del state_dict["version"]
del state_dict["embed_positions._float_tensor"]
except:
pass
# del state_dict["decoder.version"]

out_checkpoint = {"model": state_dict}
out_checkpoint = upgrade_fairseq_checkpoint(
out_checkpoint, self._fairseq_key_map()
)
embeds = checkpoint["embed_tokens"].weight
# # The embedding positions of the control tokens do not match the
# # SentencePiece model of the tokenizer.
with torch.inference_mode():
# (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
out_checkpoint["encoder_frontend.embed.weight"] = embeds

return out_checkpoint

@staticmethod
def _fairseq_key_map() -> Dict[str, str]:
return {
r"layers\.([0-9]+)\.self_attn\.q_proj\.": r"encoder.layers.\1.self_attn.q_proj.",
r"layers\.([0-9]+)\.self_attn\.v_proj\.": r"encoder.layers.\1.self_attn.v_proj.",
r"layers\.([0-9]+)\.self_attn\.k_proj\.": r"encoder.layers.\1.self_attn.k_proj.",
r"layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder.layers.\1.self_attn.output_proj.",
r"layers\.([0-9]+)\.self_attn_layer_norm\.": r"encoder.layers.\1.self_attn_layer_norm.",
r"layers\.([0-9]+)\.fc1\.": r"encoder.layers.\1.ffn.inner_proj.",
r"layers\.([0-9]+)\.fc2\.": r"encoder.layers.\1.ffn.output_proj.",
r"layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.ffn_layer_norm.",
r"embed_tokens\.": r"encoder_frontend.embed.",
# fmt: on
}


load_sonar_text_encoder_model = SonarTextEncoderLoader(
asset_store,
Expand All @@ -98,61 +49,6 @@ class SonarTextDecoderLoader(
):
"""Loads SonarEncoder models."""

@finaloverride
def _upgrade_checkpoint(
self, checkpoint: Mapping[str, Any], config: SonarTextDecoderConfig
) -> Mapping[str, Any]:
# Return directly if found fairseq2 attribute in state dict
if (
"model" in checkpoint.keys()
and "decoder_frontend.embed.weight" in checkpoint["model"].keys()
):
return checkpoint

state_dict = checkpoint["state_dict"]
try:
del state_dict["version"]
del state_dict["embed_positions._float_tensor"]
except:
pass

out_checkpoint = {"model": state_dict}
out_checkpoint = upgrade_fairseq_checkpoint(
out_checkpoint, self._fairseq_key_map()
)
embeds = out_checkpoint["model"]["decoder_frontend.embed.weight"]
# # The embedding positions of the control tokens do not match the
# # SentencePiece model of the tokenizer.
with torch.inference_mode():
# (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]
out_checkpoint["model"]["decoder_frontend.embed.weight"] = embeds
return out_checkpoint

@staticmethod
def _fairseq_key_map() -> Dict[str, str]:
return {
r"layers\.([0-9]+)\.self_attn\.k_proj\.": r"decoder.layers.\1.self_attn.k_proj.",
r"layers\.([0-9]+)\.self_attn\.v_proj\.": r"decoder.layers.\1.self_attn.v_proj.",
r"layers\.([0-9]+)\.self_attn\.q_proj\.": r"decoder.layers.\1.self_attn.q_proj.",
r"layers\.([0-9]+)\.self_attn.out_proj\.": r"decoder.layers.\1.self_attn.output_proj.",
r"layers\.([0-9]+)\.self_attn_layer_norm\.": r"decoder.layers.\1.self_attn_layer_norm.",
r"layers\.([0-9]+).ffn\.inner_proj\.": r"decoder.layers.\1.ffn.inner_proj.",
r"layers\.([0-9]+).ffn\.output_proj\.": r"decoder.layers.\1.ffn.output_proj.",
r"layers\.([0-9]+)\.ffn_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.",
r"layers\.([0-9]+).encoder_attn\.k_proj\.": r"decoder.layers.\1.encoder_decoder_attn.k_proj.",
r"layers\.([0-9]+).encoder_attn\.v_proj\.": r"decoder.layers.\1.encoder_decoder_attn.v_proj.",
r"layers\.([0-9]+).encoder_attn\.q_proj\.": r"decoder.layers.\1.encoder_decoder_attn.q_proj.",
r"layers\.([0-9]+).encoder_attn\.out_proj\.": r"decoder.layers.\1.encoder_decoder_attn.output_proj.",
r"layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"decoder.layers.\1.encoder_decoder_attn_layer_norm.",
r"layers\.([0-9]+)\.fc1\.": r"decoder.layers.\1.ffn.inner_proj.",
r"layers\.([0-9]+)\.fc2\.": r"decoder.layers.\1.ffn.output_proj.",
r"layers\.([0-9]+)\.final_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.",
r"output_projection.": r"final_proj.",
r"embed_tokens.": r"decoder_frontend.embed.",
r"layer_norm.": r"decoder.layer_norm.",
}


load_sonar_text_decoder_model = SonarTextDecoderLoader(
asset_store,
Expand Down
9 changes: 0 additions & 9 deletions sonar/store/cards/[email protected]

This file was deleted.

9 changes: 0 additions & 9 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

7 changes: 0 additions & 7 deletions sonar/store/cards/[email protected]

This file was deleted.

Loading

0 comments on commit 8e9ed58

Please sign in to comment.