From 09d9727c20acf89b7069e98518e1efd00a1b2e3a Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 16:35:26 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=A7=B9=20Refactor=20model=20loading?= =?UTF-8?q?=20in=20vibevoice=5Fserver.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🎯 What: Extracted inline model loading logic from `_load_model_sync` into distinct helper functions (`_init_processor`, `_init_model`, and `_load_voice_presets`). 💡 Why: This significantly reduces the complexity of `_load_model_sync`, making the code easier to read and maintain. ✅ Verification: Ran a syntax check (`python -m py_compile`), started the backend server with CPU inference, and verified the model initialized and correctly processed a text-to-speech request to the `/generate` endpoint without regressions. ✨ Result: Improved code modularity while preserving identical behavior. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> --- server/vibevoice_server.py | 89 ++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 9ecc28e..22b6972 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -155,6 +155,54 @@ def _download_voices() -> None: # ── Background model loader ───────────────────────────────────────────────────── +def _init_processor(): + logger.info("Loading processor...") + from vibevoice.processor.vibevoice_streaming_processor import ( + VibeVoiceStreamingProcessor, + ) + return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID) + + +def _init_model(device: str): + logger.info("Loading model on %s...", device) + load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 + attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" + + from vibevoice.modular.modeling_vibevoice_streaming_inference import ( + VibeVoiceStreamingForConditionalGenerationInference, + ) + try: + model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + MODEL_ID, + torch_dtype=load_dtype, + device_map=device, + attn_implementation=attn_impl, + ) + except Exception: + logger.warning("flash_attention_2 unavailable, falling back to sdpa") + model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + MODEL_ID, + torch_dtype=load_dtype, + device_map=device, + attn_implementation="sdpa", + ) + + model.eval() + model.set_ddpm_inference_steps(num_steps=10) + return model + + +def _load_voice_presets(device: str) -> dict[str, object]: + presets = {} + for name, filename in EN_VOICES.items(): + path = VOICES_DIR / filename + if path.exists(): + presets[name] = torch.load( + path, map_location=device, weights_only=False + ) + return presets + + def _load_model_sync() -> None: global _processor, _model, _device, _model_status, _model_error, _voice_presets @@ -174,44 +222,9 @@ def _load_model_sync() -> None: _device = _resolve_device() logger.info("Using device: %s", _device) - load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32 - attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa" - - logger.info("Loading processor...") - from vibevoice.processor.vibevoice_streaming_processor import ( - VibeVoiceStreamingProcessor, - ) - _processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID) - - logger.info("Loading model on %s...", _device) - from vibevoice.modular.modeling_vibevoice_streaming_inference import ( - VibeVoiceStreamingForConditionalGenerationInference, - ) - try: - _model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( - MODEL_ID, - torch_dtype=load_dtype, - device_map=_device, - attn_implementation=attn_impl, - ) - except Exception: - logger.warning("flash_attention_2 unavailable, falling back to sdpa") - _model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( - MODEL_ID, - torch_dtype=load_dtype, - device_map=_device, - attn_implementation="sdpa", - ) - - _model.eval() - _model.set_ddpm_inference_steps(num_steps=10) - - for name, filename in EN_VOICES.items(): - path = VOICES_DIR / filename - if path.exists(): - _voice_presets[name] = torch.load( - path, map_location=_device, weights_only=False - ) + _processor = _init_processor() + _model = _init_model(_device) + _voice_presets = _load_voice_presets(_device) _model_status = "online" logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())) From af85b444a7770f43c57af98331b4844ddca2ae43 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 08:08:17 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=A7=B9=20Refactor=20model=20loading?= =?UTF-8?q?=20in=20vibevoice=5Fserver.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🎯 What: Extracted inline model loading logic from `_load_model_sync` into distinct helper functions (`_init_processor`, `_init_model`, and `_load_voice_presets`). Added exc_info to model load exception logging. 💡 Why: This significantly reduces the complexity of `_load_model_sync`, making the code easier to read and maintain. Better logging helps diagnose initialization failures. ✅ Verification: Ran a syntax check (`python -m py_compile`), started the backend server with CPU inference, and verified the model initialized and correctly processed a text-to-speech request to the `/generate` endpoint without regressions. ✨ Result: Improved code modularity while preserving identical behavior. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> --- server/vibevoice_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 22b6972..9523eba 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -179,7 +179,7 @@ def _init_model(device: str): attn_implementation=attn_impl, ) except Exception: - logger.warning("flash_attention_2 unavailable, falling back to sdpa") + logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True) model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, torch_dtype=load_dtype,