From 2930f03d6ce7af0a118c642d75e9c029b61093c2 Mon Sep 17 00:00:00 2001 From: asagi4 <130366179+asagi4@users.noreply.github.com> Date: Thu, 22 Aug 2024 21:38:02 +0300 Subject: [PATCH] Make sure that LoRAs are loaded to the correct device --- prompt_control/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prompt_control/utils.py b/prompt_control/utils.py index b59622a..07f2bd3 100644 --- a/prompt_control/utils.py +++ b/prompt_control/utils.py @@ -173,10 +173,10 @@ def _patch_model(model, forget=False, orig=None, offload_to_cpu=False): if offload_to_cpu: saved_offload = model.offload_device model.offload_device = torch.device("cpu") - log.info("Patching model, cpu_offload=%s", model.offload_device == torch.device("cpu")) + log.info("Patching model, model.load_device=%s model.model.device=%s cpu_offload=%s", model.load_device, model.model.device, model.offload_device == torch.device("cpu")) if orig: model.backup = orig.backup - model.patch_model() + model.patch_model(device_to=model.load_device) if offload_to_cpu: model.offload_device = saved_offload if forget: