feat: add studio roadmap and streaming cleanup

This commit is contained in:
2026-04-28 00:09:15 +01:00
parent 11ffc7df7c
commit 34ec879cdb
45 changed files with 5899 additions and 2659 deletions
+57
View File
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""
Download microsoft/VibeVoice-Realtime-0.5B to the local HuggingFace cache.
Run once before starting the server:
python download_model.py
Set HF_HOME or HUGGINGFACE_HUB_CACHE to control where the model is stored.
Set HF_TOKEN (or HUGGINGFACE_TOKEN) if you need an access token.
"""
import os
import sys
import time
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
# Patterns that are not needed for PyTorch inference
_IGNORE = [
"*.msgpack",
"flax_model*",
"tf_model*",
"rust_model*",
"*.ot",
]
def download() -> str:
try:
from huggingface_hub import snapshot_download
except ImportError:
print(
"ERROR: huggingface_hub is not installed.\n"
"Run: pip install huggingface_hub",
file=sys.stderr,
)
sys.exit(1)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
print(f"Checking / downloading model: {MODEL_ID}")
print("(This may take several minutes on first run — the model is ~1 GB)")
start = time.time()
cache_path = snapshot_download(
repo_id=MODEL_ID,
ignore_patterns=_IGNORE,
token=token or None,
)
elapsed = time.time() - start
print(f"Model ready in {elapsed:.1f}s -> {cache_path}")
return cache_path
if __name__ == "__main__":
download()
+34
View File
@@ -0,0 +1,34 @@
[project]
name = "vibepod-server"
version = "0.1.0"
description = "VibePod TTS Server — VibeVoice FastAPI backend"
requires-python = ">=3.10"
dependencies = [
# torch is listed explicitly so uv pulls the CUDA wheel (see [tool.uv.sources]).
# To switch back to CPU-only, remove the [tool.uv.sources] torch entry below.
"torch>=2.0.0",
# VibeVoice custom model + processor classes (not yet in upstream transformers)
"vibevoice @ git+https://github.com/microsoft/VibeVoice.git",
# Exact version required by vibevoice's streaming TTS module
"transformers==4.51.3",
"fastapi>=0.111.0",
"uvicorn[standard]>=0.29.0",
"soundfile>=0.12.1",
"pydantic>=2.7.0",
"huggingface_hub>=0.23.0",
]
# No build-system — this is a scripts project, not an installable package.
# Lock file is committed so installs are reproducible.
# Run `uv lock --upgrade` to bump dependencies.
[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true
[tool.uv.sources]
# Pull torch from the PyTorch CUDA 12.4 index instead of PyPI's CPU-only wheel.
# CUDA 12.4 runs on any driver >= 525.60 (RTX 30/40 series all qualify).
# To use CPU instead: remove this block and run `uv sync --reinstall-package torch`.
torch = { index = "pytorch-cu124" }
+40
View File
@@ -0,0 +1,40 @@
#!/usr/bin/env bash
# VibePod TTS server — start script
# Syncs the uv environment, downloads the model on first run, then launches uvicorn.
# Prerequisite: uv must be installed (https://docs.astral.sh/uv/getting-started/installation/)
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
echo "================================================"
echo " VibePod TTS Server"
echo "================================================"
# 1. Check uv is available
if ! command -v uv &>/dev/null; then
echo ""
echo "ERROR: uv is not installed."
echo "Install it first:"
echo " Windows: winget install astral-sh.uv"
echo " macOS/Linux: curl -LsSf https://astral.sh/uv/install.sh | sh"
echo ""
exit 1
fi
# 2. Sync Python environment (creates .venv on first run, no-op afterwards)
echo ""
echo "--> Syncing Python environment..."
uv sync
# 3. Start the server — model download + load happens inside the server process
# so the /health endpoint is reachable immediately and can report progress.
echo ""
echo "--> Starting uvicorn on http://0.0.0.0:8000"
export PYTHONUTF8=1
exec uv run uvicorn vibevoice_server:app \
--host 0.0.0.0 \
--port 8000 \
--log-level info \
"$@"
+3321
View File
File diff suppressed because it is too large Load Diff
+370
View File
@@ -0,0 +1,370 @@
"""
VibePod — VibeVoice FastAPI TTS Server
Startup sequence (background thread):
1. Download model weights if not cached -> status: downloading
2. Download voice preset .pt files -> status: loading
3. Load processor + model into memory -> status: loading
4. Pre-load all voice tensors -> status: loading
-> Server ready -> status: online
Generation flow:
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
ends with {type:"complete"}
"""
import asyncio
import base64
import copy
import functools
import json
import logging
import os
import threading
import time
import urllib.request
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator, Literal, Optional
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator
from tqdm import tqdm as _BaseTqdm
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
SAMPLE_RATE = 24_000
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
"/demo/voices/streaming_model"
)
EN_VOICES: dict[str, str] = {
"carter": "en-Carter_man.pt",
"davis": "en-Davis_man.pt",
"emma": "en-Emma_woman.pt",
"frank": "en-Frank_man.pt",
"grace": "en-Grace_woman.pt",
"mike": "en-Mike_man.pt",
}
DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"]
_processor = None
_model = None
_device: str = "cpu"
_model_status: ModelStatus = "loading"
_model_error: Optional[str] = None
_voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
# Download progress (files downloaded so far)
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
def _make_dl_tqdm() -> type:
class _DlTqdm(_BaseTqdm):
def __init__(self, *args: object, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
_dl_progress["total"] = int(self.total)
_dl_progress["done"] = 0
def update(self, n: int = 1) -> "bool | None":
result = super().update(n)
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
_dl_progress["done"] = int(self.n)
return result
return _DlTqdm
# ── Model / voice helpers ───────────────────────────────────────────────────────
def _is_model_cached() -> bool:
try:
from huggingface_hub import snapshot_download
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
return True
except Exception:
return False
def _download_model() -> None:
from huggingface_hub import snapshot_download
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download(
repo_id=MODEL_ID,
ignore_patterns=_IGNORE_PATTERNS,
token=token or None,
tqdm_class=DlTqdm,
)
logger.info("Model download complete.")
def _download_voices() -> None:
VOICES_DIR.mkdir(parents=True, exist_ok=True)
for name, filename in EN_VOICES.items():
dest = VOICES_DIR / filename
if not dest.exists():
url = f"{VOICE_BASE_URL}/{filename}"
logger.info("Downloading voice preset: %s", filename)
urllib.request.urlretrieve(url, dest)
logger.info("Voice presets ready.")
# ── Background model loader ─────────────────────────────────────────────────────
def _load_model_sync() -> None:
global _processor, _model, _device, _model_status, _model_error, _voice_presets
with _load_lock:
if _model is not None:
return
try:
if not _is_model_cached():
_model_status = "downloading"
_download_model()
_model_status = "loading"
_download_voices()
_device = "cuda" if torch.cuda.is_available() else "cpu"
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
logger.info("Loading processor...")
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
_processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
logger.info("Loading model on %s...", _device)
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
try:
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=_device,
attn_implementation=attn_impl,
)
except Exception:
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=_device,
attn_implementation="sdpa",
)
_model.eval()
_model.set_ddpm_inference_steps(num_steps=10)
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
_voice_presets[name] = torch.load(
path, map_location=_device, weights_only=False
)
_model_status = "online"
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
except Exception as exc:
_model_status = "error"
_model_error = str(exc)
logger.exception("Failed to initialise model: %s", exc)
# ── FastAPI app ─────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
# ── Schemas ─────────────────────────────────────────────────────────────────────
class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000)
speaker: str = Field(default=DEFAULT_SPEAKER)
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
inference_steps: int = Field(default=10, ge=5, le=20)
@field_validator("text")
@classmethod
def text_not_blank(cls, v: str) -> str:
if not v.strip():
raise ValueError("text must not be blank")
return v.strip()
@field_validator("speaker")
@classmethod
def normalise_speaker(cls, v: str) -> str:
return v.lower().strip()
# ── Endpoints ───────────────────────────────────────────────────────────────────
@app.get("/health")
async def health() -> dict:
body: dict = {
"status": _model_status,
"model": MODEL_ID,
"voices": list(_voice_presets.keys()),
}
if _model_status == "downloading":
body["progress"] = {"done": _dl_progress["done"], "total": _dl_progress["total"]}
if _model_error:
body["message"] = _model_error
return body
def _sync_generate(
req: GenerateRequest,
streamer: Optional[object] = None,
cancel_event: Optional[threading.Event] = None,
) -> str:
"""Blocking inference. Returns the speaker used.
Runs in a thread-pool executor — do not call from the event loop directly.
Pass an AsyncAudioStreamer to receive audio chunks in real time.
"""
if cancel_event and cancel_event.is_set():
raise RuntimeError("Generation cancelled.")
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
voice_preset = copy.deepcopy(_voice_presets[speaker])
_model.set_ddpm_inference_steps(num_steps=req.inference_steps)
inputs = _processor.process_input_with_cached_prompt(
text=req.text,
cached_prompt=voice_preset,
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(_device)
outputs = _model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=req.cfg_scale,
tokenizer=_processor.tokenizer,
generation_config={"do_sample": False},
verbose=True,
all_prefilled_outputs=copy.deepcopy(voice_preset),
audio_streamer=streamer,
)
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
raise ValueError("Model returned no audio output.")
return speaker
def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
@app.post("/generate")
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if _model_status != "online":
detail = {
"downloading": "Model is downloading — please wait.",
"loading": "Model is loading into memory — please wait.",
"error": f"Model failed to load: {_model_error or 'unknown error'}",
}.get(_model_status, "Server not ready.")
raise HTTPException(status_code=503, detail=detail)
if _generation_lock.locked():
raise HTTPException(status_code=503, detail="Server is already generating audio. Please wait.")
async def event_stream() -> AsyncGenerator[str, None]:
from vibevoice.modular.streamer import AsyncAudioStreamer
start = time.monotonic()
streamer = AsyncAudioStreamer(batch_size=1)
cancel_event = threading.Event()
async with _generation_lock:
loop = asyncio.get_event_loop()
future = loop.run_in_executor(
None, functools.partial(_sync_generate, req, streamer, cancel_event)
)
# Drain audio chunks as they arrive from the diffusion head.
# stop_signal=None is the default sentinel that ends the queue.
while True:
try:
chunk = await asyncio.wait_for(
streamer.audio_queues[0].get(), timeout=120.0
)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
yield _sse({"type": "error", "message": "Generation timed out"})
return
if await request.is_disconnected():
cancel_event.set()
future.cancel()
logger.info("Generation client disconnected; stream cancelled.")
return
if chunk is None: # stop signal
break
pcm_b64 = base64.b64encode(
chunk.detach().cpu().float().numpy().tobytes()
).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
try:
speaker = await future
except asyncio.CancelledError:
logger.info("Generation cancelled.")
yield _sse({"type": "cancelled"})
return
except Exception as exc:
logger.exception("Generation failed: %s", exc)
yield _sse({"type": "error", "message": str(exc)})
return
elapsed = round(time.monotonic() - start, 1)
logger.info("Generation complete in %.1fs", elapsed)
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker})
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)