Improve dev startup: model download script, loading state in health check, faster polling

Agent-Logs-Url: https://github.com/JezzWTF/vibepod/sessions/3c05c740-b0a3-497d-88f1-dfa63121424d

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-04-27 16:00:53 +00:00
committed by GitHub
parent 3974a4cf69
commit 11ffc7df7c
8 changed files with 546 additions and 46 deletions
+61 -22
View File
@@ -6,13 +6,16 @@ exposes a POST /generate endpoint that accepts { text, cfg_scale, inference_step
and returns a WAV audio blob.
Start with:
./start.sh
or directly:
uvicorn vibevoice_server:app --host 0.0.0.0 --port 8000
"""
import io
import logging
import threading
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Literal, Optional
import numpy as np
import soundfile as sf
@@ -26,36 +29,54 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(mess
logger = logging.getLogger(__name__)
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
DEFAULT_SAMPLE_RATE = 24_000 # fallback sample rate when not specified by model config
# ─── Global model state ────────────────────────────────────────────────────────
ModelStatus = Literal["loading", "online", "error"]
_processor: Optional[object] = None
_model: Optional[object] = None
_device: str = "cpu"
_model_status: ModelStatus = "loading"
_model_error: Optional[str] = None
_load_lock = threading.Lock()
def _load_model() -> None:
global _processor, _model, _device
def _load_model_sync() -> None:
"""Load the model synchronously. Called from a background thread at startup."""
global _processor, _model, _device, _model_status, _model_error
if _model is not None:
return
with _load_lock:
if _model is not None:
return
_device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading %s on %s", MODEL_ID, _device)
_device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading %s on %s", MODEL_ID, _device)
_processor = AutoProcessor.from_pretrained(MODEL_ID)
_model = AutoModel.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if _device == "cuda" else torch.float32,
)
_model = _model.to(_device)
_model.eval()
try:
_processor = AutoProcessor.from_pretrained(MODEL_ID)
_model = AutoModel.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if _device == "cuda" else torch.float32,
)
_model = _model.to(_device)
_model.eval()
logger.info("Model loaded successfully.")
_model_status = "online"
logger.info("Model loaded successfully on %s.", _device)
except Exception as exc:
_model_status = "error"
_model_error = str(exc)
logger.exception("Failed to load model: %s", exc)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
_load_model()
# Start model loading in a background thread so the server answers
# health-check requests immediately (status="loading") rather than
# blocking startup for the full model download/load time.
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
@@ -81,8 +102,16 @@ class GenerateRequest(BaseModel):
@app.get("/health")
async def health() -> dict:
"""Liveness probe used by the Next.js /api/health route."""
return {"status": "online", "model": MODEL_ID}
"""
Liveness / readiness probe used by the Next.js /api/health route.
Returns:
{ status: "loading" | "online" | "error", model: str, message?: str }
"""
body: dict = {"status": _model_status, "model": MODEL_ID}
if _model_error:
body["message"] = _model_error
return body
@app.post("/generate")
@@ -90,8 +119,16 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
"""
Generate speech from text and return a WAV audio stream.
"""
if _model is None or _processor is None:
raise HTTPException(status_code=503, detail="Model not loaded yet — please retry in a moment.")
if _model_status == "loading":
raise HTTPException(
status_code=503,
detail="Model is still loading — please retry in a moment.",
)
if _model_status == "error" or _model is None or _processor is None:
raise HTTPException(
status_code=503,
detail=f"Model failed to load: {_model_error or 'unknown error'}",
)
logger.info(
"Generating audio for %d chars (cfg=%.1f, steps=%d)",
@@ -113,7 +150,8 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
# output is typically a tensor of shape (1, num_samples) or (num_samples,)
audio_array = output.squeeze().cpu().numpy()
# Normalise to [-1, 1] float32 for WAV
# Normalise to [-1, 1] float32 for WAV.
# astype() may copy the array, but we need float32 for soundfile — this is intentional.
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32)
peak = np.abs(audio_array).max()
@@ -124,7 +162,7 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
sample_rate: int = (
getattr(_model.config, "sampling_rate", None)
or getattr(_model.config, "sample_rate", None)
or 24_000
or DEFAULT_SAMPLE_RATE
)
buf = io.BytesIO()
@@ -148,3 +186,4 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
logger.exception("Generation failed: %s", exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc