mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
09d9727c20
🎯 What: Extracted inline model loading logic from `_load_model_sync` into distinct helper functions (`_init_processor`, `_init_model`, and `_load_voice_presets`). 💡 Why: This significantly reduces the complexity of `_load_model_sync`, making the code easier to read and maintain. ✅ Verification: Ran a syntax check (`python -m py_compile`), started the backend server with CPU inference, and verified the model initialized and correctly processed a text-to-speech request to the `/generate` endpoint without regressions. ✨ Result: Improved code modularity while preserving identical behavior. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>
410 lines
14 KiB
Python
410 lines
14 KiB
Python
"""
|
|
VibePod — VibeVoice FastAPI TTS Server
|
|
|
|
Startup sequence (background thread):
|
|
1. Download model weights if not cached -> status: downloading
|
|
2. Download voice preset .pt files -> status: loading
|
|
3. Load processor + model into memory -> status: loading
|
|
4. Pre-load all voice tensors -> status: loading
|
|
-> Server ready -> status: online
|
|
|
|
Generation flow:
|
|
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
|
|
ends with {type:"complete"}
|
|
|
|
Device selection:
|
|
Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh).
|
|
Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available).
|
|
If unset, the server auto-detects: CUDA if available, otherwise CPU.
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import copy
|
|
import functools
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
import urllib.request
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import AsyncGenerator, Literal, Optional
|
|
|
|
import torch
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from tqdm import tqdm as _BaseTqdm
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
|
SAMPLE_RATE = 24_000
|
|
|
|
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
|
VOICE_BASE_URL = (
|
|
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
|
|
"/demo/voices/streaming_model"
|
|
)
|
|
|
|
EN_VOICES: dict[str, str] = {
|
|
"carter": "en-Carter_man.pt",
|
|
"davis": "en-Davis_man.pt",
|
|
"emma": "en-Emma_woman.pt",
|
|
"frank": "en-Frank_man.pt",
|
|
"grace": "en-Grace_woman.pt",
|
|
"mike": "en-Mike_man.pt",
|
|
}
|
|
DEFAULT_SPEAKER = "carter"
|
|
|
|
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
|
|
|
|
# ── Device selection ────────────────────────────────────────────────────────────
|
|
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
|
# Falls back to auto-detection if not set.
|
|
|
|
def _resolve_device() -> str:
|
|
"""Resolve the target device from env var or auto-detect."""
|
|
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
|
if env in ("cpu", "cuda"):
|
|
if env == "cuda" and not torch.cuda.is_available():
|
|
logger.warning(
|
|
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
|
|
)
|
|
return "cpu"
|
|
return env
|
|
# Auto-detect
|
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
# ── Global state ────────────────────────────────────────────────────────────────
|
|
|
|
ModelStatus = Literal["downloading", "loading", "online", "error"]
|
|
|
|
_processor = None
|
|
_model = None
|
|
_device: str = "cpu"
|
|
_model_status: ModelStatus = "loading"
|
|
_model_error: Optional[str] = None
|
|
_voice_presets: dict[str, object] = {}
|
|
_load_lock = threading.Lock()
|
|
_generation_lock = asyncio.Lock()
|
|
|
|
# Download progress (files downloaded so far)
|
|
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
|
|
|
|
|
|
|
|
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
|
|
|
|
def _make_dl_tqdm() -> type:
|
|
class _DlTqdm(_BaseTqdm):
|
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
|
|
_dl_progress["total"] = int(self.total)
|
|
_dl_progress["done"] = 0
|
|
|
|
def update(self, n: int = 1) -> "bool | None":
|
|
result = super().update(n)
|
|
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
|
|
_dl_progress["done"] = int(self.n)
|
|
return result
|
|
|
|
return _DlTqdm
|
|
|
|
|
|
# ── Model / voice helpers ───────────────────────────────────────────────────────
|
|
|
|
def _is_model_cached() -> bool:
|
|
try:
|
|
from huggingface_hub import snapshot_download
|
|
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _download_model() -> None:
|
|
from huggingface_hub import snapshot_download
|
|
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
|
DlTqdm = _make_dl_tqdm()
|
|
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
|
snapshot_download(
|
|
repo_id=MODEL_ID,
|
|
ignore_patterns=_IGNORE_PATTERNS,
|
|
token=token or None,
|
|
tqdm_class=DlTqdm,
|
|
)
|
|
logger.info("Model download complete.")
|
|
|
|
|
|
def _download_voices() -> None:
|
|
VOICES_DIR.mkdir(parents=True, exist_ok=True)
|
|
for name, filename in EN_VOICES.items():
|
|
dest = VOICES_DIR / filename
|
|
if not dest.exists():
|
|
url = f"{VOICE_BASE_URL}/{filename}"
|
|
logger.info("Downloading voice preset: %s", filename)
|
|
urllib.request.urlretrieve(url, dest)
|
|
logger.info("Voice presets ready.")
|
|
|
|
|
|
# ── Background model loader ─────────────────────────────────────────────────────
|
|
|
|
def _init_processor():
|
|
logger.info("Loading processor...")
|
|
from vibevoice.processor.vibevoice_streaming_processor import (
|
|
VibeVoiceStreamingProcessor,
|
|
)
|
|
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
|
|
|
|
|
def _init_model(device: str):
|
|
logger.info("Loading model on %s...", device)
|
|
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
|
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
|
|
|
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
|
VibeVoiceStreamingForConditionalGenerationInference,
|
|
)
|
|
try:
|
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
MODEL_ID,
|
|
torch_dtype=load_dtype,
|
|
device_map=device,
|
|
attn_implementation=attn_impl,
|
|
)
|
|
except Exception:
|
|
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
|
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
|
MODEL_ID,
|
|
torch_dtype=load_dtype,
|
|
device_map=device,
|
|
attn_implementation="sdpa",
|
|
)
|
|
|
|
model.eval()
|
|
model.set_ddpm_inference_steps(num_steps=10)
|
|
return model
|
|
|
|
|
|
def _load_voice_presets(device: str) -> dict[str, object]:
|
|
presets = {}
|
|
for name, filename in EN_VOICES.items():
|
|
path = VOICES_DIR / filename
|
|
if path.exists():
|
|
presets[name] = torch.load(
|
|
path, map_location=device, weights_only=False
|
|
)
|
|
return presets
|
|
|
|
|
|
def _load_model_sync() -> None:
|
|
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
|
|
|
with _load_lock:
|
|
if _model is not None:
|
|
return
|
|
|
|
try:
|
|
if not _is_model_cached():
|
|
_model_status = "downloading"
|
|
_download_model()
|
|
|
|
_model_status = "loading"
|
|
_download_voices()
|
|
|
|
# Resolve device from env var (set by start.sh --cpu/--cuda) or auto-detect.
|
|
_device = _resolve_device()
|
|
logger.info("Using device: %s", _device)
|
|
|
|
_processor = _init_processor()
|
|
_model = _init_model(_device)
|
|
_voice_presets = _load_voice_presets(_device)
|
|
|
|
_model_status = "online"
|
|
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
|
|
|
except Exception as exc:
|
|
_model_status = "error"
|
|
_model_error = "Internal server error during model initialization."
|
|
logger.exception("Failed to initialise model: %s", exc)
|
|
|
|
|
|
# ── FastAPI app ─────────────────────────────────────────────────────────────────
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
|
thread.start()
|
|
yield
|
|
|
|
|
|
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
|
|
|
|
|
|
# ── Schemas ─────────────────────────────────────────────────────────────────────
|
|
|
|
class GenerateRequest(BaseModel):
|
|
text: str = Field(..., min_length=1, max_length=10_000)
|
|
speaker: str = Field(default=DEFAULT_SPEAKER)
|
|
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
|
inference_steps: int = Field(default=10, ge=5, le=20)
|
|
|
|
@field_validator("text")
|
|
@classmethod
|
|
def text_not_blank(cls, v: str) -> str:
|
|
if not v.strip():
|
|
raise ValueError("text must not be blank")
|
|
return v.strip()
|
|
|
|
@field_validator("speaker")
|
|
@classmethod
|
|
def normalise_speaker(cls, v: str) -> str:
|
|
return v.lower().strip()
|
|
|
|
|
|
# ── Endpoints ───────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict:
|
|
body: dict = {
|
|
"status": _model_status,
|
|
"model": MODEL_ID,
|
|
"device": _device,
|
|
"voices": list(_voice_presets.keys()),
|
|
}
|
|
if _model_status == "downloading":
|
|
body["progress"] = {"done": _dl_progress["done"], "total": _dl_progress["total"]}
|
|
if _model_error:
|
|
body["message"] = _model_error
|
|
return body
|
|
|
|
|
|
def _sync_generate(
|
|
req: GenerateRequest,
|
|
streamer: Optional[object] = None,
|
|
cancel_event: Optional[threading.Event] = None,
|
|
) -> str:
|
|
"""Blocking inference. Returns the speaker used.
|
|
Runs in a thread-pool executor — do not call from the event loop directly.
|
|
Pass an AsyncAudioStreamer to receive audio chunks in real time.
|
|
"""
|
|
if cancel_event and cancel_event.is_set():
|
|
raise RuntimeError("Generation cancelled.")
|
|
|
|
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
|
voice_preset = copy.deepcopy(_voice_presets[speaker])
|
|
|
|
_model.set_ddpm_inference_steps(num_steps=req.inference_steps)
|
|
|
|
inputs = _processor.process_input_with_cached_prompt(
|
|
text=req.text,
|
|
cached_prompt=voice_preset,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
return_attention_mask=True,
|
|
)
|
|
for k, v in inputs.items():
|
|
if torch.is_tensor(v):
|
|
inputs[k] = v.to(_device)
|
|
|
|
outputs = _model.generate(
|
|
**inputs,
|
|
max_new_tokens=None,
|
|
cfg_scale=req.cfg_scale,
|
|
tokenizer=_processor.tokenizer,
|
|
generation_config={"do_sample": False},
|
|
verbose=True,
|
|
all_prefilled_outputs=copy.deepcopy(voice_preset),
|
|
audio_streamer=streamer,
|
|
)
|
|
|
|
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
|
|
raise ValueError("Model returned no audio output.")
|
|
|
|
return speaker
|
|
|
|
|
|
def _sse(event: dict) -> str:
|
|
return f"data: {json.dumps(event)}\n\n"
|
|
|
|
|
|
@app.post("/generate")
|
|
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|
if _model_status != "online":
|
|
detail = {
|
|
"downloading": "Model is downloading — please wait.",
|
|
"loading": "Model is loading into memory — please wait.",
|
|
"error": f"Model failed to load: {_model_error or 'unknown error'}",
|
|
}.get(_model_status, "Server not ready.")
|
|
raise HTTPException(status_code=503, detail=detail)
|
|
|
|
if _generation_lock.locked():
|
|
raise HTTPException(status_code=503, detail="Server is already generating audio. Please wait.")
|
|
|
|
async def event_stream() -> AsyncGenerator[str, None]:
|
|
from vibevoice.modular.streamer import AsyncAudioStreamer
|
|
|
|
start = time.monotonic()
|
|
streamer = AsyncAudioStreamer(batch_size=1)
|
|
cancel_event = threading.Event()
|
|
|
|
async with _generation_lock:
|
|
loop = asyncio.get_event_loop()
|
|
future = loop.run_in_executor(
|
|
None, functools.partial(_sync_generate, req, streamer, cancel_event)
|
|
)
|
|
|
|
# Drain audio chunks as they arrive from the diffusion head.
|
|
# stop_signal=None is the default sentinel that ends the queue.
|
|
while True:
|
|
try:
|
|
chunk = await asyncio.wait_for(
|
|
streamer.audio_queues[0].get(), timeout=120.0
|
|
)
|
|
except asyncio.TimeoutError:
|
|
cancel_event.set()
|
|
future.cancel()
|
|
yield _sse({"type": "error", "message": "Generation timed out"})
|
|
return
|
|
|
|
if await request.is_disconnected():
|
|
cancel_event.set()
|
|
future.cancel()
|
|
logger.info("Generation client disconnected; stream cancelled.")
|
|
return
|
|
|
|
if chunk is None: # stop signal
|
|
break
|
|
|
|
pcm_b64 = base64.b64encode(
|
|
chunk.detach().cpu().float().numpy().tobytes()
|
|
).decode()
|
|
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
|
|
|
try:
|
|
speaker = await future
|
|
except asyncio.CancelledError:
|
|
logger.info("Generation cancelled.")
|
|
yield _sse({"type": "cancelled"})
|
|
return
|
|
except Exception as exc:
|
|
logger.exception("Generation failed: %s", exc)
|
|
yield _sse({"type": "error", "message": "Internal server error during generation."})
|
|
return
|
|
|
|
elapsed = round(time.monotonic() - start, 1)
|
|
logger.info("Generation complete in %.1fs", elapsed)
|
|
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker})
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|