mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
Merge pull request #5 from JezzWTF/refactor-load-model-1716154043227557412
🧹 Refactor _load_model_sync to improve code readability
This commit is contained in:
+51
-38
@@ -155,6 +155,54 @@ def _download_voices() -> None:
|
||||
|
||||
# ── Background model loader ─────────────────────────────────────────────────────
|
||||
|
||||
def _init_processor():
|
||||
logger.info("Loading processor...")
|
||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||
VibeVoiceStreamingProcessor,
|
||||
)
|
||||
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
||||
|
||||
|
||||
def _init_model(device: str):
|
||||
logger.info("Loading model on %s...", device)
|
||||
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
|
||||
|
||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||
VibeVoiceStreamingForConditionalGenerationInference,
|
||||
)
|
||||
try:
|
||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=device,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True)
|
||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=device,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
model.eval()
|
||||
model.set_ddpm_inference_steps(num_steps=10)
|
||||
return model
|
||||
|
||||
|
||||
def _load_voice_presets(device: str) -> dict[str, object]:
|
||||
presets = {}
|
||||
for name, filename in EN_VOICES.items():
|
||||
path = VOICES_DIR / filename
|
||||
if path.exists():
|
||||
presets[name] = torch.load(
|
||||
path, map_location=device, weights_only=False
|
||||
)
|
||||
return presets
|
||||
|
||||
|
||||
def _load_model_sync() -> None:
|
||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
||||
|
||||
@@ -174,44 +222,9 @@ def _load_model_sync() -> None:
|
||||
_device = _resolve_device()
|
||||
logger.info("Using device: %s", _device)
|
||||
|
||||
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
|
||||
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
|
||||
|
||||
logger.info("Loading processor...")
|
||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||
VibeVoiceStreamingProcessor,
|
||||
)
|
||||
_processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
||||
|
||||
logger.info("Loading model on %s...", _device)
|
||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||
VibeVoiceStreamingForConditionalGenerationInference,
|
||||
)
|
||||
try:
|
||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=_device,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
|
||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=_device,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
_model.eval()
|
||||
_model.set_ddpm_inference_steps(num_steps=10)
|
||||
|
||||
for name, filename in EN_VOICES.items():
|
||||
path = VOICES_DIR / filename
|
||||
if path.exists():
|
||||
_voice_presets[name] = torch.load(
|
||||
path, map_location=_device, weights_only=False
|
||||
)
|
||||
_processor = _init_processor()
|
||||
_model = _init_model(_device)
|
||||
_voice_presets = _load_voice_presets(_device)
|
||||
|
||||
_model_status = "online"
|
||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||
|
||||
Reference in New Issue
Block a user