mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
🧹 Refactor model loading in vibevoice_server.py
🎯 What: Extracted inline model loading logic from `_load_model_sync` into distinct helper functions (`_init_processor`, `_init_model`, and `_load_voice_presets`). 💡 Why: This significantly reduces the complexity of `_load_model_sync`, making the code easier to read and maintain. ✅ Verification: Ran a syntax check (`python -m py_compile`), started the backend server with CPU inference, and verified the model initialized and correctly processed a text-to-speech request to the `/generate` endpoint without regressions. ✨ Result: Improved code modularity while preserving identical behavior. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>
This commit is contained in:
+51
-38
@@ -155,6 +155,54 @@ def _download_voices() -> None:
|
|||||||
|
|
||||||
# ── Background model loader ─────────────────────────────────────────────────────
|
# ── Background model loader ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _init_processor():
|
||||||
|
logger.info("Loading processor...")
|
||||||
|
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||||
|
VibeVoiceStreamingProcessor,
|
||||||
|
)
|
||||||
|
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_model(device: str):
|
||||||
|
logger.info("Loading model on %s...", device)
|
||||||
|
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||||
|
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
|
||||||
|
|
||||||
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||||
|
VibeVoiceStreamingForConditionalGenerationInference,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_impl,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
|
||||||
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation="sdpa",
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.set_ddpm_inference_steps(num_steps=10)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _load_voice_presets(device: str) -> dict[str, object]:
|
||||||
|
presets = {}
|
||||||
|
for name, filename in EN_VOICES.items():
|
||||||
|
path = VOICES_DIR / filename
|
||||||
|
if path.exists():
|
||||||
|
presets[name] = torch.load(
|
||||||
|
path, map_location=device, weights_only=False
|
||||||
|
)
|
||||||
|
return presets
|
||||||
|
|
||||||
|
|
||||||
def _load_model_sync() -> None:
|
def _load_model_sync() -> None:
|
||||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
||||||
|
|
||||||
@@ -174,44 +222,9 @@ def _load_model_sync() -> None:
|
|||||||
_device = _resolve_device()
|
_device = _resolve_device()
|
||||||
logger.info("Using device: %s", _device)
|
logger.info("Using device: %s", _device)
|
||||||
|
|
||||||
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
|
_processor = _init_processor()
|
||||||
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
|
_model = _init_model(_device)
|
||||||
|
_voice_presets = _load_voice_presets(_device)
|
||||||
logger.info("Loading processor...")
|
|
||||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
|
||||||
VibeVoiceStreamingProcessor,
|
|
||||||
)
|
|
||||||
_processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
|
||||||
|
|
||||||
logger.info("Loading model on %s...", _device)
|
|
||||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
|
||||||
VibeVoiceStreamingForConditionalGenerationInference,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
||||||
MODEL_ID,
|
|
||||||
torch_dtype=load_dtype,
|
|
||||||
device_map=_device,
|
|
||||||
attn_implementation=attn_impl,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
|
|
||||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
||||||
MODEL_ID,
|
|
||||||
torch_dtype=load_dtype,
|
|
||||||
device_map=_device,
|
|
||||||
attn_implementation="sdpa",
|
|
||||||
)
|
|
||||||
|
|
||||||
_model.eval()
|
|
||||||
_model.set_ddpm_inference_steps(num_steps=10)
|
|
||||||
|
|
||||||
for name, filename in EN_VOICES.items():
|
|
||||||
path = VOICES_DIR / filename
|
|
||||||
if path.exists():
|
|
||||||
_voice_presets[name] = torch.load(
|
|
||||||
path, map_location=_device, weights_only=False
|
|
||||||
)
|
|
||||||
|
|
||||||
_model_status = "online"
|
_model_status = "online"
|
||||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||||
|
|||||||
Reference in New Issue
Block a user