Files
vibepod/server/vibevoice_server.py
T
LyAhn a39ec536fd Improve CPU Inference Stability: Adaptive Buffering & Chunk Accumulation (#11)
* Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation

This change addresses audio stuttering issues when running on CPU-only hardware by:
- Implementing server-side audio chunk accumulation to reduce SSE overhead.
- Introducing device-aware default configurations for buffering and inference steps.
- Exposing key performance parameters as environment variables.
- Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration.

Changes:
- Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`.
- Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters.
- Updated `web/app/page.tsx` to fetch and apply server-side configuration.

Verified on CPU mode in the development environment.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>

* Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation

This change addresses audio stuttering issues when running on CPU-only hardware by:
- Implementing server-side audio chunk accumulation to reduce SSE overhead.
- Introducing device-aware default configurations for buffering and inference steps.
- Exposing key performance parameters as environment variables.
- Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration.

Changes:
- Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`.
- Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters.
- Updated `web/app/page.tsx` to fetch and apply server-side configuration.

Verified on CPU mode in the development environment.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>

* Improve CPU Inference Stability: Adaptive Buffering UI & Logic

This change enhances the initial CPU stability fix by:
- Exposing adaptive buffering settings (Pre-buffer, Re-buffer Threshold, Resume Threshold) in a new "Advanced Buffering" UI section.
- Managing buffering settings in the application state to allow for manual overrides.
- Implementing robust re-initialization of buffering and inference defaults whenever the server's device (CPU/CUDA) changes.
- Including the active device in the server's config object for reliable client-side detection.

Verified with frontend screenshots and full build. Responds to PR feedback regarding actioning the adaptive logic.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>

* Refine adaptive buffering: env helpers, threshold validation, a11y fixes

- Extract _env_int/_env_float helpers in server to validate env-var config
  with graceful fallback instead of bare int/float casts
- Fix inference_steps falsy-check (0 is valid) to use explicit None guard
- Enforce rebufferThresholdSecs < resumeThresholdSecs in both the hook
  (with console.warn + clamp) and the GenerationControls UI (sliders block
  invalid states by auto-bumping or ignoring the drag)
- Add type="button", aria-expanded, aria-controls, htmlFor, and input id
  attributes to GenerationControls for accessibility
- Add .vscode/settings.json to .gitignore; sort package.json scripts

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
2026-04-30 16:03:35 +01:00

496 lines
17 KiB
Python

"""
VibePod — VibeVoice FastAPI TTS Server
Startup sequence (background thread):
1. Download model weights if not cached -> status: downloading
2. Download voice preset .pt files -> status: loading
3. Load processor + model into memory -> status: loading
4. Pre-load all voice tensors -> status: loading
-> Server ready -> status: online
Generation flow:
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
ends with {type:"complete"}
Device selection:
Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh).
Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available).
If unset, the server auto-detects: CUDA if available, otherwise CPU.
"""
import asyncio
import base64
import copy
import functools
import json
import logging
import os
import threading
import time
import urllib.request
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator, Literal, Optional
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator
from tqdm import tqdm as _BaseTqdm
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
SAMPLE_RATE = 24_000
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
"/demo/voices/streaming_model"
)
EN_VOICES: dict[str, str] = {
"carter": "en-Carter_man.pt",
"davis": "en-Davis_man.pt",
"emma": "en-Emma_woman.pt",
"frank": "en-Frank_man.pt",
"grace": "en-Grace_woman.pt",
"mike": "en-Mike_man.pt",
}
DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
# Falls back to auto-detection if not set.
def _resolve_device() -> str:
"""Resolve the target device from env var or auto-detect."""
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
if env in ("cpu", "cuda"):
if env == "cuda" and not torch.cuda.is_available():
logger.warning(
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
)
return "cpu"
return env
# Auto-detect
return "cuda" if torch.cuda.is_available() else "cpu"
# ── Env-var helpers ─────────────────────────────────────────────────────────────
def _env_int(name: str, default: int) -> int:
raw = os.environ.get(name, "").strip()
if not raw:
return default
try:
return int(raw)
except ValueError:
logger.warning("Invalid value for %s=%r — using default %d", name, raw, default)
return default
def _env_float(name: str, default: float) -> float:
raw = os.environ.get(name, "").strip()
if not raw:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid value for %s=%r — using default %g", name, raw, default)
return default
# ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"]
_processor = None
_model = None
_device: str = "cpu"
_model_status: ModelStatus = "loading"
_model_error: Optional[str] = None
_voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
# Config defaults (can be overridden by env vars)
# These are populated in _load_model_sync once the device is known.
_config = {
"device": "cpu",
"chunk_accum": 1,
"prebuffer_secs": 2.0,
"rebuffer_threshold_secs": 0.4,
"resume_threshold_secs": 1.5,
"default_inference_steps": 10,
}
# Download progress (files downloaded so far)
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
def _make_dl_tqdm() -> type:
class _DlTqdm(_BaseTqdm):
def __init__(self, *args: object, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
_dl_progress["total"] = int(self.total)
_dl_progress["done"] = 0
def update(self, n: int = 1) -> "bool | None":
result = super().update(n)
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
_dl_progress["done"] = int(self.n)
return result
return _DlTqdm
# ── Model / voice helpers ───────────────────────────────────────────────────────
def _is_model_cached() -> bool:
try:
from huggingface_hub import snapshot_download
snapshot_download(
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
)
return True
except Exception:
return False
def _download_model() -> None:
from huggingface_hub import snapshot_download
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download(
repo_id=MODEL_ID,
ignore_patterns=_IGNORE_PATTERNS,
token=token or None,
tqdm_class=DlTqdm,
)
logger.info("Model download complete.")
def _download_voices() -> None:
VOICES_DIR.mkdir(parents=True, exist_ok=True)
for name, filename in EN_VOICES.items():
dest = VOICES_DIR / filename
if not dest.exists():
url = f"{VOICE_BASE_URL}/{filename}"
logger.info("Downloading voice preset: %s", filename)
urllib.request.urlretrieve(url, dest)
logger.info("Voice presets ready.")
# ── Background model loader ─────────────────────────────────────────────────────
def _init_processor():
logger.info("Loading processor...")
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
def _init_model(device: str):
logger.info("Loading model on %s...", device)
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
try:
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
attn_implementation=attn_impl,
)
except Exception:
logger.warning(
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True
)
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
attn_implementation="sdpa",
)
model.eval()
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
return model
def _load_voice_presets(device: str) -> dict[str, object]:
presets = {}
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
presets[name] = torch.load(path, map_location=device, weights_only=False)
return presets
def _load_model_sync() -> None:
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
with _load_lock:
if _model is not None:
return
try:
if not _is_model_cached():
_model_status = "downloading"
_download_model()
_model_status = "loading"
_download_voices()
# Resolve device from env var (set by start.sh --cpu/--cuda) or auto-detect.
_device = _resolve_device()
logger.info("Using device: %s", _device)
# Populate config based on device
is_cpu = _device == "cpu"
_config["device"] = _device
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0)
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4)
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5)
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
_processor = _init_processor()
_model = _init_model(_device)
_voice_presets = _load_voice_presets(_device)
_model_status = "online"
logger.info(
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
)
logger.info("Configuration: %s", _config)
except Exception as exc:
_model_status = "error"
_model_error = "Internal server error during model initialization."
logger.exception("Failed to initialise model: %s", exc)
# ── FastAPI app ─────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
# ── Schemas ─────────────────────────────────────────────────────────────────────
class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000)
speaker: str = Field(default=DEFAULT_SPEAKER)
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
@field_validator("text")
@classmethod
def text_not_blank(cls, v: str) -> str:
if not v.strip():
raise ValueError("text must not be blank")
return v.strip()
@field_validator("speaker")
@classmethod
def normalise_speaker(cls, v: str) -> str:
return v.lower().strip()
# ── Endpoints ───────────────────────────────────────────────────────────────────
@app.get("/health")
async def health() -> dict:
body: dict = {
"status": _model_status,
"model": MODEL_ID,
"device": _device,
"voices": list(_voice_presets.keys()),
"config": _config,
}
if _model_status == "downloading":
body["progress"] = {
"done": _dl_progress["done"],
"total": _dl_progress["total"],
}
if _model_error:
body["message"] = _model_error
return body
def _sync_generate(
req: GenerateRequest,
streamer: Optional[object] = None,
cancel_event: Optional[threading.Event] = None,
) -> str:
"""Blocking inference. Returns the speaker used.
Runs in a thread-pool executor — do not call from the event loop directly.
Pass an AsyncAudioStreamer to receive audio chunks in real time.
"""
if cancel_event and cancel_event.is_set():
raise RuntimeError("Generation cancelled.")
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
voice_preset = copy.deepcopy(_voice_presets[speaker])
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
_model.set_ddpm_inference_steps(num_steps=steps)
inputs = _processor.process_input_with_cached_prompt(
text=req.text,
cached_prompt=voice_preset,
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(_device)
outputs = _model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=req.cfg_scale,
tokenizer=_processor.tokenizer,
generation_config={"do_sample": False},
verbose=True,
all_prefilled_outputs=copy.deepcopy(voice_preset),
audio_streamer=streamer,
)
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
raise ValueError("Model returned no audio output.")
return speaker
def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
@app.post("/generate")
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if _model_status != "online":
detail = {
"downloading": "Model is downloading — please wait.",
"loading": "Model is loading into memory — please wait.",
"error": f"Model failed to load: {_model_error or 'unknown error'}",
}.get(_model_status, "Server not ready.")
raise HTTPException(status_code=503, detail=detail)
if _generation_lock.locked():
raise HTTPException(
status_code=503, detail="Server is already generating audio. Please wait."
)
async def event_stream() -> AsyncGenerator[str, None]:
from vibevoice.modular.streamer import AsyncAudioStreamer
start = time.monotonic()
streamer = AsyncAudioStreamer(batch_size=1)
cancel_event = threading.Event()
accum_size = max(1, _config["chunk_accum"])
accumulated_chunks = []
async with _generation_lock:
loop = asyncio.get_event_loop()
future = loop.run_in_executor(
None, functools.partial(_sync_generate, req, streamer, cancel_event)
)
# Drain audio chunks as they arrive from the diffusion head.
# stop_signal=None is the default sentinel that ends the queue.
while True:
try:
chunk = await asyncio.wait_for(
streamer.audio_queues[0].get(), timeout=120.0
)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
yield _sse({"type": "error", "message": "Generation timed out"})
return
if await request.is_disconnected():
cancel_event.set()
future.cancel()
logger.info("Generation client disconnected; stream cancelled.")
return
if chunk is None: # stop signal
break
accumulated_chunks.append(chunk.detach().cpu().float())
if len(accumulated_chunks) >= accum_size:
combined = torch.cat(accumulated_chunks, dim=0)
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
accumulated_chunks = []
# Flush any remaining chunks
if accumulated_chunks:
combined = torch.cat(accumulated_chunks, dim=0)
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
try:
speaker = await future
except asyncio.CancelledError:
logger.info("Generation cancelled.")
yield _sse({"type": "cancelled"})
return
except Exception as exc:
logger.exception("Generation failed: %s", exc)
yield _sse(
{
"type": "error",
"message": "Internal server error during generation.",
}
)
return
elapsed = round(time.monotonic() - start, 1)
logger.info("Generation complete in %.1fs", elapsed)
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker})
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)