Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Quantized MoE model is generating invalid repsonse #991

Open
BodhiHu opened this issue Jan 2, 2025 · 9 comments
Open

[BUG] Quantized MoE model is generating invalid repsonse #991

BodhiHu opened this issue Jan 2, 2025 · 9 comments
Labels
bug Something isn't working

Comments

@BodhiHu
Copy link

BodhiHu commented Jan 2, 2025

Describe the bug

Hi,

After quantizing the model, it's generating repeating response.

Below is the convert and test script:

from datasets import load_dataset
from transformers import AutoTokenizer
from gptqmodel import GPTQModel, QuantizeConfig

model_id = "/path/to/LLaMA-MoE_8B-2_8-sft"
quant_path = "/path/to/LLaMA-MoE_8B-2_8-sft-GPTQ-w4g128"

tokenizer = AutoTokenizer.from_pretrained(model_id)

calibration_dataset = [
  tokenizer(example["text"])
  for example in load_dataset(
    "allenai/c4",
    data_files="en/c4-train.00001-of-01024.json.gz",
    split="train"
  ).select(range(1024))
]

quant_config = QuantizeConfig(bits=4, group_size=128, desc_act=False)

model = GPTQModel.load(model_id, quant_config)

model.quantize(calibration_dataset)

model.save(quant_path)

model = GPTQModel.load(quant_path)

result = model.generate(
  **tokenizer(
      "Good Morning! Once upon a time, there's a company called", return_tensors="pt"
  ).to(model.device)
)[0]

print(f"\n{tokenizer.decode(result[0], skip_special_tokens=True)}\n")

Quantized model output: (which is the same as input)

Good Morning! Once upon a time, there's a company called

GPU Info

Using CPU: x86_64

Software Info

Ubuntu 22.04.4 LTS + Python 3.12.8

Show output of:

pip show gptqmodel torch transformers accelerate triton

Name: gptqmodel
Version: 1.4.6.dev0
Summary: A LLM quantization package with user-friendly apis. Based on GPTQ algorithm.
Home-page: https://github.com/ModelCloud/GPTQModel
Author: ModelCloud
Author-email: [email protected]
License:
Location: /home/huaishun/miniconda3/envs/GPTQModel/lib/python3.12/site-packages
Requires: accelerate, datasets, device-smi, numpy, packaging, pillow, protobuf, safetensors, sentencepiece, threadpoolctl, torch, transformers
Required-by:
---
Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3-Clause
Location: /home/huaishun/miniconda3/envs/GPTQModel/lib/python3.12/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, setuptools, sympy, triton, typing-extensions
Required-by: accelerate, gptqmodel
---
Name: transformers
Version: 4.47.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: [email protected]
License: Apache 2.0 License
Location: /home/huaishun/miniconda3/envs/GPTQModel/lib/python3.12/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: gptqmodel
---
Name: accelerate
Version: 1.2.1
Summary: Accelerate
Home-page: https://github.com/huggingface/accelerate
Author: The HuggingFace team
Author-email: [email protected]
License: Apache
Location: /home/huaishun/miniconda3/envs/GPTQModel/lib/python3.12/site-packages
Requires: huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch
Required-by: gptqmodel
---
Name: triton
Version: 3.1.0
Summary: A language and compiler for custom Deep Learning operations
Home-page: https://github.com/triton-lang/triton/
Author: Philippe Tillet
Author-email: [email protected]
License:
Location: /home/huaishun/miniconda3/envs/GPTQModel/lib/python3.12/site-packages
Requires: filelock
Required-by: torch

Model/Datasets

Model: https://huggingface.co/llama-moe/LLaMA-MoE-v2-3_8B-2_8-sft
Dataset: allenai/c4

@BodhiHu BodhiHu added the bug Something isn't working label Jan 2, 2025
@Qubitium
Copy link
Collaborator

Qubitium commented Jan 2, 2025

After quantizing the model, it's generating repeating response.

Can you post samples of the repeating response?

@BodhiHu
Copy link
Author

BodhiHu commented Jan 6, 2025

Hi @Qubitium ,

The input is:

Good Morning! Once upon a time, there's a company called

And below is the model output, which is just the same as input:

Good Morning! Once upon a time, there's a company called

@Qubitium
Copy link
Collaborator

Qubitium commented Jan 6, 2025

@BodhiHu All llm models repeat the input. This is normal. The question, is it repeating? Like:

Good Morning! Once upon a time, there's a company called Good Morning! Once upon a time, there's a company called Good Morning! Once upon a time, there's a company called

@BodhiHu
Copy link
Author

BodhiHu commented Jan 6, 2025

Hi @Qubitium , then it's not repeating, but just outputting:

Good Morning! Once upon a time, there's a company called

without any further generated texts.

Thanks a lot.

@Qubitium
Copy link
Collaborator

Qubitium commented Jan 7, 2025

@BodhiHu Make the model public and we can check. We can't debug a private model that may deviate from normal llama models. I see you are using some special llama mode model.

@BodhiHu
Copy link
Author

BodhiHu commented Jan 7, 2025

Hi @Qubitium , the model is publicly available: https://huggingface.co/llama-moe/LLaMA-MoE-v2-3_8B-2_8-sft
You could just download the original model from above link.

And here's the quant script:

from datasets import load_dataset
from transformers import AutoTokenizer
from gptqmodel import GPTQModel, QuantizeConfig

model_id = "/path/to/LLaMA-MoE-v2-3_8B-2_8-sft"
quant_path = "/path/to/LLaMA-MoE-v2-3_8B-2_8-sft-GPTQ-w4g128"

tokenizer = AutoTokenizer.from_pretrained(model_id)

calibration_dataset = [
  tokenizer(example["text"])
  for example in load_dataset(
    "allenai/c4",
    data_files="en/c4-train.00001-of-01024.json.gz",
    split="train"
  ).select(range(1024))
]

quant_config = QuantizeConfig(bits=4, group_size=128, desc_act=False)

model = GPTQModel.load(model_id, quant_config)

model.quantize(calibration_dataset)

model.save(quant_path)

model = GPTQModel.load(quant_path)

result = model.generate(
  **tokenizer(
      "Good Morning! Once upon a time, there's a company called", return_tensors="pt"
  ).to(model.device)
)[0]

print(f"\n{tokenizer.decode(result[0], skip_special_tokens=True)}\n")

@Qubitium
Copy link
Collaborator

Qubitium commented Jan 7, 2025

This is an mistral "moe" model which is very hard to quantize. You need to exclude all layers related to moe from quantization.

@BodhiHu
Copy link
Author

BodhiHu commented Jan 7, 2025

This is an mistral "moe" model which is very hard to quantize. You need to exclude all layers related to moe from quantization.

ok... thanks a lot for your help, we'll try that : D

@BodhiHu
Copy link
Author

BodhiHu commented Jan 8, 2025

This is an mistral "moe" model which is very hard to quantize. You need to exclude all layers related to moe from quantization.

@Qubitium
We skipped all the moe layers from quantization, but now the output is:

<|begin_of_text|>Good morning, let's
.........................................................

Test code:

import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from gptqmodel import GPTQModel, QuantizeConfig

model_id = "path/to/LLaMA-MoE-v2-3_8B-2_8-sft-GPTQ-w4g128"

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = GPTQModel.load(model_id, device="cpu")

# model = torch.nn.DataParallel(model)

result = model.generate(
  **tokenizer(
      "Good morning, let's", return_tensors="pt"
  ).to(model.device)
)[0]

print(tokenizer.decode(result))

Skipped quantize MoE layers:

class MixtralGPTQ(BaseGPTQModel):
    base_modules = ["model.embed_tokens", "model.norm"]

    layers_node = "model.layers"
    layer_type = "MixtralDecoderLayer"
    layer_modules = [
        ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
        ["self_attn.o_proj"],
        # Please see issue for MoE models quantize:
        #   https://github.com/ModelCloud/GPTQModel/issues/991#issuecomment-2574533252
        # [
        #     "block_sparse_moe.experts.0.w1",
        #     "block_sparse_moe.experts.1.w1",
        #     "block_sparse_moe.experts.2.w1",
        #     "block_sparse_moe.experts.3.w1",
        #     "block_sparse_moe.experts.4.w1",
        #     "block_sparse_moe.experts.5.w1",
        #     "block_sparse_moe.experts.6.w1",
        #     "block_sparse_moe.experts.7.w1",
        #     "block_sparse_moe.experts.0.w3",
        #     "block_sparse_moe.experts.1.w3",
        #     "block_sparse_moe.experts.2.w3",
        #     "block_sparse_moe.experts.3.w3",
        #     "block_sparse_moe.experts.4.w3",
        #     "block_sparse_moe.experts.5.w3",
        #     "block_sparse_moe.experts.6.w3",
        #     "block_sparse_moe.experts.7.w3",
        # ],
        # [
        #     "block_sparse_moe.experts.0.w2",
        #     "block_sparse_moe.experts.1.w2",
        #     "block_sparse_moe.experts.2.w2",
        #     "block_sparse_moe.experts.3.w2",
        #     "block_sparse_moe.experts.4.w2",
        #     "block_sparse_moe.experts.5.w2",
        #     "block_sparse_moe.experts.6.w2",
        #     "block_sparse_moe.experts.7.w2",
        # ],
    ]

@Qubitium Qubitium changed the title [BUG] Quantized and model is generating repeating response [BUG] Quantized MoE model is generating invalid repsonse Jan 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants