mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
Merge pull request #5 from JezzWTF/refactor-load-model-1716154043227557412
🧹 Refactor _load_model_sync to improve code readability
This commit is contained in:
+51
-38
@@ -155,6 +155,54 @@ def _download_voices() -> None:
|
|||||||
|
|
||||||
# ── Background model loader ─────────────────────────────────────────────────────
|
# ── Background model loader ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _init_processor():
|
||||||
|
logger.info("Loading processor...")
|
||||||
|
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||||
|
VibeVoiceStreamingProcessor,
|
||||||
|
)
|
||||||
|
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_model(device: str):
|
||||||
|
logger.info("Loading model on %s...", device)
|
||||||
|
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||||
|
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
|
||||||
|
|
||||||
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||||
|
VibeVoiceStreamingForConditionalGenerationInference,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_impl,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True)
|
||||||
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation="sdpa",
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.set_ddpm_inference_steps(num_steps=10)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _load_voice_presets(device: str) -> dict[str, object]:
|
||||||
|
presets = {}
|
||||||
|
for name, filename in EN_VOICES.items():
|
||||||
|
path = VOICES_DIR / filename
|
||||||
|
if path.exists():
|
||||||
|
presets[name] = torch.load(
|
||||||
|
path, map_location=device, weights_only=False
|
||||||
|
)
|
||||||
|
return presets
|
||||||
|
|
||||||
|
|
||||||
def _load_model_sync() -> None:
|
def _load_model_sync() -> None:
|
||||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
||||||
|
|
||||||
@@ -174,44 +222,9 @@ def _load_model_sync() -> None:
|
|||||||
_device = _resolve_device()
|
_device = _resolve_device()
|
||||||
logger.info("Using device: %s", _device)
|
logger.info("Using device: %s", _device)
|
||||||
|
|
||||||
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
|
_processor = _init_processor()
|
||||||
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
|
_model = _init_model(_device)
|
||||||
|
_voice_presets = _load_voice_presets(_device)
|
||||||
logger.info("Loading processor...")
|
|
||||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
|
||||||
VibeVoiceStreamingProcessor,
|
|
||||||
)
|
|
||||||
_processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
|
||||||
|
|
||||||
logger.info("Loading model on %s...", _device)
|
|
||||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
|
||||||
VibeVoiceStreamingForConditionalGenerationInference,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
||||||
MODEL_ID,
|
|
||||||
torch_dtype=load_dtype,
|
|
||||||
device_map=_device,
|
|
||||||
attn_implementation=attn_impl,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
|
|
||||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
||||||
MODEL_ID,
|
|
||||||
torch_dtype=load_dtype,
|
|
||||||
device_map=_device,
|
|
||||||
attn_implementation="sdpa",
|
|
||||||
)
|
|
||||||
|
|
||||||
_model.eval()
|
|
||||||
_model.set_ddpm_inference_steps(num_steps=10)
|
|
||||||
|
|
||||||
for name, filename in EN_VOICES.items():
|
|
||||||
path = VOICES_DIR / filename
|
|
||||||
if path.exists():
|
|
||||||
_voice_presets[name] = torch.load(
|
|
||||||
path, map_location=_device, weights_only=False
|
|
||||||
)
|
|
||||||
|
|
||||||
_model_status = "online"
|
_model_status = "online"
|
||||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||||
|
|||||||
Reference in New Issue
Block a user