diff --git a/shepardtts/utils.py b/shepardtts/utils.py index c4b87d5..b91777a 100644 --- a/shepardtts/utils.py +++ b/shepardtts/utils.py @@ -33,7 +33,7 @@ def load_checkpoint() -> ShepardXtts: import intel_extension_for_pytorch as ipex model.train(mode=False) - model = ipex.optimize(model, weights_prepack=False) + model = ipex.optimize(model, weights_prepack=False, dtype=torch.float16) model = torch.compile(model, backend="ipex") return model