Switch to ROCm 6.1 + torch 2.5.1 to fix MIOpen workspace=0 slowness
Some checks failed
Build ROCm Image / build (push) Failing after 11s

ROCm 7.2 + PyTorch 2.11.0 has a bug where PyTorch passes workspace=0 to
MIOpen convolutions, forcing fallback to the slow GemmFwdRest solver.
This caused s3gen.inference to take 15-22s instead of <5s, making
synthesis 3-4x slower than real-time audio playback.

ROCm 6.1 allocates workspace correctly so MIOpen picks fast GEMM solvers
without needing torch.compile workarounds.

Changes:
- Base image: rocm/dev-ubuntu-22.04:7.2 → 6.1
- torch 2.11.0 → 2.5.1 (rocm6.1 wheel index)
- Add pytorch_triton_rocm==3.1.0
- transformers 5.2.0 → 4.46.3, safetensors 0.5.3 → 0.4.0
- s3tokenizer unpinned → 0.3.0
- resemble-perth==1.0.1 directly (v1.0.1 is pip-installable; drop stub)
- Drop Dockerfile perth_stub steps
- Drop torch.compile and timing patches from engine.py (not needed)
- Drop multi-pass warmup from main.py (torch JIT warmup not needed)
- Drop ROCm 7.2-specific env vars from docker-compose.yml

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 17:27:21 -04:00
parent 23a0b914fa
commit 8de67c8bd9
6 changed files with 18 additions and 100 deletions

View File

@@ -1,4 +1,4 @@
FROM rocm/dev-ubuntu-22.04:7.2 FROM rocm/dev-ubuntu-22.04:6.1
ENV DEBIAN_FRONTEND=noninteractive \ ENV DEBIAN_FRONTEND=noninteractive \
PYTHONDONTWRITEBYTECODE=1 \ PYTHONDONTWRITEBYTECODE=1 \
@@ -29,12 +29,6 @@ RUN pip3 install -r requirements-rocm.txt
# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch. # Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch.
RUN pip3 install --no-deps chatterbox-tts RUN pip3 install --no-deps chatterbox-tts
# Stub out resemble-perth (audio watermarking, unnecessary for self-hosted use).
# Placed after pip layers so changes to the stub don't bust the cache above.
COPY perth_stub.py .
RUN python3 -c "import site; print(site.getsitepackages()[0])" | \
xargs -I{} cp /app/perth_stub.py {}/perth.py
# Application source # Application source
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./ COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./

View File

@@ -28,18 +28,10 @@ services:
- hf_cache:/app/hf_cache - hf_cache:/app/hf_cache
environment: environment:
- HF_HUB_ENABLE_HF_TRANSFER=1 - HF_HUB_ENABLE_HF_TRANSFER=1
# Required for RX 6700 XT (gfx1031) - not natively supported in ROCm 7.2. # Required for RX 6700 XT (gfx1031) - not natively supported in ROCm.
- HSA_OVERRIDE_GFX_VERSION=10.3.0 - HSA_OVERRIDE_GFX_VERSION=10.3.0
# Disable MIOpen's SQLite cache — avoids crashes writing benchmark results. # Disable MIOpen's SQLite cache — avoids crashes writing benchmark results.
# PyTorch's in-memory benchmark cache still applies within a container run.
- MIOPEN_DISABLE_CACHE=1 - MIOPEN_DISABLE_CACHE=1
# Disable MLIR-based ImplicitGEMM solvers. These compile MLIR kernels on the
# fly and hit 'too many open files' during the exhaustive benchmark search.
- MIOPEN_DEBUG_CONV_IMPLICIT_GEMM=0
# Suppress MIOpen workspace=0 fallback warnings (errors still shown).
# Levels: 0=quiet 1=fatal 2=error 3=warning(default) 4=info 5=debug
- MIOPEN_LOG_LEVEL=2
# - HF_TOKEN=your_token_here
volumes: volumes:
hf_cache: hf_cache:

View File

@@ -1,7 +1,5 @@
import logging import logging
import time
import torch import torch
import torch._dynamo
from typing import Optional, Tuple from typing import Optional, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,23 +44,6 @@ def load_model() -> bool:
_is_turbo = False _is_turbo = False
_sample_rate = 24000 _sample_rate = 24000
if torch.cuda.is_available():
# torch.compile replaces MIOpen's convolution path with Triton-generated
# kernels, bypassing the workspace=0 fallback entirely. We compile s3gen
# (HiFiGAN vocoder + flow matching) and the voice encoder (ve) since both
# are convolutional and hit the workspace=0 bug.
# suppress_errors=True falls back to eager for any op compile can't handle.
torch._dynamo.config.suppress_errors = True
for attr, label in [("s3gen", "s3gen"), ("ve", "ve")]:
try:
obj = getattr(chatterbox_model, attr)
setattr(chatterbox_model, attr, torch.compile(obj, dynamic=True))
logger.info(f"{label} compiled with torch.compile")
except Exception:
logger.warning(f"torch.compile unavailable for {label}, running in eager mode", exc_info=True)
_patch_timing(chatterbox_model)
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
return True return True
except Exception: except Exception:
@@ -70,36 +51,6 @@ def load_model() -> bool:
return False return False
def _patch_timing(model) -> None:
"""Wrap key sub-model forward() calls with timing logs."""
def _wrap(obj, method_name, label):
original = getattr(obj, method_name)
def timed(*args, **kwargs):
t0 = time.monotonic()
result = original(*args, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s")
return result
setattr(obj, method_name, timed)
try:
# S3 tokenizer — processes reference audio through a conformer
_wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)")
except AttributeError:
pass
try:
# Speaker/voice encoder — xvector embedding from reference audio
_wrap(model.ve, "forward", "ve (speaker embedding)")
except AttributeError:
pass
try:
# S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav)
_wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)")
except AttributeError:
pass
def get_sample_rate() -> int: def get_sample_rate() -> int:
return _sample_rate return _sample_rate

31
main.py
View File

@@ -18,33 +18,16 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WARMUP_TEXTS = [
# Short: covers brief HA notifications (lights on/off, etc.)
"Okay.",
# Medium: covers typical HA announcements
"The front door is open. Please close it.",
# Long: covers longer TTS requests and pre-compiles dynamic shape graph
(
"This is a warmup synthesis to pre-compile neural network kernels "
"for longer text lengths used in Home Assistant announcements and notifications."
),
]
def _warmup(voices: dict) -> None: def _warmup(voices: dict) -> None:
"""Run one synthesis to populate MIOpen's in-memory kernel cache."""
from wyoming_voices import resolve_voice from wyoming_voices import resolve_voice
audio_prompt = resolve_voice(None, voices) if voices else None audio_prompt = resolve_voice(None, voices) if voices else None
logger.info( logger.info("Running warmup synthesis...")
f"Running {len(_WARMUP_TEXTS)}-pass warmup to pre-compile torch kernels "
"for short, medium, and long text lengths..."
)
for i, text in enumerate(_WARMUP_TEXTS, 1):
try: try:
engine.synthesize(text=text, audio_prompt_path=audio_prompt) engine.synthesize(text="Warmup.", audio_prompt_path=audio_prompt)
logger.info(f"Warmup pass {i}/{len(_WARMUP_TEXTS)} complete")
except Exception:
logger.warning(f"Warmup pass {i} failed (non-fatal)", exc_info=True)
logger.info("Warmup complete") logger.info("Warmup complete")
except Exception:
logger.warning("Warmup synthesis failed (non-fatal)", exc_info=True)
async def main() -> None: async def main() -> None:
@@ -58,10 +41,6 @@ async def main() -> None:
voices = load_voices() voices = load_voices()
wyoming_info = create_wyoming_info(engine.get_sample_rate(), voices) wyoming_info = create_wyoming_info(engine.get_sample_rate(), voices)
# Run a warmup synthesis before accepting connections so MIOpen benchmarks
# and caches the best convolution algorithms for all layer shapes. Without
# this, the first real HA request triggers benchmarking (hundreds of runs)
# and times out before any audio is returned.
_warmup(voices) _warmup(voices)
host = get_wyoming_host() host = get_wyoming_host()

View File

@@ -1,3 +1,5 @@
--index-url https://download.pytorch.org/whl/rocm7.2 --index-url https://download.pytorch.org/whl/rocm6.1
torch==2.11.0 torch==2.5.1
torchaudio==2.11.0 torchaudio==2.5.1
torchvision==0.20.1
pytorch_triton_rocm==3.1.0

View File

@@ -5,18 +5,18 @@ librosa==0.11.0
pyloudnorm pyloudnorm
# ML dependencies (pinned to match chatterbox without overwriting ROCm torch) # ML dependencies (pinned to match chatterbox without overwriting ROCm torch)
transformers==5.2.0 transformers==4.46.3
diffusers==0.29.0 diffusers==0.29.0
safetensors==0.5.3 safetensors==0.4.0
huggingface-hub huggingface-hub
omegaconf omegaconf
# Chatterbox dependencies (installed separately since chatterbox uses --no-deps) # Chatterbox dependencies (installed separately since chatterbox uses --no-deps)
# Note: resemble-perth is stubbed out in perth_stub.py (watermarking unneeded for self-hosted use)
conformer==0.3.2 conformer==0.3.2
s3tokenizer s3tokenizer==0.3.0
spacy-pkuseg spacy-pkuseg
pykakasi==2.3.0 pykakasi==2.3.0
resemble-perth==1.0.1
# Wyoming protocol # Wyoming protocol
wyoming>=1.5.4 wyoming>=1.5.4