diff --git a/engine.py b/engine.py index 557b2c1..e1691f7 100644 --- a/engine.py +++ b/engine.py @@ -1,6 +1,7 @@ import logging import time import torch +import torch._dynamo from typing import Optional, Tuple logger = logging.getLogger(__name__) @@ -52,7 +53,6 @@ def load_model() -> bool: # s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck. # suppress_errors=True falls back to eager for any op compile can't handle. try: - import torch._dynamo torch._dynamo.config.suppress_errors = True chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True) logger.info("s3gen compiled with torch.compile")