Replace MIOpen convolution path with torch.compile on s3gen
All checks were successful
Build ROCm Image / build (push) Successful in 2m47s

The GemmFwdRest workspace=0 issue is in MIOpen itself — PyTorch's ROCm
backend does not allocate workspace for convolutions, causing HiFiGAN to
use a slow fallback solver regardless of benchmark settings.

torch.compile(s3gen, dynamic=True) replaces MIOpen's conv path with
Triton-generated kernels, bypassing the issue entirely. dynamic=True
handles variable audio lengths without recompiling per request. The warmup
triggers JIT compilation so first HA request is fast.

Also removes fp16 autocast (Triton handles precision internally) and
cudnn.benchmark (no longer needed without MIOpen convs).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 14:27:09 -04:00
parent cd33b1c161
commit 7babd0584e

View File

@@ -46,13 +46,18 @@ def load_model() -> bool:
_sample_rate = 24000 _sample_rate = 24000
# Enable MIOpen algorithm benchmarking. Without this, PyTorch picks
# convolution algorithms heuristically and passes ptr=0/size=0 workspace
# to MIOpen, forcing a slow fallback on every conv op. With benchmark=True,
# PyTorch evaluates algorithms with proper workspace on first use and caches
# the best result (persisted via MIOPEN_USER_DB_PATH volume mount).
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True # torch.compile replaces MIOpen's convolution path with Triton-generated
# kernels, bypassing the workspace=0 fallback entirely. We compile only
# s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck.
# suppress_errors=True falls back to eager for any op compile can't handle.
try:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True)
logger.info("s3gen compiled with torch.compile")
except Exception:
logger.warning("torch.compile unavailable, running s3gen in eager mode", exc_info=True)
_patch_timing(chatterbox_model) _patch_timing(chatterbox_model)
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
@@ -122,13 +127,8 @@ def synthesize(
kwargs["exaggeration"] = exaggeration kwargs["exaggeration"] = exaggeration
kwargs["cfg_weight"] = cfg_weight kwargs["cfg_weight"] = cfg_weight
try: with torch.inference_mode():
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16, enabled=torch.cuda.is_available()): wav = chatterbox_model.generate(text=text, **kwargs)
wav = chatterbox_model.generate(text=text, **kwargs)
except Exception:
logger.warning("fp16 autocast failed, retrying in fp32", exc_info=True)
with torch.inference_mode():
wav = chatterbox_model.generate(text=text, **kwargs)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()