[dev-fp16] Convert model weights to fp16 at load time

Converting t3/s3gen/ve to fp16 once at load time means:
- Warmup runs in fp16, covering the right dtypes for all real requests
- No per-call autocast casting overhead
- ~2x faster matrix ops and convolutions on RDNA 2 hardware

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 20:34:33 -04:00
parent 967ed41239
commit 9b62fce5c5

View File

@@ -50,6 +50,19 @@ def load_model() -> bool:
_is_turbo = False _is_turbo = False
_sample_rate = 24000 _sample_rate = 24000
# Convert weights to fp16. Done once at load time so the warmup
# covers the right dtypes and there's no per-call casting overhead.
if torch.cuda.is_available():
try:
for attr in ("t3", "s3gen", "ve"):
m = getattr(chatterbox_model, attr, None)
if m is not None:
m.half()
logger.info("Model converted to fp16")
except Exception:
logger.warning("fp16 conversion failed, running in fp32", exc_info=True)
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
return True return True
except Exception: except Exception: