From 51188ca973f9523327e7453ae777d8e61e10ca1e Mon Sep 17 00:00:00 2001 From: scott Date: Sun, 5 Apr 2026 20:41:24 -0400 Subject: [PATCH] [dev-fp16] Only convert T3 to fp16, leave s3gen/ve in fp32 s3gen.speaker_encoder (CAMPPlus xvector) hardcodes float32 inputs in its inference() method, causing dtype mismatch when weights are fp16. T3 (the autoregressive GPT-2-medium LLM) has no such constraint and is the token-generation bottleneck that benefits most from fp16. Co-Authored-By: Claude Sonnet 4.6 --- engine.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/engine.py b/engine.py index a66d95f..ee2a572 100644 --- a/engine.py +++ b/engine.py @@ -51,17 +51,16 @@ def load_model() -> bool: _sample_rate = 24000 - # Convert weights to fp16. Done once at load time so the warmup - # covers the right dtypes and there's no per-call casting overhead. + # Convert T3 (the autoregressive LLM) to fp16 for faster token generation. + # s3gen and ve are left in fp32 — s3gen.speaker_encoder (CAMPPlus xvector) + # hardcodes float32 inputs in its inference() method and errors on fp16 weights. if torch.cuda.is_available(): try: - for attr in ("t3", "s3gen", "ve"): - m = getattr(chatterbox_model, attr, None) - if m is not None: - m.half() - logger.info("Model converted to fp16") + if hasattr(chatterbox_model, "t3"): + chatterbox_model.t3.half() + logger.info("T3 converted to fp16") except Exception: - logger.warning("fp16 conversion failed, running in fp32", exc_info=True) + logger.warning("T3 fp16 conversion failed, running in fp32", exc_info=True) logger.info("Model loaded successfully") return True