Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM for Mistral-Nemo-Base-2407 with NeMo + ThunderFX for input sequence lengths working for NeMo Eager #1475

Open
mpatel31415 opened this issue Nov 26, 2024 · 6 comments
Assignees
Labels
mixology Issues that the mixology team has surfaced nemo Issues needed to support NVIDIA NeMo models.

Comments

@mpatel31415
Copy link
Contributor

mpatel31415 commented Nov 26, 2024

🐛 Bug

When running Mistral-Nemo-Base-2407 with NeMo + ThunderFX even for small sequence lengths we get OOM error.

To Reproduce

The error is present on 1xH100.

Dockerfile used (I build it yesterday and I'm not sure yet how nemo:dev images are versioned, so I can't provide its detailed version):

FROM nvcr.io/nvidia/nemo:dev
ARG NVFUSER_REPO=git+https://github.com/NVIDIA/Fuser.git
ARG THUNDER_REPO=git+https://github.com/Lightning-AI/lightning-thunder.git

# Add cloned NeMo latest code
RUN git clone --recursive https://github.com/NVIDIA/NeMo.git /NeMo_cloned
RUN (cd /NeMo_cloned && python -m pip install .)


# Install requirements needed for NeMo, Thunder and NVFUser.
# We must install them in such compilated way because otherwise Thunder is not 
# updated and we are not able to use the latest version. 
RUN python -m pip install -r /NeMo_cloned/requirements/requirements_lightning.txt && \
    python -m pip install --upgrade ${NVFUSER_REPO}  && \
    python -m pip install --upgrade ${THUNDER_REPO} && \
    python -m pip install --upgrade --no-deps --force-reinstall ${NVFUSER_REPO} && \
    python -m pip install --upgrade --no-deps --force-reinstall ${THUNDER_REPO}
 
# Install Mixology requirements (this can be skipped, so I'm commenting it out)
# COPY requirements/mixology.txt mixology_requirements.txt
# RUN pip install --upgrade -r mixology_requirements.txt

Inside docker container please run:

model=mistralai/Mistral-Nemo-Base-2407
# Download the model (you might need to set HF_TOKEN and agree on the website to terms of use of this model)
huggingface-cli download $model --local-dir checkpoints/$model --cache-dir checkpoints/$model 
# Run benchmark
python bench_targets/llm_peft/_nemo.py --model checkpoints/$model --mbs 1 --seq-length 2048 --jit-backend thunder

Script bench_targets/llm_peft/_nemo.py can be obtained from internal Gitlab from akoumparouli/nemo_bench. You can contant me or @tfogal if you have any questions.

You can check that the command below works:

python bench_targets/llm_peft/_nemo.py --model checkpoints/$model --mbs 1 --seq-length 2048 --jit-backend eager

Expected behavior

We should be able to run at least the same sequence length as NeMo Eager.

Environment

cc @tfogal

@IvanYashchuk IvanYashchuk added nemo Issues needed to support NVIDIA NeMo models. mixology Issues that the mixology team has surfaced labels Nov 26, 2024
@mpatel31415
Copy link
Contributor Author

As Tom suggested. the issue might be caused by this PR: #1400

@IvanYashchuk
Copy link
Collaborator

@tfogal, why do you think that pull request could cause problems?

@tfogal
Copy link
Collaborator

tfogal commented Nov 26, 2024

@tfogal, why do you think that pull request could cause problems?

I had git bisected and found that 052bac3 (the commit right before it) works well.

I am confused as to what in there is an issue, though; experimenting now, but my initial theory about pruning too aggressively doesn't hold water, so I'm not so sure at the moment.
edit: Phi-3 doesn't change at all before/after #1400; I somehow managed to see a difference in memory usage with mistral-nemo that went away when I tested again. I will need to dig in more next week.

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Nov 27, 2024

I had git bisected and found that 052bac3 (the commit right before it) works well.

@tfogal, what command did you run for bisection? I get an OOM error on H100 with the commit 052bac3 and 2k sequence length as in the issue description

python bench_targets/llm_peft/_nemo.py --model=mistralai/Mistral-Nemo-Base-2407 --mbs 1 --seq-length 2048 --jit-backend thunder

The same OOM error is with the linked commit, and before that, but a different error appears with the commit right after (a617503), which breaks memory consumption because a deepcopy of an fx.GraphModule also creates copies for all parameters.

With the 1k sequence here are the memory consumptions:

c9bbc5e0: 66880730624 bytes
052bac: 66880730624 bytes
a61750: OOM

@kshitij12345, looks like the problem was introduced in #1400.

@IvanYashchuk IvanYashchuk assigned kshitij12345 and unassigned tfogal Nov 27, 2024
@kshitij12345
Copy link
Collaborator

kshitij12345 commented Nov 27, 2024

Interestingly, copy.deepcopy only leads to copying of parameters when torch._dynamo.config.inline_inbuilt_nn_modules=False, and since PR it defaults to True (i.e. from release PyTorch 2.5) . Following snippet demonstrates the difference in generated graphs and memory usage with copy.deepcopy on GraphModule.

import torch
import copy

# copy.deepcopy leads to more memory usage (as modules with parameters are saved in GraphModule).
# Eg.
# class GraphModule(torch.nn.Module):
#     def forward(self, L_args_0_: "f32[1024, 1024]"):
#         l_args_0_ = L_args_0_
        
#          # File: /home/kkalambarkar/git/pytorch/torch/_dynamo/external_utils.py:31 in inner, code: return fn(*args, **kwargs)
#         fn_0: "f32[1024, 1024]" = self.fn_0(l_args_0_);  l_args_0_ = None
#         fn_1: "f32[1024, 1024]" = self.fn_1(fn_0);  fn_0 = None
#         return (fn_1,)
torch._dynamo.config.inline_inbuilt_nn_modules=False

# class GraphModule(torch.nn.Module):
#     def forward(self, L_fn_modules_0_parameters_weight_: "f32[1024, 1024]", L_fn_modules_0_parameters_bias_: "f32[1024]", L_args_0_: "f32[1024, 1024]", L_fn_modules_1_parameters_weight_: "f32[1024, 1024]", L_fn_modules_1_parameters_bias_: "f32[1024]"):
#         l_fn_modules_0_parameters_weight_ = L_fn_modules_0_parameters_weight_
#         l_fn_modules_0_parameters_bias_ = L_fn_modules_0_parameters_bias_
#         l_args_0_ = L_args_0_
#         l_fn_modules_1_parameters_weight_ = L_fn_modules_1_parameters_weight_
#         l_fn_modules_1_parameters_bias_ = L_fn_modules_1_parameters_bias_
        
#          # File: /home/kkalambarkar/git/pytorch/torch/_dynamo/external_utils.py:31 in inner, code: return fn(*args, **kwargs)
#         input_1: "f32[1024, 1024]" = torch._C._nn.linear(l_args_0_, l_fn_modules_0_parameters_weight_, l_fn_modules_0_parameters_bias_);  l_args_0_ = l_fn_modules_0_parameters_weight_ = l_fn_modules_0_parameters_bias_ = None
#         input_2: "f32[1024, 1024]" = torch._C._nn.linear(input_1, l_fn_modules_1_parameters_weight_, l_fn_modules_1_parameters_bias_);  input_1 = l_fn_modules_1_parameters_weight_ = l_fn_modules_1_parameters_bias_ = None
#         return (input_2,)
torch._dynamo.config.inline_inbuilt_nn_modules=True

gm_copy = None

def backend(gm, sample_args):
    global gm_copy
    gm_copy = copy.deepcopy(gm)

    gm.print_readable()

    return gm

with torch.device("cuda"):
    models = torch.nn.Sequential(torch.nn.Linear(1024, 1024), torch.nn.Linear(1024, 1024))

opt_model = torch.compile(models, backend=backend)

x = torch.randn(1024, 1024, device="cuda")
opt_model(x)

print(torch.cuda.memory_allocated())  # no_inline = 29507584, inline = 21110784
del opt_model, models
print(torch.cuda.memory_allocated())  # no_inline = 29507584, inline = 12713984
del gm_copy
print(torch.cuda.memory_allocated())  # no_inline = 29507584, inline = 12713984

@IvanYashchuk
Copy link
Collaborator

Cool, so as long as torch._dynamo.config.inline_inbuilt_nn_modules=True is used there's nothing to fix on the Thunder side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mixology Issues that the mixology team has surfaced nemo Issues needed to support NVIDIA NeMo models.
Projects
None yet
Development

No branches or pull requests

5 participants