diff --git a/engine.py b/engine.py index 1cafcb2..b1f6cbb 100644 --- a/engine.py +++ b/engine.py @@ -122,7 +122,7 @@ def synthesize( kwargs["exaggeration"] = exaggeration kwargs["cfg_weight"] = cfg_weight - with torch.inference_mode(): + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): wav = chatterbox_model.generate(text=text, **kwargs) if torch.cuda.is_available():