Skip to content

Commit

Permalink
fix: use torch to detect GPU number instead of pynvml
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed Dec 10, 2024
1 parent 3e8c0ad commit 6a8d968
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions flashrag/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def _init_device(self):
if gpu_id is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
try:
import pynvml
pynvml.nvmlInit()
gpu_num = pynvml.nvmlDeviceGetCount()
# import pynvml
# pynvml.nvmlInit()
# gpu_num = pynvml.nvmlDeviceGetCount()
import torch
gpu_num = torch.cuda.device_count()
except:
gpu_num = 0
self.final_config['gpu_num'] = gpu_num
Expand Down

0 comments on commit 6a8d968

Please sign in to comment.