Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: 安装paddlecustom之后再运行paddlenlp会报错 #9403

Open
1 task done
programmer-lxj opened this issue Nov 11, 2024 · 9 comments
Open
1 task done

[Bug]: 安装paddlecustom之后再运行paddlenlp会报错 #9403

programmer-lxj opened this issue Nov 11, 2024 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@programmer-lxj
Copy link

软件环境

- paddlepaddle:3.0.0-beta0
- paddlepaddle-gpu: 用的cpu
- paddlenlp: 3.0.0b2

重复问题

  • I have searched the existing issues

错误描述

我按照这个链接安装了paddlecustom的custom_cpu
https://github.com/PaddlePaddle/PaddleCustomDevice/tree/develop/backends/custom_cpu
然后安装paddlenlp,运行就会报下面错误:
root@5d501bece6db:~/sycl_workspace# python3 testllama.py 
I1111 08:07:07.623097 809117 init.cc:236] ENV [CUSTOM_DEVICE_ROOT]=/usr/local/lib/python3.8/dist-packages/paddle_custom_device
I1111 08:07:07.623117 809117 init.cc:145] Try loading custom device libs from: [/usr/local/lib/python3.8/dist-packages/paddle_custom_device]
custom_cpu plugin compiled with gcc
I1111 08:07:07.624505 809117 custom_device.cc:1099] Succeed in loading custom runtime in lib: /usr/local/lib/python3.8/dist-packages/paddle_custom_device/libpaddle-custom-cpu.so
I1111 08:07:07.624509 809117 custom_kernel.cc:39] No custom kernel info found in loaded lib(s).
I1111 08:07:07.624511 809117 init.cc:157] Finished in LoadCustomDevice with libs_path: [/usr/local/lib/python3.8/dist-packages/paddle_custom_device]
I1111 08:07:07.624516 809117 init.cc:242] CustomDevice: custom_cpu, visible devices count: 2
[2024-11-11 08:07:08,573] [    INFO] - We are using <class 'paddlenlp.transformers.llama.tokenizer.Llama3Tokenizer'> to load 'meta-llama/Llama-3.2-1B'.
[2024-11-11 08:07:08,923] [    INFO] - We are using <class 'paddlenlp.transformers.llama.modeling.LlamaForCausalLM'> to load 'meta-llama/Llama-3.2-1B'.
[2024-11-11 08:07:08,923] [    INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-3.2-1B/config.json
[2024-11-11 08:07:08,924] [    INFO] - Loading weights file from cache at /root/.paddlenlp/models/meta-llama/Llama-3.2-1B/model.safetensors
[2024-11-11 08:07:09,350] [    INFO] - Loaded weights file from disk, setting weights to model.
[2024-11-11 08:08:36,107] [    INFO] - All model checkpoint weights were used when initializing LlamaForCausalLM.

[2024-11-11 08:08:36,108] [ WARNING] - Some weights of LlamaForCausalLM were not initialized from the model checkpoint at meta-llama/Llama-3.2-1B and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[2024-11-11 08:08:36,108] [    INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-3.2-1B/generation_config.json
[2024-11-11 08:08:36,108] [    INFO] - Generation config file not found, using a generation config created from the model config.
[2024-11-11 08:08:36,233] [ WARNING] - `max_length` will be deprecated in future releases, use `max_new_tokens` instead.
Traceback (most recent call last):
  File "testllama.py", line 7, in <module>
    outputs = model.generate(**input_features, max_length=128)
  File "/usr/local/lib/python3.8/dist-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/usr/local/lib/python3.8/dist-packages/paddle/base/dygraph/base.py", line 337, in _decorate_function
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/paddlenlp/generation/utils.py", line 927, in generate
    return self.greedy_search(
  File "/usr/local/lib/python3.8/dist-packages/paddlenlp/generation/utils.py", line 1081, in greedy_search
    outputs = self(**model_inputs)
  File "/usr/local/lib/python3.8/dist-packages/paddle/nn/layer/layers.py", line 1426, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/paddlenlp/transformers/llama/modeling.py", line 2013, in forward
    outputs = self.llama(
  File "/usr/local/lib/python3.8/dist-packages/paddle/nn/layer/layers.py", line 1426, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/paddlenlp/transformers/llama/modeling.py", line 1725, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.8/dist-packages/paddle/nn/layer/layers.py", line 1426, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/paddlenlp/transformers/llama/modeling.py", line 1189, in forward
    outputs = self.self_attn(
  File "/usr/local/lib/python3.8/dist-packages/paddle/nn/layer/layers.py", line 1426, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/paddlenlp/transformers/llama/modeling.py", line 1083, in forward
    outputs = self.attn_func(
  File "/usr/local/lib/python3.8/dist-packages/paddlenlp/transformers/llama/modeling.py", line 278, in scaled_dot_product_attention
    attn_weights = attn_weights + attention_mask
RuntimeError: (PreconditionNotMet) Tensor holds no memory. Call Tensor::mutable_data firstly.
  [Hint: holder_ should not be null.] (at /paddle/paddle/phi/core/dense_tensor_impl.cc:44)

稳定复现步骤 & 代码

from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
#model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B", dtype="bfloat16")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", dtype="float32")
input_features = tokenizer("你好!请自我介绍一下。", return_tensors="pd")
outputs = model.generate(**input_features, max_length=128)
#tokenizer.batch_decode(outputs[0])
print(tokenizer.batch_decode(outputs[0]))

@programmer-lxj programmer-lxj added the bug Something isn't working label Nov 11, 2024
@xiaoguoguo626807
Copy link
Contributor

xiaoguoguo626807 commented Nov 11, 2024

使用paddlepaddle cpu 版本的安装包试下吧

@programmer-lxj
Copy link
Author

我安装的就是paddlepaddle cpu 版本。我不是要跑npu,我就是想跑cpu,因为我们自己加了一个sycl分支,然后就是在cpu上跑,遇到了错误,我以为是新分支的问题,然后我去跑custom_cpu分支,也报错,就是不安装paddlecustom就paddlenlp可以跑,但是只要安装了paddlecustom,paddlenlp就会报错,是不是没有兼容好呀,应该修改paddlenlp得哪些代码?是modeling.py里吗?

@xiaoguoguo626807
Copy link
Contributor

python ../tests/test_MNIST_model.py 这个测试能跑吗,能跑的话可能是nlp 需要适配。不能跑就是customdevice 里有些问题

@DrownFish19 DrownFish19 assigned DrownFish19 and unassigned KB-Ding Nov 11, 2024
@programmer-lxj
Copy link
Author

python ../tests/test_MNIST_model.py 这个测试能跑。就是安装了paddlecustom之后paddlenlp就不能跑了,感觉是paddlenlp和paddlecustom没适配好

@programmer-lxj
Copy link
Author

我在(https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py)的第81行看到了if get_env_device() in ["npu", "mlu", "gcu"]:。在https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/utils/tools.py第129行看到了elif "gcu" in paddle.device.get_all_custom_device_type():
return "gcu",但是这两个程序里设备没有完全覆盖paddlecustom里的设备,感觉做了一部分适配然后没有做完?我尝试仿照改了一下,就会报其他错误

@xiaoguoguo626807
Copy link
Contributor

https://github.com/PaddlePaddle/PaddleCustomDevice/tree/develop/backends/custom_cpu 你测试这里的verify 部分能正常跑吗, 看起来是编的有问题,编成GCU版本的customdevice了。 正常打出来是custom_cpu

@programmer-lxj
Copy link
Author

可以正常跑。python3 -c "import paddle; print(paddle.device.get_all_custom_device_type())"
I1111 09:25:59.223500 857147 init.cc:236] ENV [CUSTOM_DEVICE_ROOT]=/usr/local/lib/python3.8/dist-packages/paddle_custom_device
I1111 09:25:59.223515 857147 init.cc:145] Try loading custom device libs from: [/usr/local/lib/python3.8/dist-packages/paddle_custom_device]
custom_cpu plugin compiled with gcc
I1111 09:25:59.229365 857147 custom_device.cc:1099] Succeed in loading custom runtime in lib: /usr/local/lib/python3.8/dist-packages/paddle_custom_device/libpaddle-custom-cpu.so
I1111 09:25:59.229370 857147 custom_kernel.cc:39] No custom kernel info found in loaded lib(s).
I1111 09:25:59.229372 857147 init.cc:157] Finished in LoadCustomDevice with libs_path: [/usr/local/lib/python3.8/dist-packages/paddle_custom_device]
I1111 09:25:59.229377 857147 init.cc:242] CustomDevice: custom_cpu, visible devices count: 2
['custom_cpu']

python3 ../tests/test_MNIST_model.py
I1111 09:26:33.263067 857586 init.cc:236] ENV [CUSTOM_DEVICE_ROOT]=/usr/local/lib/python3.8/dist-packages/paddle_custom_device
I1111 09:26:33.263084 857586 init.cc:145] Try loading custom device libs from: [/usr/local/lib/python3.8/dist-packages/paddle_custom_device]
custom_cpu plugin compiled with gcc
I1111 09:26:33.264544 857586 custom_device.cc:1099] Succeed in loading custom runtime in lib: /usr/local/lib/python3.8/dist-packages/paddle_custom_device/libpaddle-custom-cpu.so
I1111 09:26:33.264549 857586 custom_kernel.cc:39] No custom kernel info found in loaded lib(s).
I1111 09:26:33.264549 857586 init.cc:157] Finished in LoadCustomDevice with libs_path: [/usr/local/lib/python3.8/dist-packages/paddle_custom_device]
I1111 09:26:33.264554 857586 init.cc:242] CustomDevice: custom_cpu, visible devices count: 2
Epoch 0 step 0, Loss = 2.2956037521362305, Accuracy = 0.15625
Epoch 0 step 100, Loss = 2.155289649963379, Accuracy = 0.3125
Epoch 0 step 200, Loss = 2.1177732944488525, Accuracy = 0.4375
Epoch 0 step 300, Loss = 2.0089213848114014, Accuracy = 0.53125
Epoch 0 step 400, Loss = 2.0845465660095215, Accuracy = 0.421875
Epoch 0 step 500, Loss = 2.047300100326538, Accuracy = 0.453125
Epoch 0 step 600, Loss = 1.8561761379241943, Accuracy = 0.71875
Epoch 0 step 700, Loss = 1.9915285110473633, Accuracy = 0.53125
Epoch 0 step 800, Loss = 1.8925955295562744, Accuracy = 0.640625
Epoch 0 step 900, Loss = 1.8199623823165894, Accuracy = 0.734375

/usr/local/lib/python3.8/site-packages/paddlenlp/transformers/llama/modeling.py这里是装paddlenlp产生的吧,编paddlecustomdevice不会有modeling.py把。paddlenlp和paddlecustomdevice安装顺序有影响吗?

@programmer-lxj
Copy link
Author

就是paddlecustom和paddlenlp单独跑各自github页面那些测试例子都可以跑,就是安装了paddlecustom之后再去用paddlenlp推理就会报错,这是什么原因?

@DrownFish19
Copy link
Collaborator

就是paddlecustom和paddlenlp单独跑各自github页面那些测试例子都可以跑,就是安装了paddlecustom之后再去用paddlenlp推理就会报错,这是什么原因?

可能因为当前paddlenlp对应的modeling还没有适配验证对应的算子,在自定义设备上存在算子确实或者不对齐的情况。针对这种情况,一方面需要paddlenlp适配,另一方面需要设备开发人员来适配算子。如果可以的话,可以将适配paddlenlp modeling的代码以PR形式提交,共同适配新硬件。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants