diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 9ecc28e..9523eba 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -155,6 +155,54 @@ def _download_voices() -> None: # ── Background model loader ───────────────────────────────────────────────────── +def _init_processor(): + logger.info("Loading processor...") + from vibevoice.processor.vibevoice_streaming_processor import ( + VibeVoiceStreamingProcessor, + ) + return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID) + + +def _init_model(device: str): + logger.info("Loading model on %s...", device) + load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 + attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" + + from vibevoice.modular.modeling_vibevoice_streaming_inference import ( + VibeVoiceStreamingForConditionalGenerationInference, + ) + try: + model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + MODEL_ID, + torch_dtype=load_dtype, + device_map=device, + attn_implementation=attn_impl, + ) + except Exception: + logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True) + model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + MODEL_ID, + torch_dtype=load_dtype, + device_map=device, + attn_implementation="sdpa", + ) + + model.eval() + model.set_ddpm_inference_steps(num_steps=10) + return model + + +def _load_voice_presets(device: str) -> dict[str, object]: + presets = {} + for name, filename in EN_VOICES.items(): + path = VOICES_DIR / filename + if path.exists(): + presets[name] = torch.load( + path, map_location=device, weights_only=False + ) + return presets + + def _load_model_sync() -> None: global _processor, _model, _device, _model_status, _model_error, _voice_presets @@ -174,44 +222,9 @@ def _load_model_sync() -> None: _device = _resolve_device() logger.info("Using device: %s", _device) - load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32 - attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa" - - logger.info("Loading processor...") - from vibevoice.processor.vibevoice_streaming_processor import ( - VibeVoiceStreamingProcessor, - ) - _processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID) - - logger.info("Loading model on %s...", _device) - from vibevoice.modular.modeling_vibevoice_streaming_inference import ( - VibeVoiceStreamingForConditionalGenerationInference, - ) - try: - _model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( - MODEL_ID, - torch_dtype=load_dtype, - device_map=_device, - attn_implementation=attn_impl, - ) - except Exception: - logger.warning("flash_attention_2 unavailable, falling back to sdpa") - _model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( - MODEL_ID, - torch_dtype=load_dtype, - device_map=_device, - attn_implementation="sdpa", - ) - - _model.eval() - _model.set_ddpm_inference_steps(num_steps=10) - - for name, filename in EN_VOICES.items(): - path = VOICES_DIR / filename - if path.exists(): - _voice_presets[name] = torch.load( - path, map_location=_device, weights_only=False - ) + _processor = _init_processor() + _model = _init_model(_device) + _voice_presets = _load_voice_presets(_device) _model_status = "online" logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))