From 7babd0584e520f5f1872aeb8d73e9d9a9efe761a Mon Sep 17 00:00:00 2001 From: scott Date: Sun, 5 Apr 2026 14:27:09 -0400 Subject: [PATCH] Replace MIOpen convolution path with torch.compile on s3gen MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The GemmFwdRest workspace=0 issue is in MIOpen itself — PyTorch's ROCm backend does not allocate workspace for convolutions, causing HiFiGAN to use a slow fallback solver regardless of benchmark settings. torch.compile(s3gen, dynamic=True) replaces MIOpen's conv path with Triton-generated kernels, bypassing the issue entirely. dynamic=True handles variable audio lengths without recompiling per request. The warmup triggers JIT compilation so first HA request is fast. Also removes fp16 autocast (Triton handles precision internally) and cudnn.benchmark (no longer needed without MIOpen convs). Co-Authored-By: Claude Sonnet 4.6 --- engine.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/engine.py b/engine.py index 0b951d2..557b2c1 100644 --- a/engine.py +++ b/engine.py @@ -46,13 +46,18 @@ def load_model() -> bool: _sample_rate = 24000 - # Enable MIOpen algorithm benchmarking. Without this, PyTorch picks - # convolution algorithms heuristically and passes ptr=0/size=0 workspace - # to MIOpen, forcing a slow fallback on every conv op. With benchmark=True, - # PyTorch evaluates algorithms with proper workspace on first use and caches - # the best result (persisted via MIOPEN_USER_DB_PATH volume mount). if torch.cuda.is_available(): - torch.backends.cudnn.benchmark = True + # torch.compile replaces MIOpen's convolution path with Triton-generated + # kernels, bypassing the workspace=0 fallback entirely. We compile only + # s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck. + # suppress_errors=True falls back to eager for any op compile can't handle. + try: + import torch._dynamo + torch._dynamo.config.suppress_errors = True + chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True) + logger.info("s3gen compiled with torch.compile") + except Exception: + logger.warning("torch.compile unavailable, running s3gen in eager mode", exc_info=True) _patch_timing(chatterbox_model) logger.info("Model loaded successfully") @@ -122,13 +127,8 @@ def synthesize( kwargs["exaggeration"] = exaggeration kwargs["cfg_weight"] = cfg_weight - try: - with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16, enabled=torch.cuda.is_available()): - wav = chatterbox_model.generate(text=text, **kwargs) - except Exception: - logger.warning("fp16 autocast failed, retrying in fp32", exc_info=True) - with torch.inference_mode(): - wav = chatterbox_model.generate(text=text, **kwargs) + with torch.inference_mode(): + wav = chatterbox_model.generate(text=text, **kwargs) if torch.cuda.is_available(): torch.cuda.synchronize()