diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 9a00a4cc0..e22b2ae80 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -122,7 +122,10 @@ def _load( assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' gpt.load_state_dict(torch.load(gpt_ckpt_path)) if compile and 'cuda' in str(device): - gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True) + try: + gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True) + except RuntimeError as e: + logging.warning(f'Compile failed,{e}. fallback to normal mode.') self.pretrain_models['gpt'] = gpt spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt') assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'