Skip to content

Commit

Permalink
Merge branch 'SWivid:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr authored Oct 20, 2024
2 parents 33679b9 + b4f8142 commit 6727245
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion model/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def sample(
):
self.eval()

cond = cond.half()
if cond.device != torch.device('cpu'):
cond = cond.half()

# raw wave

Expand Down
3 changes: 2 additions & 1 deletion model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,8 @@ def repetition_found(text, length = 2, tolerance = 10):
# load model checkpoint for inference

def load_checkpoint(model, ckpt_path, device, use_ema = True):
model = model.half()
if device != "cpu":
model = model.half()

ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
Expand Down

0 comments on commit 6727245

Please sign in to comment.