-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ARM aarch-64 server build failed (host OS: Ubuntu22.04.3) #2021
Comments
Actually, the nvcc is ok to run as these: root@f8c2e06fbf8b:/mnt/vllm# nvcc -v |
there is cuda: root@f8c2e06fbf8b:/mnt/vllm# echo $CUDA_HOME root@f8c2e06fbf8b:/mnt/vllm# type nvcc github.com/vllm# python3 -c "import torch; print(torch.cuda.is_available()); print(torch.version);" |
add
to setup.py at line 268 |
I have the same problem and would be glad if there would be any help. I am running inside the nvidia pytorch_23.12 Container. |
Got it working with the changes in this branch: https://github.com/haileyschoelkopf/vllm/tree/aarm64-dockerfile , with built dockerfiles here: https://hub.docker.com/r/haileysch/vllm-aarch64-base https://hub.docker.com/r/haileysch/vllm-aarch64-openai hopefully this'll be helpful to others! |
HI, guys , had you solved the issue ? |
@tuanhe |
Had a similar problem on the GH200 (aarch64 Grace CPU). Main issues that needed to be overcome:
For future updating, you can see the changes here: drikster80@359fd4f |
Thank you all. I have built the image using the script provided by @drikster80, and it takes about 12 hours (most of the time is spent for mamba builder and xformers). So to save time for others, I have made the image public at https://hub.docker.com/r/zihaokevinzhou/vllm-aarch64-openai . I have validated it works well for my personal hosting of fp8 quantized version of llama-3-70b. |
@ZihaoZhou, thank you. It normally only takes ~80 min on my system. 12 hrs seems excessive. I'm working on an update for v0.5.2, but haven't gotten the new I haven't been uploading since the container is ~33GB. It looks like the one you uploaded is 13GB? Is that just from native compression? I'm sure there are some ways to cut it down (e.g. remove some of the build artifacts from the last image?). |
@ZihaoZhou @drikster80 The step that took me the most time was: Additionally, for the xformers part, I spent an entire afternoon, and it also seemed to be stuck there. Now, vllm is successfully running on GH200, thanks to your selfless contribution! May I ask, regarding the Docker image on aarch64, compared to the original version, is the main difference just commenting out the items you mentioned in the requirements.txt? |
@cyc00518 You can see the list of full changes here: main...drikster80:vllm:gh200-docker Effectively, As a side note, if you're using a GH200 bare metal, you might also want to checkout my auto-install for GH200s. Getting it setup with optimizations, NCCL, OFED, for high-speed distributed training/inference was a pain, so automated it for people to use or reference: https://github.com/drikster80/gh200-Ubuntu-22.04-autoinstall
|
@drikster80 I have learned a lot, and I also appreciate the additional information you provided! |
Updated the aarch64 remote branch to v0.5.2: https://github.com/drikster80/vllm/tree/gh200-docker Pushed up a GH200 specific version (built for SM 9.0+PTX) to https://hub.docker.com/r/drikster80/vllm-gh200-openai Building a more generic version now and will update this comment when complete. |
If anyone comes across this and is trying to get Llama-3.1 to work with the GH200 (or aarch64 + H100), I have the latest working container (v0.5.3-post1 with a couple more commits) image up at https://hub.docker.com/r/drikster80/vllm-gh200-openai Codes is still in the https://github.com/drikster80/vllm/tree/gh200-docker branch. Validated Llama-3.1-8b-Instruct works, and trying to working to test 405B-FP8 now (with cpu-offload) |
+1 |
Also built some images for arm64 with cuda arch 9.0 (for GH200/H100) and for amd64 for cuda arch 8.0 and 9.0 (A100 and H100) from a fork of @drikster80 's installation to focus on the reproducibility of the build and to have both architectures start from the NGC PyTorch images. |
@FanZhang91, I still maintain two docker images for aarch64 on DockerHub. These have both been updated to v0.6.1 as of 30 min ago. All Supported CUDA caps: drikster80/vllm-aarch64-openai:latest They are slightly different from upstream in a couple small ways:
You can pull and build yourself with: git clone -b gh200-docker https://github.com/drikster80/vllm.git
cd ./vllm\
# Update the max_jobs and nvvc_threads as needed to prevent OOM. This is good for a GH200.
docker build . --target vllm-openai -t drikster80/vllm-aarch64-openai:v0.6.1 --build-arg max_jobs=10 --build-arg nvcc_threads=8
# Can also pin to a specific Nvidia GPU Capability:
# docker build . --target vllm-openai -t drikster80/vllm-gh200-openai:v0.6.1 --build-arg max_jobs=10 --build-arg nvcc_threads=8 --build-arg torch_cuda_arch_list="9.0+PTX" It takes ~1 hr to build on a pinned capability, and ~3+ hours to build for all GPU capability levels. Longer if you reduce the max_jobs variable. @skandermoalla, I've been meaning to make a PR for a merged DockerFile that can product both arm64 & amd64... just haven't had the time to work it. This was requested by some of the vllm maintainers and would make my life a lot easier to not need to maintain a separate fork. Is this something you'd be interested in collaborating on? |
There weren't any changes in the Dockerfile or dependencies to compile for arm64 and and64 as most of the tricky packages are compiled from source. |
Hi @drikster80 , thanks for your docker images. After pulling the docker image, is it still necessary to rebuild or recompile from the source code? Since I got the error:
I am using the NVIDIA Jetson Orin with the docker in vllm v0.6.1, and the device is different from yours. |
can you please try out #8713 ? @drikster80 @gongchengli I spare some time to investigate the issue, and it looks the most complicated part is to bring your own pytorch ( @drikster80 does this by using ngc pytorch container). other than that, it is pretty straight-forward. on that branch, I can easily build vllm from scratch with nightly pytorch, in a fresh new environment: $ pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
$ git clone https://github.com/vllm-project/vllm.git
$ cd vllm
$ python use_existing_torch.py
$ pip install -r requirements-build.txt
$ pip install -vvv -e . --no-build-isolation |
are your environment arm? |
yes, I built it on GH200 successfully. |
so many erros : 32 errors detected in the compilation of "/home/qz/zww/vllm/csrc/quantization/gptq/q_gemm.cu". |
I failed |
which dockerfile you use? |
my platform is aarch64-linux in jetson |
I failed :Feature 'f16 arithemetic and compare instructions' requires .target sm_53 or higher |
Hi. I followed your steps here. But I either stuck at the first step on building the pytorch in a new python=3.10 virtual env from scratch or failed at the last step with a pre-installed torch=2.3.1+cuda12.0 and python=3.9 virtual env.
Info. of my machine:GH200 aarch64 node (Linux 5.14.0-427.37.1.el9_4.aarch64+64k aarch64)
pre-installed env:
|
this might be a problem of your |
Thanks for replying. Also tried this before but got the same error as case 2. |
I am experiencing the same issue as @Jerrrrykun on a GH200 Node with architecture The
A bunch of errors occurs in I can verify that my Pytorch installation is working:
|
you can run some pytorch cuda program to verify if the installation is correct. |
Should we reopen this until we have a docker image that just works on GH200? Can also file a new issue. I tried the vllm/vllm-openai:v0.6.3.post1 docker image and it doesn't work on GH200. It throws this error:
The image from @drikster80 works great! Is someone working on a PR to merge the required changes that @drikster80 made back into upstream? |
It worked for me on Jetson Orin Nano, vllm is able to start but I am facing issue with running with models. vllm serve --device cuda ibm-granite/granite-3.0-2b-instruct Error I get or raise ValueError( environment: (myenv) [yajuvendra@llmhost vllm]$ python3 collect_env.py OS: Red Hat Enterprise Linux 9.4 (Plow) (aarch64) Python version: 3.10.15 (main, Oct 3 2024, 07:21:53) [GCC 11.2.0] (64-bit runtime) CPU: Versions of relevant libraries: |
this error is clear. it is not supported.
this is strange. we need more logging output for this. |
thanks a lot @youkaichao no matter what model I use I am landing on the same error. attached are two logs, thanks a lot for your help. (myenv) [yajuvendra@llmhost vllm]$ git remote -v br.. |
@yajuvendrarawat you are running your code under |
@youkaichao @samos123 |
If anyone with a MGX or GH200 wants to test the latest version with different models, let me know if any supported models fail. Uploaded the latest (v0.6.4.post1) to dockerhub: https://hub.docker.com/r/drikster80/vllm-gh200-openai/tags The new version ( PR #10499 ) is unique that it doesn't use the Nvidia Pytorch container anymore and matches closer to standard VLLM container. It uses the nightly version of Pytorch (now supports aarch64) and compiles the modules that don't release a aarch64 wheel (e.g. bitsandbytes, flashinfer, triton, mamba, causal-conv1d, etc.). So it should be considered experimental, but I've run some tests on a couple of models and haven't had any issues. |
Hello @youkaichao , I am not sure what I am doing wrong now. My device is Jetson Orin Nano, compilation is fine one additional step I had to is to install Xformers as device is being detected as Volta and Turing. and I am getting error while executing vllm serve any advice? (myenv) [yajuvendra@llmhost ~]$ vllm serve --device cuda ibm-granite/granite-3.0-2b-instruct br.. |
@yajuvendrarawat I think you need #9735 , which is landed just now. |
Hi @drikster80 I tried this command: sudo docker run --runtime nvidia --gpus all \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=xxx" \
--env NCCL_TIMEOUT=600 \
-p 8000:8000 \
--ipc=host \
--name vllm \
drikster80/vllm-gh200-openai:latest \
--model meta-llama/Meta-Llama-3.1-70B-Instruct \
--max-num-seqs 1 \
--tensor-parallel-size 1 \
--max-model-len 65536 \
--api-key eyJhIjoiYmI5ZW \
--trust-remote-code \
--gpu-memory-utilization 0.85 But report error like this: INFO 11-26 10:36:50 model_runner.py:1072] Starting to load model meta-llama/Meta-Llama-3.1-70B-Instruct...
ERROR 11-26 10:36:58 engine.py:366] CUDA out of memory. Tried to allocate 896.00 MiB. GPU 0 has a total capacity of 94.50 GiB of which 767.00 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 93.12 GiB is allocated by PyTorch, and 224.00 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
ERROR 11-26 10:36:58 engine.py:366] Traceback (most recent call last):
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
ERROR 11-26 10:36:58 engine.py:366] engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
ERROR 11-26 10:36:58 engine.py:366] return cls(ipc_path=ipc_path,
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
ERROR 11-26 10:36:58 engine.py:366] self.engine = LLMEngine(*args, **kwargs)
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 345, in __init__
ERROR 11-26 10:36:58 engine.py:366] self.model_executor = executor_class(vllm_config=vllm_config, )
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 36, in __init__
ERROR 11-26 10:36:58 engine.py:366] self._init_executor()
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/gpu_executor.py", line 40, in _init_executor
ERROR 11-26 10:36:58 engine.py:366] self.driver_worker.load_model()
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 152, in load_model
ERROR 11-26 10:36:58 engine.py:366] self.model_runner.load_model()
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1074, in load_model
ERROR 11-26 10:36:58 engine.py:366] self.model = get_model(vllm_config=self.vllm_config)
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
ERROR 11-26 10:36:58 engine.py:366] return loader.load_model(vllm_config=vllm_config)
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.py", line 337, in load_model
ERROR 11-26 10:36:58 engine.py:366] model = _initialize_model(vllm_config=vllm_config)
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.py", line 104, in _initialize_model
ERROR 11-26 10:36:58 engine.py:366] return model_class(vllm_config=vllm_config, prefix=prefix)
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 507, in __init__
ERROR 11-26 10:36:58 engine.py:366] self.model = LlamaModel(vllm_config=vllm_config,
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 124, in __init__
ERROR 11-26 10:36:58 engine.py:366] old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 298, in __init__
ERROR 11-26 10:36:58 engine.py:366] self.start_layer, self.end_layer, self.layers = make_layers(
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 511, in make_layers
ERROR 11-26 10:36:58 engine.py:366] maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 300, in <lambda>
ERROR 11-26 10:36:58 engine.py:366] lambda prefix: LlamaDecoderLayer(config=config,
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 231, in __init__
ERROR 11-26 10:36:58 engine.py:366] self.mlp = LlamaMLP(
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 73, in __init__
ERROR 11-26 10:36:58 engine.py:366] self.gate_up_proj = MergedColumnParallelLinear(
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 424, in __init__
ERROR 11-26 10:36:58 engine.py:366] super().__init__(input_size=input_size,
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 304, in __init__
ERROR 11-26 10:36:58 engine.py:366] self.quant_method.create_weights(
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 122, in create_weights
ERROR 11-26 10:36:58 engine.py:366] weight = Parameter(torch.empty(sum(output_partition_sizes),
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] File "/usr/local/lib/python3.12/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
ERROR 11-26 10:36:58 engine.py:366] return func(*args, **kwargs)
ERROR 11-26 10:36:58 engine.py:366] ^^^^^^^^^^^^^^^^^^^^^
ERROR 11-26 10:36:58 engine.py:366] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 896.00 MiB. GPU 0 has a total capacity of 94.50 GiB of which 767.00 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 93.12 GiB is allocated by PyTorch, and 224.00 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Process SpawnProcess-1:
Traceback (most recent call last):
File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 368, in run_mp_engine
raise e
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
return cls(ipc_path=ipc_path,
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
self.engine = LLMEngine(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 345, in __init__
self.model_executor = executor_class(vllm_config=vllm_config, )
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 36, in __init__
self._init_executor()
File "/usr/local/lib/python3.12/dist-packages/vllm/executor/gpu_executor.py", line 40, in _init_executor
self.driver_worker.load_model()
File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 152, in load_model
self.model_runner.load_model()
File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1074, in load_model
self.model = get_model(vllm_config=self.vllm_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
return loader.load_model(vllm_config=vllm_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.py", line 337, in load_model
model = _initialize_model(vllm_config=vllm_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.py", line 104, in _initialize_model
return model_class(vllm_config=vllm_config, prefix=prefix)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 507, in __init__
self.model = LlamaModel(vllm_config=vllm_config,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 124, in __init__
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 298, in __init__
self.start_layer, self.end_layer, self.layers = make_layers(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 511, in make_layers
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 300, in <lambda>
lambda prefix: LlamaDecoderLayer(config=config,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 231, in __init__
self.mlp = LlamaMLP(
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 73, in __init__
self.gate_up_proj = MergedColumnParallelLinear(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 424, in __init__
super().__init__(input_size=input_size,
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 304, in __init__
self.quant_method.create_weights(
File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 122, in create_weights
weight = Parameter(torch.empty(sum(output_partition_sizes),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 896.00 MiB. GPU 0 has a total capacity of 94.50 GiB of which 767.00 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 93.12 GiB is allocated by PyTorch, and 224.00 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank0]:[W1126 10:36:58.745447229 ProcessGroupNCCL.cpp:1432] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
Task exception was never retrieved
future: <Task finished name='Task-2' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/zmq/_future.py", line 400, in poll
raise _zmq.ZMQError(_zmq.ENOTSUP)
zmq.error.ZMQError: Operation not supported
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 650, in <module>
uvloop.run(run_server(args))
File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 109, in run
return __asyncio.run(
^^^^^^^^^^^^^^
File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 61, in wrapper
return await main
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 616, in run_server
async with build_async_engine_client(args) as engine_client:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
return await anext(self.gen)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 114, in build_async_engine_client
async with build_async_engine_client_from_engine_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
return await anext(self.gen)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 211, in build_async_engine_client_from_engine_args
raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause. Could you please check why? |
Thanks a lot @youkaichao it moved ahead and I have faced another error, any other pointers? The pytorch is with the cuda Detected 1 CUDA Capable device(s) Device 0: "Orin" (myenv1) [yajuvendra@llmhost ~]$ nvcc --version (myenv1) [yajuvendra@llmhost vllm]$ vllm serve --device cuda ibm-granite/granite-3.0-2b-instruct [rank0]:[W1127 15:31:43.495876646 ProcessGroupNCCL.cpp:1427] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator()) |
@yajuvendrarawat how do you install vllm? it seems you don't have the kernels compiled for the cuda architecture. |
My environment is Rhel 9 on jetson orin nano. Created a conda environment the way it’s explained on vllm documentation and then below commands. $ pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 |
I don't have a jetson machine to try it, you can contact the author @conroy-cheers of #9735 . |
Hello @youkaichao , The issue was that my pytorch was not supporting sm_87 arch which I got now and now I am facing another issue on my jetson nano orin. Error is attached is the error file Can you please help to see why I am getting error. br.. |
@yajuvendrarawat It means something is wrong with this line: Line 310 in 17138af
You can try to debug here. |
do as: https://docs.vllm.ai/en/latest/getting_started/installation.html
here is the details in side the docker instance:
root@f8c2e06fbf8b:/mnt/vllm# pip install -e .
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Obtaining file:///mnt/vllm
Installing build dependencies ... done
Checking if build backend supports build_editable ... done
Getting requirements to build editable ... error
error: subprocess-exited-with-error
× Getting requirements to build editable did not run successfully.
│ exit code: 1
╰─> [22 lines of output]
/tmp/pip-build-env-4xoxai9j/overlay/local/lib/python3.10/dist-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:84.)
device: torch.device = torch.device(torch._C._get_default_device()), # torch.device('cpu'),
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
:142: UserWarning: Unsupported CUDA/ROCM architectures ({'6.1', '7.2', '8.7', '5.2', '6.0'}) are excluded from the
TORCH_CUDA_ARCH_LIST
env variable (5.2 6.0 6.1 7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX). Supported CUDA/ROCM architectures are: {'7.5', '8.0', '9.0', '7.0', '8.6+PTX', '9.0+PTX', '8.6', '8.0+PTX', '8.9+PTX', '8.9', '7.0+PTX', '7.5+PTX'}.Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in
main()
File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
json_out['return_val'] = hook(**hook_input['kwargs'])
File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 132, in get_requires_for_build_editable
return hook(config_settings)
File "/tmp/pip-build-env-4xoxai9j/overlay/local/lib/python3.10/dist-packages/setuptools/build_meta.py", line 441, in get_requires_for_build_editable
return self.get_requires_for_build_wheel(config_settings)
File "/tmp/pip-build-env-4xoxai9j/overlay/local/lib/python3.10/dist-packages/setuptools/build_meta.py", line 325, in get_requires_for_build_wheel
return self._get_build_requires(config_settings, requirements=['wheel'])
File "/tmp/pip-build-env-4xoxai9j/overlay/local/lib/python3.10/dist-packages/setuptools/build_meta.py", line 295, in _get_build_requires
self.run_setup()
File "/tmp/pip-build-env-4xoxai9j/overlay/local/lib/python3.10/dist-packages/setuptools/build_meta.py", line 311, in run_setup
exec(code, locals())
File "", line 297, in
File "", line 267, in get_vllm_version
NameError: name 'nvcc_cuda_version' is not defined. Did you mean: 'cuda_version'?
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error
× Getting requirements to build editable did not run successfully.
│ exit code: 1
╰─> See above for output.
note: This error originates from a subprocess, and is likely not a problem with pip.
[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: python -m pip install --upgrade pip
The text was updated successfully, but these errors were encountered: