mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-13 03:58:07 +00:00
Improve dev startup: model download script, loading state in health check, faster polling
Agent-Logs-Url: https://github.com/JezzWTF/vibepod/sessions/3c05c740-b0a3-497d-88f1-dfa63121424d Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3974a4cf69
commit
11ffc7df7c
@@ -6,13 +6,16 @@ exposes a POST /generate endpoint that accepts { text, cfg_scale, inference_step
|
||||
and returns a WAV audio blob.
|
||||
|
||||
Start with:
|
||||
./start.sh
|
||||
or directly:
|
||||
uvicorn vibevoice_server:app --host 0.0.0.0 --port 8000
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Optional
|
||||
from typing import AsyncGenerator, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
@@ -26,36 +29,54 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(mess
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
||||
DEFAULT_SAMPLE_RATE = 24_000 # fallback sample rate when not specified by model config
|
||||
|
||||
# ─── Global model state ────────────────────────────────────────────────────────
|
||||
ModelStatus = Literal["loading", "online", "error"]
|
||||
|
||||
_processor: Optional[object] = None
|
||||
_model: Optional[object] = None
|
||||
_device: str = "cpu"
|
||||
_model_status: ModelStatus = "loading"
|
||||
_model_error: Optional[str] = None
|
||||
_load_lock = threading.Lock()
|
||||
|
||||
|
||||
def _load_model() -> None:
|
||||
global _processor, _model, _device
|
||||
def _load_model_sync() -> None:
|
||||
"""Load the model synchronously. Called from a background thread at startup."""
|
||||
global _processor, _model, _device, _model_status, _model_error
|
||||
|
||||
if _model is not None:
|
||||
return
|
||||
with _load_lock:
|
||||
if _model is not None:
|
||||
return
|
||||
|
||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info("Loading %s on %s …", MODEL_ID, _device)
|
||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info("Loading %s on %s …", MODEL_ID, _device)
|
||||
|
||||
_processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
_model = AutoModel.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.float16 if _device == "cuda" else torch.float32,
|
||||
)
|
||||
_model = _model.to(_device)
|
||||
_model.eval()
|
||||
try:
|
||||
_processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
_model = AutoModel.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.float16 if _device == "cuda" else torch.float32,
|
||||
)
|
||||
_model = _model.to(_device)
|
||||
_model.eval()
|
||||
|
||||
logger.info("Model loaded successfully.")
|
||||
_model_status = "online"
|
||||
logger.info("Model loaded successfully on %s.", _device)
|
||||
except Exception as exc:
|
||||
_model_status = "error"
|
||||
_model_error = str(exc)
|
||||
logger.exception("Failed to load model: %s", exc)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
_load_model()
|
||||
# Start model loading in a background thread so the server answers
|
||||
# health-check requests immediately (status="loading") rather than
|
||||
# blocking startup for the full model download/load time.
|
||||
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
||||
thread.start()
|
||||
yield
|
||||
|
||||
|
||||
@@ -81,8 +102,16 @@ class GenerateRequest(BaseModel):
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
"""Liveness probe used by the Next.js /api/health route."""
|
||||
return {"status": "online", "model": MODEL_ID}
|
||||
"""
|
||||
Liveness / readiness probe used by the Next.js /api/health route.
|
||||
|
||||
Returns:
|
||||
{ status: "loading" | "online" | "error", model: str, message?: str }
|
||||
"""
|
||||
body: dict = {"status": _model_status, "model": MODEL_ID}
|
||||
if _model_error:
|
||||
body["message"] = _model_error
|
||||
return body
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
@@ -90,8 +119,16 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
|
||||
"""
|
||||
Generate speech from text and return a WAV audio stream.
|
||||
"""
|
||||
if _model is None or _processor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded yet — please retry in a moment.")
|
||||
if _model_status == "loading":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Model is still loading — please retry in a moment.",
|
||||
)
|
||||
if _model_status == "error" or _model is None or _processor is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Model failed to load: {_model_error or 'unknown error'}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Generating audio for %d chars (cfg=%.1f, steps=%d)",
|
||||
@@ -113,7 +150,8 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
|
||||
# output is typically a tensor of shape (1, num_samples) or (num_samples,)
|
||||
audio_array = output.squeeze().cpu().numpy()
|
||||
|
||||
# Normalise to [-1, 1] float32 for WAV
|
||||
# Normalise to [-1, 1] float32 for WAV.
|
||||
# astype() may copy the array, but we need float32 for soundfile — this is intentional.
|
||||
if audio_array.dtype != np.float32:
|
||||
audio_array = audio_array.astype(np.float32)
|
||||
peak = np.abs(audio_array).max()
|
||||
@@ -124,7 +162,7 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
|
||||
sample_rate: int = (
|
||||
getattr(_model.config, "sampling_rate", None)
|
||||
or getattr(_model.config, "sample_rate", None)
|
||||
or 24_000
|
||||
or DEFAULT_SAMPLE_RATE
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
@@ -148,3 +186,4 @@ async def generate(req: GenerateRequest) -> StreamingResponse:
|
||||
logger.exception("Generation failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user