Replace MIOpen convolution path with torch.compile on s3gen
All checks were successful
Build ROCm Image / build (push) Successful in 2m47s
All checks were successful
Build ROCm Image / build (push) Successful in 2m47s
The GemmFwdRest workspace=0 issue is in MIOpen itself — PyTorch's ROCm backend does not allocate workspace for convolutions, causing HiFiGAN to use a slow fallback solver regardless of benchmark settings. torch.compile(s3gen, dynamic=True) replaces MIOpen's conv path with Triton-generated kernels, bypassing the issue entirely. dynamic=True handles variable audio lengths without recompiling per request. The warmup triggers JIT compilation so first HA request is fast. Also removes fp16 autocast (Triton handles precision internally) and cudnn.benchmark (no longer needed without MIOpen convs). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
22
engine.py
22
engine.py
@@ -46,13 +46,18 @@ def load_model() -> bool:
|
|||||||
|
|
||||||
_sample_rate = 24000
|
_sample_rate = 24000
|
||||||
|
|
||||||
# Enable MIOpen algorithm benchmarking. Without this, PyTorch picks
|
|
||||||
# convolution algorithms heuristically and passes ptr=0/size=0 workspace
|
|
||||||
# to MIOpen, forcing a slow fallback on every conv op. With benchmark=True,
|
|
||||||
# PyTorch evaluates algorithms with proper workspace on first use and caches
|
|
||||||
# the best result (persisted via MIOPEN_USER_DB_PATH volume mount).
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.backends.cudnn.benchmark = True
|
# torch.compile replaces MIOpen's convolution path with Triton-generated
|
||||||
|
# kernels, bypassing the workspace=0 fallback entirely. We compile only
|
||||||
|
# s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck.
|
||||||
|
# suppress_errors=True falls back to eager for any op compile can't handle.
|
||||||
|
try:
|
||||||
|
import torch._dynamo
|
||||||
|
torch._dynamo.config.suppress_errors = True
|
||||||
|
chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True)
|
||||||
|
logger.info("s3gen compiled with torch.compile")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("torch.compile unavailable, running s3gen in eager mode", exc_info=True)
|
||||||
|
|
||||||
_patch_timing(chatterbox_model)
|
_patch_timing(chatterbox_model)
|
||||||
logger.info("Model loaded successfully")
|
logger.info("Model loaded successfully")
|
||||||
@@ -122,11 +127,6 @@ def synthesize(
|
|||||||
kwargs["exaggeration"] = exaggeration
|
kwargs["exaggeration"] = exaggeration
|
||||||
kwargs["cfg_weight"] = cfg_weight
|
kwargs["cfg_weight"] = cfg_weight
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16, enabled=torch.cuda.is_available()):
|
|
||||||
wav = chatterbox_model.generate(text=text, **kwargs)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("fp16 autocast failed, retrying in fp32", exc_info=True)
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
wav = chatterbox_model.generate(text=text, **kwargs)
|
wav = chatterbox_model.generate(text=text, **kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user