🧹 Refactor model loading in vibevoice_server.py

🎯 What: Extracted inline model loading logic from `_load_model_sync` into distinct helper functions (`_init_processor`, `_init_model`, and `_load_voice_presets`).
💡 Why: This significantly reduces the complexity of `_load_model_sync`, making the code easier to read and maintain.
 Verification: Ran a syntax check (`python -m py_compile`), started the backend server with CPU inference, and verified the model initialized and correctly processed a text-to-speech request to the `/generate` endpoint without regressions.
 Result: Improved code modularity while preserving identical behavior.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>
This commit is contained in:
google-labs-jules[bot]
2026-04-28 16:35:26 +00:00
parent 59d3280cb5
commit 09d9727c20
+51 -38
View File
@@ -155,6 +155,54 @@ def _download_voices() -> None:
# ── Background model loader ───────────────────────────────────────────────────── # ── Background model loader ─────────────────────────────────────────────────────
def _init_processor():
logger.info("Loading processor...")
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
def _init_model(device: str):
logger.info("Loading model on %s...", device)
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
try:
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
attn_implementation=attn_impl,
)
except Exception:
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
attn_implementation="sdpa",
)
model.eval()
model.set_ddpm_inference_steps(num_steps=10)
return model
def _load_voice_presets(device: str) -> dict[str, object]:
presets = {}
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
presets[name] = torch.load(
path, map_location=device, weights_only=False
)
return presets
def _load_model_sync() -> None: def _load_model_sync() -> None:
global _processor, _model, _device, _model_status, _model_error, _voice_presets global _processor, _model, _device, _model_status, _model_error, _voice_presets
@@ -174,44 +222,9 @@ def _load_model_sync() -> None:
_device = _resolve_device() _device = _resolve_device()
logger.info("Using device: %s", _device) logger.info("Using device: %s", _device)
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32 _processor = _init_processor()
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa" _model = _init_model(_device)
_voice_presets = _load_voice_presets(_device)
logger.info("Loading processor...")
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
_processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
logger.info("Loading model on %s...", _device)
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
try:
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=_device,
attn_implementation=attn_impl,
)
except Exception:
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=_device,
attn_implementation="sdpa",
)
_model.eval()
_model.set_ddpm_inference_steps(num_steps=10)
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
_voice_presets[name] = torch.load(
path, map_location=_device, weights_only=False
)
_model_status = "online" _model_status = "online"
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())) logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))