[dev-fp16] Only convert T3 to fp16, leave s3gen/ve in fp32
All checks were successful
Build ROCm Image / build (push) Successful in 2m39s
All checks were successful
Build ROCm Image / build (push) Successful in 2m39s
s3gen.speaker_encoder (CAMPPlus xvector) hardcodes float32 inputs in its inference() method, causing dtype mismatch when weights are fp16. T3 (the autoregressive GPT-2-medium LLM) has no such constraint and is the token-generation bottleneck that benefits most from fp16. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
15
engine.py
15
engine.py
@@ -51,17 +51,16 @@ def load_model() -> bool:
|
|||||||
|
|
||||||
_sample_rate = 24000
|
_sample_rate = 24000
|
||||||
|
|
||||||
# Convert weights to fp16. Done once at load time so the warmup
|
# Convert T3 (the autoregressive LLM) to fp16 for faster token generation.
|
||||||
# covers the right dtypes and there's no per-call casting overhead.
|
# s3gen and ve are left in fp32 — s3gen.speaker_encoder (CAMPPlus xvector)
|
||||||
|
# hardcodes float32 inputs in its inference() method and errors on fp16 weights.
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
for attr in ("t3", "s3gen", "ve"):
|
if hasattr(chatterbox_model, "t3"):
|
||||||
m = getattr(chatterbox_model, attr, None)
|
chatterbox_model.t3.half()
|
||||||
if m is not None:
|
logger.info("T3 converted to fp16")
|
||||||
m.half()
|
|
||||||
logger.info("Model converted to fp16")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("fp16 conversion failed, running in fp32", exc_info=True)
|
logger.warning("T3 fp16 conversion failed, running in fp32", exc_info=True)
|
||||||
|
|
||||||
logger.info("Model loaded successfully")
|
logger.info("Model loaded successfully")
|
||||||
return True
|
return True
|
||||||
|
|||||||
Reference in New Issue
Block a user