Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeepSeek V2 config #1293

Merged
merged 1 commit into from
Feb 21, 2025
Merged

Add DeepSeek V2 config #1293

merged 1 commit into from
Feb 21, 2025

Conversation

RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 21, 2025

Description

Add DeepSeek V2 config for fast development:

Tests

V2-16b:

python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=/tmp/ run_name=deepsee_training per_device_batch_size=4 enable_checkpointing=false model_name=deepseek2-16b ici_fsdp_parallelism=4 steps=5 async_checkpointing=false tokenizer_path=deepseek-ai/DeepSeek-V2-Lite attention=dot_product dtype=bfloat16 weight_dtype=bfloat16 dataset_type=synthetic sparse_matmul=True megablox=True 

Small config of V2-236b:

python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=/tmp/ run_name=deepsee_training per_device_batch_size=4 enable_checkpointing=false model_name=deepseek2-236b ici_fsdp_parallelism=4 steps=5 async_checkpointing=false tokenizer_path=deepseek-ai/DeepSeek-V2 attention=dot_product dtype=bfloat16 weight_dtype=bfloat16 dataset_type=synthetic sparse_matmul=True megablox=True 

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@RissyRan
Copy link
Collaborator Author

RissyRan commented Feb 21, 2025

Hi @gagika, I will meet error running v2-lite config as the q_lora_rank=0. It shows the embedding dim is different to H from here. The embedding_dim=64, while inputs's last dim is 0. Full logs. Did I miss anything here?

If this is expected in the current code, we may need a workaround, otherwise, this branch could be useless.

@gagika
Copy link
Collaborator

gagika commented Feb 21, 2025

Hi @gagika, I will meet error running v2-lite config as the q_lora_rank=0. It shows the embedding dim is different to H from here. The embedding_dim=64, while inputs's last dim is 0. Full logs. Did I miss anything here?

If this is expected in the current code, we may need a workaround, otherwise, this branch could be useless.

Hi Ran, that brunch wasn't tested, there is a bug, could you please change the features=(self.num_query_heads, self.head_dim), to features=(self.num_query_heads, self.qk_head_dim),

    if self.q_lora_rank == 0:
      # Standard Q projection (without LoRA).
      self.query_proj = DenseGeneral(
          features=(self.num_query_heads, self.qk_head_dim),
          axis=-1,
          kernel_init=self.kernel_init,
          kernel_axes=("embed", "q_heads", "kv"),
          dtype=self.dtype,
          weight_dtype=self.weight_dtype,
          name="query",
          quant=self.quant,
          matmul_precision=self.config.matmul_precision,
      )

Please add the fix in your PR or I can sent a PR with that one line fix, if you prefer that way.

@RissyRan
Copy link
Collaborator Author

Hi @gagika, I will meet error running v2-lite config as the q_lora_rank=0. It shows the embedding dim is different to H from here. The embedding_dim=64, while inputs's last dim is 0. Full logs. Did I miss anything here?
If this is expected in the current code, we may need a workaround, otherwise, this branch could be useless.

Hi Ran, that brunch wasn't tested, there is a bug, could you please change the features=(self.num_query_heads, self.head_dim), to features=(self.num_query_heads, self.qk_head_dim),

    if self.q_lora_rank == 0:
      # Standard Q projection (without LoRA).
      self.query_proj = DenseGeneral(
          features=(self.num_query_heads, self.qk_head_dim),
          axis=-1,
          kernel_init=self.kernel_init,
          kernel_axes=("embed", "q_heads", "kv"),
          dtype=self.dtype,
          weight_dtype=self.weight_dtype,
          name="query",
          quant=self.quant,
          matmul_precision=self.config.matmul_precision,
      )

Please add the fix in your PR or I can sent a PR with that one line fix, if you prefer that way.

Cool! Let me add this one line change here.

@RissyRan
Copy link
Collaborator Author

Both tests are working now. V2lite-16b with q_lora_rank=0, and v2-236b with q_lora_rank=1536.

Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, Thanks Ran!

@copybara-service copybara-service bot merged commit 8632dcb into main Feb 21, 2025
18 checks passed
@copybara-service copybara-service bot deleted the deepseek_v2 branch February 21, 2025 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants