-
Notifications
You must be signed in to change notification settings - Fork 11.4k
Llama-3_1-Nemotron-Ultra-253B-v1 support #12843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Thanks a lot for your contribution! I will try Llama-3_1-Nemotron-Ultra-253B-v1 and let you know shortly. I'm currently running convert_hf_to_gguf.py and everything is working great so far. |
@ymcki Something seams to unfortunately be broken. The output seams to just be random tokens. PromptWhat is the meaning of life? Response8!B"(1D<)<,4@3-A'3(<5,72A9.F62AC"%D08);E)6CDHCA0C.HC!%85>8DD(3!=&;48<"802=A,%0,6%D@0/'D<%(11@=&:.F0A)!91.#;2,&;) NoteI canceled token generation after a while as it likely would have continued generating garbage until reaching the context size. Steps to reproducerm -rf llama.cpp
git clone --recursive https://github.com/ymcki/llama.cpp.git
cd llama.cpp
python -m venv venv
venv/bin/pip install -r requirements.txt
cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release -j
venv/bin/python convert_hf_to_gguf.py /mradermacher/tmp/quant/Llama-3_1-Nemotron-Ultra-253B-v1 --outfile /dpool/Llama-3_1-Nemotron-Ultra-253B-v1.gguf
cd build/bin
./llama-quantize /dpool/Llama-3_1-Nemotron-Ultra-253B-v1.gguf /bpool/Llama-3_1-Nemotron-Ultra-253B-v1.Q4_K_M.gguf Q4_K_M
./llama-cli -m /bpool/Llama-3_1-Nemotron-Ultra-253B-v1.Q4_K_M.gguf Logroot@AI:/apool/llama.cpp/build/bin# ./llama-cli -m /bpool/Llama-3_1-Nemotron-Ultra-253B-v1.Q4_K_M.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
Device 1: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
build: 5099 (80af2e33) with cc (Debian 12.2.0-14) 12.2.0 for x86_64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 4090) - 23628 MiB free
llama_model_load_from_file_impl: using device CUDA1 (NVIDIA GeForce RTX 4090) - 23663 MiB free
llama_model_loader: loaded meta data with 34 key-value pairs and 648 tensors from /bpool/Llama-3_1-Nemotron-Ultra-253B-v1.Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = deci
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.name str = Llama_Nemotron_Ultra
llama_model_loader: - kv 3: general.version str = v1
llama_model_loader: - kv 4: general.finetune str = 3_1-Nemotron-Ultra
llama_model_loader: - kv 5: general.basename str = Llama
llama_model_loader: - kv 6: general.size_label str = 253B
llama_model_loader: - kv 7: general.license str = other
llama_model_loader: - kv 8: general.license.name str = nvidia-open-model-license
llama_model_loader: - kv 9: general.license.link str = https://www.nvidia.com/en-us/agreemen...
llama_model_loader: - kv 10: general.tags arr[str,4] = ["nvidia", "llama-3", "pytorch", "tex...
llama_model_loader: - kv 11: general.languages arr[str,1] = ["en"]
llama_model_loader: - kv 12: deci.rope.freq_base f32 = 500000.000000
llama_model_loader: - kv 13: deci.attention.head_count_kv arr[i32,162] = [8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 0, 0, ...
llama_model_loader: - kv 14: deci.attention.head_count arr[i32,162] = [128, 128, 128, 128, 128, 128, 128, 1...
llama_model_loader: - kv 15: deci.feed_forward_length arr[i32,162] = [5376, 10752, 16128, 16128, 16128, 16...
llama_model_loader: - kv 16: deci.block_count u32 = 162
llama_model_loader: - kv 17: deci.context_length u32 = 131072
llama_model_loader: - kv 18: deci.embedding_length u32 = 16384
llama_model_loader: - kv 19: deci.attention.layer_norm_rms_epsilon f32 = 0.000010
llama_model_loader: - kv 20: deci.attention.key_length u32 = 128
llama_model_loader: - kv 21: deci.attention.value_length u32 = 128
llama_model_loader: - kv 22: deci.vocab_size u32 = 128256
llama_model_loader: - kv 23: deci.rope.dimension_count u32 = 128
llama_model_loader: - kv 24: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 25: tokenizer.ggml.pre str = llama-bpe
llama_model_loader: - kv 26: tokenizer.ggml.tokens arr[str,128256] = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv 27: tokenizer.ggml.token_type arr[i32,128256] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 28: tokenizer.ggml.merges arr[str,280147] = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv 29: tokenizer.ggml.bos_token_id u32 = 128000
llama_model_loader: - kv 30: tokenizer.ggml.eos_token_id u32 = 128009
llama_model_loader: - kv 31: tokenizer.chat_template str = {{- bos_token }}{%- if messages[0]['r...
llama_model_loader: - kv 32: general.quantization_version u32 = 2
llama_model_loader: - kv 33: general.file_type u32 = 15
llama_model_loader: - type f32: 147 tensors
llama_model_loader: - type q4_K: 428 tensors
llama_model_loader: - type q6_K: 73 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type = Q4_K - Medium
print_info: file size = 140.56 GiB (4.76 BPW)
load: special tokens cache size = 256
load: token to piece cache size = 0.7999 MB
print_info: arch = deci
print_info: vocab_only = 0
print_info: n_ctx_train = 131072
print_info: n_embd = 16384
print_info: n_layer = 162
print_info: n_head = [128, 128, 128, 128, 128, 128, 128, 128, 128, 0, 0, 0, 0, 128, 128, 128, 128, 128, 0, 0, 0, 0, 0, 0, 128, 128, 128, 0, 0, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0, 128, 128, 128, 0, 128, 0, 0, 0, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0, 0, 0, 128, 128, 128, 128, 0, 0, 128, 128, 128, 128, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 128, 128, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 128, 0, 128, 128, 128, 128, 128, 128, 128, 128]
print_info: n_head_kv = [8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 0, 0, 0, 8, 8, 8, 8, 8, 0, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 8, 0, 0, 0, 8, 8, 8, 0, 8, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 8, 0, 0, 8, 8, 8, 8, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 8, 8, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 0, 8, 8, 8, 8, 8, 8, 8, 8]
print_info: n_rot = 128
print_info: n_swa = 0
print_info: n_swa_pattern = 1
print_info: n_embd_head_k = 128
print_info: n_embd_head_v = 128
print_info: n_gqa = [16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 16, 16, 16, 0, 0, 0, 0, 0, 16, 16, 16, 16, 0, 0, 0, 16, 16, 16, 0, 16, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 0, 0, 0, 0, 0, 16, 16, 16, 16, 0, 0, 0, 0, 0, 16, 16, 16, 16, 0, 0, 0, 0, 0, 16, 16, 16, 16, 0, 0, 0, 0, 0, 16, 16, 16, 16, 0, 0, 16, 16, 16, 16, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 16, 16, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 0, 16, 16, 16, 16, 16, 16, 16, 16]
print_info: n_embd_k_gqa = [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 0, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 1024, 1024, 1024, 0, 1024, 0, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 1024, 1024, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]
print_info: n_embd_v_gqa = [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 0, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 1024, 1024, 1024, 0, 1024, 0, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 0, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 1024, 1024, 1024, 1024, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 1024, 1024, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]
print_info: f_norm_eps = 0.0e+00
print_info: f_norm_rms_eps = 1.0e-05
print_info: f_clamp_kqv = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale = 0.0e+00
print_info: f_attn_scale = 0.0e+00
print_info: n_ff = [5376, 10752, 16128, 16128, 16128, 16128, 16128, 16128, 21504, 0, 0, 0, 0, 21504, 21504, 21504, 53248, 53248, 0, 0, 0, 0, 0, 0, 53248, 53248, 53248, 0, 0, 0, 0, 0, 53248, 53248, 53248, 26624, 0, 0, 0, 21504, 21504, 21504, 21504, 53248, 53248, 0, 0, 0, 0, 0, 53248, 53248, 53248, 53248, 0, 0, 0, 0, 0, 53248, 53248, 53248, 53248, 0, 0, 0, 0, 0, 53248, 53248, 53248, 53248, 0, 0, 0, 0, 0, 53248, 53248, 53248, 53248, 0, 0, 0, 0, 0, 53248, 37376, 37376, 37376, 0, 0, 32000, 26624, 26624, 26624, 26624, 26624, 26624, 0, 26624, 26624, 26624, 26624, 26624, 26624, 26624, 26624, 0, 0, 0, 0, 0, 32000, 53248, 53248, 53248, 0, 0, 0, 0, 0, 0, 0, 0, 399360, 0, 0, 0, 0, 0, 0, 0, 0, 425984, 0, 0, 0, 0, 0, 0, 0, 0, 343040, 0, 0, 0, 0, 0, 301056, 21504, 21504, 26624, 0, 26624, 26624, 37376, 53248, 53248, 53248, 53248, 26624]
print_info: n_expert = 0
print_info: n_expert_used = 0
print_info: causal attn = 1
print_info: pooling type = 0
print_info: rope type = 0
print_info: rope scaling = linear
print_info: freq_base_train = 500000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn = 131072
print_info: rope_finetuned = unknown
print_info: ssm_d_conv = 0
print_info: ssm_d_inner = 0
print_info: ssm_d_state = 0
print_info: ssm_dt_rank = 0
print_info: ssm_dt_b_c_rms = 0
print_info: model type = 405B
print_info: model params = 253.40 B
print_info: general.name = Llama_Nemotron_Ultra
print_info: vocab type = BPE
print_info: n_vocab = 128256
print_info: n_merges = 280147
print_info: BOS token = 128000 '<|begin_of_text|>'
print_info: EOS token = 128009 '<|eot_id|>'
print_info: EOT token = 128009 '<|eot_id|>'
print_info: EOM token = 128008 '<|eom_id|>'
print_info: LF token = 198 'Ċ'
print_info: EOG token = 128008 '<|eom_id|>'
print_info: EOG token = 128009 '<|eot_id|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 0 repeating layers to GPU
load_tensors: offloaded 0/163 layers to GPU
load_tensors: CPU_Mapped model buffer size = 143937.13 MiB
...................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch = 2048
llama_context: n_ubatch = 512
llama_context: causal_attn = 1
llama_context: flash_attn = 0
llama_context: freq_base = 500000.0
llama_context: freq_scale = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_context: CPU output buffer size = 0.49 MiB
init: kv_size = 4096, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 162, can_shift = 1
init: CPU KV buffer size = 1024.00 MiB
llama_context: KV self size = 1024.00 MiB, K (f16): 512.00 MiB, V (f16): 512.00 MiB
llama_context: CUDA0 compute buffer size = 10132.00 MiB
llama_context: CUDA_Host compute buffer size = 40.01 MiB
llama_context: graph nodes = 2399
llama_context: graph splits = 777 (with bs=512), 1 (with bs=1)
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 32
main: chat template is available, enabling conversation mode (disable it with -no-cnv)
main: chat template example:
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Hi there<|eot_id|><|start_header_id|>user<|end_header_id|>
How are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
system_info: n_threads = 32 (n_threads_batch = 32) / 64 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |
main: interactive mode on.
sampler seed: 1575017034
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 1
== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to the AI.
- To return control without starting a new line, end your input with '/'.
- If you want to submit another line, end your input with '\'.
- Not using system message. To change it, set a different value via -sys PROMPT
> What is the meaning of life?
8!B"(1D<)<,4@*3-A'3(<5,72A9.F62AC"%D08);E)6CDHCA0C.HC!%85>8DD(3!=&;48<"802=A,%0,6%D@0/'D<%(11@=&:.F*0A)!91.#;2,&;))
>
llama_perf_sampler_print: sampling time = 5.23 ms / 131 runs ( 0.04 ms per token, 25052.59 tokens per second)
llama_perf_context_print: load time = 213563.07 ms
llama_perf_context_print: prompt eval time = 4242.12 ms / 17 tokens ( 249.54 ms per token, 4.01 tokens per second)
llama_perf_context_print: eval time = 90563.05 ms / 114 runs ( 794.41 ms per token, 1.26 tokens per second)
llama_perf_context_print: total time = 108716.49 ms / 131 tokens
Interrupted by user |
Thanks for your update. I will take a closer look to compare modeling_*.py files to see if I can spot more differences. Can your llama-cli binary work with the 49B/51B ggufs? |
Im just wondering if we can rely on ffn.ffn_mult is None or a safer approach would be ffn.no_op is True to decide is the layer is dummy? https://huggingface.co/nvidia/Llama-3_1-Nemotron-Ultra-253B-v1/discussions/1 |
I believe the error can be due to me didn't skip this part of the code in the layer loop when we have a dummy layer.
I made a fix to skip this when we have a dummy layer. It doesn't break 51B inference. Can you give this a try? I believe you don't need to re-convert the gguf. Just re-compile and run llama-cli. Thanks a lot in advance. |
There are 10 layers with ffn_mult 1.95 but all of them has ffn no_op False. I believe you are talking about attention no_op True for some 1.95 layers. For those cases, they belong to the attention-free layers that I specifically handle them when n_head==0. |
@ymcki Wow amazing this fixed it. Thanks a lot for the quick fix! Here the result of my latest test. The model gave a perfect answer.
|
Good news. Now bartowski can start making the ggufs. :) |
I uploaded my Q4_K_M quants I made for testing to https://huggingface.co/nicoboss/Llama-3_1-Nemotron-Ultra-253B-v1-GGUF in case anyone wants to try out this PR and Llama-3_1-Nemotron-Ultra-253B-v1 model. Edit: I added Q3_K_M quants to above repository for those unable to run Q4_K_M. While doing so I also retested convert_hf_to_gguf.py and it still worked perfectly fine. |
I can also confirm it works, Q2 quants uploading to DevQuasar/nvidia.Llama-3_1-Nemotron-Ultra-253B-v1-GGUF |
I can confirm it works, tested Q3_K_M that @nicoboss uploaded. |
I tried out this branch. |
This is a known problem that is reported before. I don't have the resource to implement this as I only have one card. |
Unfortunately, when I try to run the model in LM Studio, I get this error:
I search high and low and can't figure out why the committed changes for pull/12843 don't allow it to run. Thoughts? |
Does other people have the same error? The 10th layer is the new dummy layer that has no self attention and no ffn. So the gguf itself should have no ffn norm weight in it. The reason why you are having this error is that your LM Studio doesn't have code to skip reading this weight. Please try it by compiling llama.cpp and run with llama-cli to make sure you have the new code. |
Someone has left a comment with the same issue on my quant: https://huggingface.co/DevQuasar/nvidia.Llama-3_1-Nemotron-Ultra-253B-v1-GGUF/discussions/1 It was worked fine for me with llama.cpp built from your branch. |
I thought it was merged into main - am I wrong. I recompiled it today( the main version) and no joy |
Well, the status of this PR is "Open" not "Merged". Apparently, it is not merged into main. Now all we need is a llama.cpp god to approve the PR... |
I tried your PR @ymcki - great work! It does work, however offloading to the GPU seems partially broken - using |
But overall, I tried using NVIDIA's suggested |
I have loaded on 4 GPUs (2x24+1x32+1x48) tinkering a lot with -ts, it seems to work fine with the default split mode. Where do you set the thinking = on? (Sorry, new to GGUF) |
What does it mean by "offloading to the GPU"? Do you mean "Offloading to CPU"? Do you mean "--split-mode row" can make VRAM distribution more even? |
Can you try "--split-mode row" and see if it makes any difference to the VRAM distribution? Thanks |
I get OOM, I guess -sm row is like tensor parallel? Since Q3_K_M is ~110GB, and my GPU with less VRAM is 24GB, it seems to load up to ~90-93GB until a 24GB VRAM GPU goes OOM. |
Thanks for the info. Seems like "-sm row" will use the main_gpu for small tensors and intermediate results. Supposedly it will run faster. So it doesn't do much to VRAM distribution. Probably combining "-sm row" with tuning "-ts" manually can work? |
Sorry for no answer, I still yet have to test this correctly. Tested a bit and seems -ts values have to wildly different when using -sm row. Was testing a bit UD_Q3_K_XL (https://huggingface.co/unsloth/Llama-3_1-Nemotron-Ultra-253B-v1-GGUF/tree/main/UD-Q3_K_XL) but couldn't load all 163 layers on GPU compared to Q3_K_M. In theory it should fit but I couldn't find a -ts value that let me load more on either the pair of gpu0/1 (4090 x2) or gpu 2 (5090) while also using just 1.9GB less on gpu 3 (a6000), even though the other 3 combined have 3gb free or so. With both default sm or row this seems to happen for that specific model. Probably have to test -sm row on normal Q3_K_M. |
I tried the Q4_k_m on a 4x3090 and cpu offloading and it seems to weirdly share layers, tried aso with -sm row with -ts and does not work too so always oom. |
Q4_K_M itself is 145GB. f16 kv cache at 4k is 1GB. That's way over your VRAM. How much RAM do you have? Maybe you should try IQ3_XXS which is 97.6GB first? You can download it from |
I have created a table for the amount of VRAM needed for each layer. Can someone told me what does "-ts" do exactly? For example, if you have three 3090s, you pass "-sm layer" and "-ts 2,3,1". Does that mean the first 54 layers go to card 1, next 81 layers to card 2 and the final 27 layers to card 3? |
I'm not sure how ts works either. I think it is a ratio how layers split, but I have to use very weird values to make it work on nemotron (I.e. 17.3,16,24,22.5) I think since layers seems to be on different sizes on nemotron and also some are empty/dummy layers, load is not even between layers. You can kinda make some GPUs with more VRAM to use larger layers of you load the layers on the 2nd and 4th fraction (last) of a model. |
I created a spread sheet for the exact number of parameters and KV cache size that can help people manually distribute their VRAM. Layer 0 is the tokenizer weight which I believe is not loaded to the VRAM. Layer 163 is the output layer. All dummy layers are skipped as they don't have any weight. This is an example of IQ3_M 3.66bpw and IQ4_NL KV Cache 4.5bpw and a context length of 65536. In this case, suppose you have 5x3090. Then layer 1 to 43 to card 1, layer 44 to 80 to card 2, layer 81 to 125 to card 3, layer 126 to 149 to card 4 and layer 150 to 163 to card 5. <style type="text/css"></style>
|
Make sure to read the contributing guidelines before submitting a PR
Dear all,
I was the person who made the PR for Llama-3_1-Nemotron-51B support.
#10669
I noticed that there is a new Deci model called Llama-3_1-Nemotron-Ultra-253B-v1.
Based on my understanding, in addition to the original three types of layers, it
added a new type of layer that has no attention as well as no ffn which I call
it a dummy layer.
So I modified convert_hf_to_gguf.py and src/llama-model.* to support this dummy
layer.
I tested the code against the original 51B model and it seems to have no error
during conversion and inference.
However, I don't have the resource to test it on the 253B model. Is it possible
someone here can try this PR and see if it works for the 253B model? Thanks a lot in advance.