diff --git a/engine.py b/engine.py index 00d4a05..dd42a88 100644 --- a/engine.py +++ b/engine.py @@ -1,4 +1,5 @@ import logging +import time import torch from typing import Optional, Tuple @@ -44,6 +45,7 @@ def load_model() -> bool: _is_turbo = False _sample_rate = 24000 + _patch_timing(chatterbox_model) logger.info("Model loaded successfully") return True except Exception: @@ -51,6 +53,36 @@ def load_model() -> bool: return False +def _patch_timing(model) -> None: + """Wrap key sub-model forward() calls with timing logs.""" + def _wrap(obj, method_name, label): + original = getattr(obj, method_name) + def timed(*args, **kwargs): + t0 = time.monotonic() + result = original(*args, **kwargs) + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s") + return result + setattr(obj, method_name, timed) + + try: + # S3 tokenizer — processes reference audio through a conformer + _wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)") + except AttributeError: + pass + try: + # Speaker/voice encoder — xvector embedding from reference audio + _wrap(model.voice_encoder, "forward", "voice_encoder (speaker embedding)") + except AttributeError: + pass + try: + # S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav) + _wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)") + except AttributeError: + pass + + def get_sample_rate() -> int: return _sample_rate