-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mark_step for encoder layers #650
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
c2f7554
to
5efc637
Compare
@@ -135,7 +135,7 @@ def flatten(in_list): | |||
return list(itertools.chain(*in_list)) | |||
|
|||
|
|||
def get_decoder_layer_suffix(model_type): | |||
def get_target_layer_suffix(model_type) -> list[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change to something like "get_target_layer_suffix_list" to reflect return type
n=1, | ||
counter=None): | ||
def modify_model_layers(module: torch.nn.Module, | ||
suffix: list[str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change to something like "suffix_list" to reflect the underlying type
Hi @michalkuligowski, thanks for your review but I need to create a new PR #669 to address your comments. Can you help close this one and take a look at the new one? |
This is a updated version from #650. Coupled with [Use FusedSDPA for MllamaVisionSdpaAttention #620], these two issues arising when running llama3.2 vision model can be resolved: GC fail when batchsize>1 on Gaudi3. Increased device memory consumption with Torch 2.5 compared to Torch 2.4. --------- Signed-off-by: yan ma <[email protected]> Co-authored-by: yisonzhu <[email protected]>
Coupled with [Use FusedSDPA for MllamaVisionSdpaAttention #620], these two issues arising when running llama3.2 vision model can be resolved: