diff --git a/model/cfm.py b/model/cfm.py index fdf42937..f70b097f 100644 --- a/model/cfm.py +++ b/model/cfm.py @@ -96,7 +96,8 @@ def sample( ): self.eval() - cond = cond.half() + if cond.device != torch.device('cpu'): + cond = cond.half() # raw wave diff --git a/model/utils.py b/model/utils.py index e6494a22..ae64b0c6 100644 --- a/model/utils.py +++ b/model/utils.py @@ -555,7 +555,8 @@ def repetition_found(text, length = 2, tolerance = 10): # load model checkpoint for inference def load_checkpoint(model, ckpt_path, device, use_ema = True): - model = model.half() + if device != "cpu": + model = model.half() ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors":