mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
feat: add studio roadmap and streaming cleanup
This commit is contained in:
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download microsoft/VibeVoice-Realtime-0.5B to the local HuggingFace cache.
|
||||
|
||||
Run once before starting the server:
|
||||
python download_model.py
|
||||
|
||||
Set HF_HOME or HUGGINGFACE_HUB_CACHE to control where the model is stored.
|
||||
Set HF_TOKEN (or HUGGINGFACE_TOKEN) if you need an access token.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
||||
|
||||
# Patterns that are not needed for PyTorch inference
|
||||
_IGNORE = [
|
||||
"*.msgpack",
|
||||
"flax_model*",
|
||||
"tf_model*",
|
||||
"rust_model*",
|
||||
"*.ot",
|
||||
]
|
||||
|
||||
|
||||
def download() -> str:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError:
|
||||
print(
|
||||
"ERROR: huggingface_hub is not installed.\n"
|
||||
"Run: pip install huggingface_hub",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||
|
||||
print(f"Checking / downloading model: {MODEL_ID}")
|
||||
print("(This may take several minutes on first run — the model is ~1 GB)")
|
||||
start = time.time()
|
||||
|
||||
cache_path = snapshot_download(
|
||||
repo_id=MODEL_ID,
|
||||
ignore_patterns=_IGNORE,
|
||||
token=token or None,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start
|
||||
print(f"Model ready in {elapsed:.1f}s -> {cache_path}")
|
||||
return cache_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download()
|
||||
@@ -0,0 +1,34 @@
|
||||
[project]
|
||||
name = "vibepod-server"
|
||||
version = "0.1.0"
|
||||
description = "VibePod TTS Server — VibeVoice FastAPI backend"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# torch is listed explicitly so uv pulls the CUDA wheel (see [tool.uv.sources]).
|
||||
# To switch back to CPU-only, remove the [tool.uv.sources] torch entry below.
|
||||
"torch>=2.0.0",
|
||||
# VibeVoice custom model + processor classes (not yet in upstream transformers)
|
||||
"vibevoice @ git+https://github.com/microsoft/VibeVoice.git",
|
||||
# Exact version required by vibevoice's streaming TTS module
|
||||
"transformers==4.51.3",
|
||||
"fastapi>=0.111.0",
|
||||
"uvicorn[standard]>=0.29.0",
|
||||
"soundfile>=0.12.1",
|
||||
"pydantic>=2.7.0",
|
||||
"huggingface_hub>=0.23.0",
|
||||
]
|
||||
|
||||
# No build-system — this is a scripts project, not an installable package.
|
||||
# Lock file is committed so installs are reproducible.
|
||||
# Run `uv lock --upgrade` to bump dependencies.
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu124"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
# Pull torch from the PyTorch CUDA 12.4 index instead of PyPI's CPU-only wheel.
|
||||
# CUDA 12.4 runs on any driver >= 525.60 (RTX 30/40 series all qualify).
|
||||
# To use CPU instead: remove this block and run `uv sync --reinstall-package torch`.
|
||||
torch = { index = "pytorch-cu124" }
|
||||
Executable
+40
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
# VibePod TTS server — start script
|
||||
# Syncs the uv environment, downloads the model on first run, then launches uvicorn.
|
||||
# Prerequisite: uv must be installed (https://docs.astral.sh/uv/getting-started/installation/)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
echo "================================================"
|
||||
echo " VibePod TTS Server"
|
||||
echo "================================================"
|
||||
|
||||
# 1. Check uv is available
|
||||
if ! command -v uv &>/dev/null; then
|
||||
echo ""
|
||||
echo "ERROR: uv is not installed."
|
||||
echo "Install it first:"
|
||||
echo " Windows: winget install astral-sh.uv"
|
||||
echo " macOS/Linux: curl -LsSf https://astral.sh/uv/install.sh | sh"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 2. Sync Python environment (creates .venv on first run, no-op afterwards)
|
||||
echo ""
|
||||
echo "--> Syncing Python environment..."
|
||||
uv sync
|
||||
|
||||
# 3. Start the server — model download + load happens inside the server process
|
||||
# so the /health endpoint is reachable immediately and can report progress.
|
||||
echo ""
|
||||
echo "--> Starting uvicorn on http://0.0.0.0:8000"
|
||||
export PYTHONUTF8=1
|
||||
exec uv run uvicorn vibevoice_server:app \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--log-level info \
|
||||
"$@"
|
||||
Generated
+3321
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
VibePod — VibeVoice FastAPI TTS Server
|
||||
|
||||
Startup sequence (background thread):
|
||||
1. Download model weights if not cached -> status: downloading
|
||||
2. Download voice preset .pt files -> status: loading
|
||||
3. Load processor + model into memory -> status: loading
|
||||
4. Pre-load all voice tensors -> status: loading
|
||||
-> Server ready -> status: online
|
||||
|
||||
Generation flow:
|
||||
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
|
||||
ends with {type:"complete"}
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Literal, Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from tqdm import tqdm as _BaseTqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
||||
SAMPLE_RATE = 24_000
|
||||
|
||||
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
||||
VOICE_BASE_URL = (
|
||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
|
||||
"/demo/voices/streaming_model"
|
||||
)
|
||||
|
||||
EN_VOICES: dict[str, str] = {
|
||||
"carter": "en-Carter_man.pt",
|
||||
"davis": "en-Davis_man.pt",
|
||||
"emma": "en-Emma_woman.pt",
|
||||
"frank": "en-Frank_man.pt",
|
||||
"grace": "en-Grace_woman.pt",
|
||||
"mike": "en-Mike_man.pt",
|
||||
}
|
||||
DEFAULT_SPEAKER = "carter"
|
||||
|
||||
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
|
||||
|
||||
# ── Global state ────────────────────────────────────────────────────────────────
|
||||
|
||||
ModelStatus = Literal["downloading", "loading", "online", "error"]
|
||||
|
||||
_processor = None
|
||||
_model = None
|
||||
_device: str = "cpu"
|
||||
_model_status: ModelStatus = "loading"
|
||||
_model_error: Optional[str] = None
|
||||
_voice_presets: dict[str, object] = {}
|
||||
_load_lock = threading.Lock()
|
||||
_generation_lock = asyncio.Lock()
|
||||
|
||||
# Download progress (files downloaded so far)
|
||||
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
|
||||
|
||||
|
||||
|
||||
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
|
||||
|
||||
def _make_dl_tqdm() -> type:
|
||||
class _DlTqdm(_BaseTqdm):
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
|
||||
_dl_progress["total"] = int(self.total)
|
||||
_dl_progress["done"] = 0
|
||||
|
||||
def update(self, n: int = 1) -> "bool | None":
|
||||
result = super().update(n)
|
||||
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
|
||||
_dl_progress["done"] = int(self.n)
|
||||
return result
|
||||
|
||||
return _DlTqdm
|
||||
|
||||
|
||||
# ── Model / voice helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _is_model_cached() -> bool:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _download_model() -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||
DlTqdm = _make_dl_tqdm()
|
||||
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
||||
snapshot_download(
|
||||
repo_id=MODEL_ID,
|
||||
ignore_patterns=_IGNORE_PATTERNS,
|
||||
token=token or None,
|
||||
tqdm_class=DlTqdm,
|
||||
)
|
||||
logger.info("Model download complete.")
|
||||
|
||||
|
||||
def _download_voices() -> None:
|
||||
VOICES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
for name, filename in EN_VOICES.items():
|
||||
dest = VOICES_DIR / filename
|
||||
if not dest.exists():
|
||||
url = f"{VOICE_BASE_URL}/{filename}"
|
||||
logger.info("Downloading voice preset: %s", filename)
|
||||
urllib.request.urlretrieve(url, dest)
|
||||
logger.info("Voice presets ready.")
|
||||
|
||||
|
||||
# ── Background model loader ─────────────────────────────────────────────────────
|
||||
|
||||
def _load_model_sync() -> None:
|
||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
||||
|
||||
with _load_lock:
|
||||
if _model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
if not _is_model_cached():
|
||||
_model_status = "downloading"
|
||||
_download_model()
|
||||
|
||||
_model_status = "loading"
|
||||
_download_voices()
|
||||
|
||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
|
||||
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
|
||||
|
||||
logger.info("Loading processor...")
|
||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||
VibeVoiceStreamingProcessor,
|
||||
)
|
||||
_processor = VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
||||
|
||||
logger.info("Loading model on %s...", _device)
|
||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||
VibeVoiceStreamingForConditionalGenerationInference,
|
||||
)
|
||||
try:
|
||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=_device,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("flash_attention_2 unavailable, falling back to sdpa")
|
||||
_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=_device,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
_model.eval()
|
||||
_model.set_ddpm_inference_steps(num_steps=10)
|
||||
|
||||
for name, filename in EN_VOICES.items():
|
||||
path = VOICES_DIR / filename
|
||||
if path.exists():
|
||||
_voice_presets[name] = torch.load(
|
||||
path, map_location=_device, weights_only=False
|
||||
)
|
||||
|
||||
_model_status = "online"
|
||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||
|
||||
except Exception as exc:
|
||||
_model_status = "error"
|
||||
_model_error = str(exc)
|
||||
logger.exception("Failed to initialise model: %s", exc)
|
||||
|
||||
|
||||
# ── FastAPI app ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
||||
thread.start()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
|
||||
# ── Schemas ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1, max_length=10_000)
|
||||
speaker: str = Field(default=DEFAULT_SPEAKER)
|
||||
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
||||
inference_steps: int = Field(default=10, ge=5, le=20)
|
||||
|
||||
@field_validator("text")
|
||||
@classmethod
|
||||
def text_not_blank(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("text must not be blank")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("speaker")
|
||||
@classmethod
|
||||
def normalise_speaker(cls, v: str) -> str:
|
||||
return v.lower().strip()
|
||||
|
||||
|
||||
# ── Endpoints ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
body: dict = {
|
||||
"status": _model_status,
|
||||
"model": MODEL_ID,
|
||||
"voices": list(_voice_presets.keys()),
|
||||
}
|
||||
if _model_status == "downloading":
|
||||
body["progress"] = {"done": _dl_progress["done"], "total": _dl_progress["total"]}
|
||||
if _model_error:
|
||||
body["message"] = _model_error
|
||||
return body
|
||||
|
||||
|
||||
def _sync_generate(
|
||||
req: GenerateRequest,
|
||||
streamer: Optional[object] = None,
|
||||
cancel_event: Optional[threading.Event] = None,
|
||||
) -> str:
|
||||
"""Blocking inference. Returns the speaker used.
|
||||
Runs in a thread-pool executor — do not call from the event loop directly.
|
||||
Pass an AsyncAudioStreamer to receive audio chunks in real time.
|
||||
"""
|
||||
if cancel_event and cancel_event.is_set():
|
||||
raise RuntimeError("Generation cancelled.")
|
||||
|
||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||
voice_preset = copy.deepcopy(_voice_presets[speaker])
|
||||
|
||||
_model.set_ddpm_inference_steps(num_steps=req.inference_steps)
|
||||
|
||||
inputs = _processor.process_input_with_cached_prompt(
|
||||
text=req.text,
|
||||
cached_prompt=voice_preset,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs[k] = v.to(_device)
|
||||
|
||||
outputs = _model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=None,
|
||||
cfg_scale=req.cfg_scale,
|
||||
tokenizer=_processor.tokenizer,
|
||||
generation_config={"do_sample": False},
|
||||
verbose=True,
|
||||
all_prefilled_outputs=copy.deepcopy(voice_preset),
|
||||
audio_streamer=streamer,
|
||||
)
|
||||
|
||||
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
|
||||
raise ValueError("Model returned no audio output.")
|
||||
|
||||
return speaker
|
||||
|
||||
|
||||
def _sse(event: dict) -> str:
|
||||
return f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if _model_status != "online":
|
||||
detail = {
|
||||
"downloading": "Model is downloading — please wait.",
|
||||
"loading": "Model is loading into memory — please wait.",
|
||||
"error": f"Model failed to load: {_model_error or 'unknown error'}",
|
||||
}.get(_model_status, "Server not ready.")
|
||||
raise HTTPException(status_code=503, detail=detail)
|
||||
|
||||
if _generation_lock.locked():
|
||||
raise HTTPException(status_code=503, detail="Server is already generating audio. Please wait.")
|
||||
|
||||
async def event_stream() -> AsyncGenerator[str, None]:
|
||||
from vibevoice.modular.streamer import AsyncAudioStreamer
|
||||
|
||||
start = time.monotonic()
|
||||
streamer = AsyncAudioStreamer(batch_size=1)
|
||||
cancel_event = threading.Event()
|
||||
|
||||
async with _generation_lock:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
None, functools.partial(_sync_generate, req, streamer, cancel_event)
|
||||
)
|
||||
|
||||
# Drain audio chunks as they arrive from the diffusion head.
|
||||
# stop_signal=None is the default sentinel that ends the queue.
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
streamer.audio_queues[0].get(), timeout=120.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
yield _sse({"type": "error", "message": "Generation timed out"})
|
||||
return
|
||||
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
logger.info("Generation client disconnected; stream cancelled.")
|
||||
return
|
||||
|
||||
if chunk is None: # stop signal
|
||||
break
|
||||
|
||||
pcm_b64 = base64.b64encode(
|
||||
chunk.detach().cpu().float().numpy().tobytes()
|
||||
).decode()
|
||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||
|
||||
try:
|
||||
speaker = await future
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generation cancelled.")
|
||||
yield _sse({"type": "cancelled"})
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.exception("Generation failed: %s", exc)
|
||||
yield _sse({"type": "error", "message": str(exc)})
|
||||
return
|
||||
|
||||
elapsed = round(time.monotonic() - start, 1)
|
||||
logger.info("Generation complete in %.1fs", elapsed)
|
||||
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker})
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user