Files
vibepod/server/vibevoice_server.py
google-labs-jules[bot] e64048e500 Improve code documentation and maintainer notes
- Add a top-level doc comment to useStreamingGeneration.ts and document the streaming lifecycle.
- Add docstrings to helper functions in useStreamingGeneration.ts.
- Add section comments to web/app/page.tsx around reducer state, server health polling, and generation handling.
- Add file-level comments to API proxy routes explaining the security architecture.
- Add a file map / maintainer guide comment to server/vibevoice_server.py.
- Add docstrings for key internal helpers in server/vibevoice_server.py.
- Document environment variables used by the server in server/vibevoice_server.py.
- Add comments identifying VibePod-specific patches around VibeVoice internals.
- Format server/vibevoice_server.py with black.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>
2026-05-02 16:44:38 +00:00

1083 lines
41 KiB
Python

"""
VibePod — VibeVoice FastAPI TTS Server
This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model,
optimized for real-time streaming on both CPU and NVIDIA GPU hardware.
MAINTAINER GUIDE / FILE MAP:
- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars.
- Global State: Thread-safe storage for the loaded model, processor, and server status.
- Background Model Loader: Logic for downloading weights and initializing the model with optimizations.
- VibeVoice Patches: Performance-critical overrides for VibeVoice internals (hot-paths).
- FastAPI Application: SSE-based generation endpoint and health/status polling.
- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network.
RUNTIME CONFIGURATION (Environment Variables):
- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset).
- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU).
- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend).
- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill.
- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback.
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA).
- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging.
CPU-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_CPU_THREADS: Number of intra-op threads (defaults to logical core count / 2).
- VIBEPOD_CPU_INTEROP_THREADS: Number of inter-op threads (default: 1).
- VIBEPOD_CPU_MKLDNN: Set to '0' to disable MKLDNN (default: 1).
- VIBEPOD_CPU_BF16: Set to '1' to force bfloat16, '0' for float32.
- VIBEPOD_ASYNC_DECODE: Set to '1' to overlap decoding with inference on a separate thread (default: 1).
- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization.
- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS).
CUDA-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'.
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'.
"""
import asyncio
import base64
import concurrent.futures
import copy
import functools
import importlib.util
import json
import logging
import os
import platform
import threading
import time
import types
import urllib.request
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator
from tqdm import tqdm as _BaseTqdm
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
SAMPLE_RATE = 24_000
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
EN_VOICES: dict[str, str] = {
"carter": "en-Carter_man.pt",
"davis": "en-Davis_man.pt",
"emma": "en-Emma_woman.pt",
"frank": "en-Frank_man.pt",
"grace": "en-Grace_woman.pt",
"mike": "en-Mike_man.pt",
}
DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Pipeline executor ──────────────────────────────────────────────────────────
# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker).
_decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
# Falls back to auto-detection if not set.
def _resolve_device() -> str:
"""
Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment
variable, falling back to CUDA if available, otherwise CPU.
"""
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
if env in ("cpu", "cuda"):
if env == "cuda" and not torch.cuda.is_available():
logger.warning(
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
)
return "cpu"
return env
# Auto-detect
return "cuda" if torch.cuda.is_available() else "cpu"
# ── Env-var helpers ─────────────────────────────────────────────────────────────
def _env_int(name: str, default: int) -> int:
"""Helper to read an integer environment variable with a fallback default."""
raw = os.environ.get(name, "").strip()
if not raw:
return default
try:
return int(raw)
except ValueError:
logger.warning("Invalid value for %s=%r — using default %d", name, raw, default)
return default
def _env_float(name: str, default: float) -> float:
"""Helper to read a float environment variable with a fallback default."""
raw = os.environ.get(name, "").strip()
if not raw:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid value for %s=%r — using default %g", name, raw, default)
return default
def _cpu_supports_bf16() -> bool:
"""Return True if the CPU has AVX512_BF16 hardware support."""
return (
hasattr(torch, "cpu")
and hasattr(torch.cpu, "is_avx512_bf16_supported")
and torch.cpu.is_avx512_bf16_supported()
)
def _configure_cpu_runtime() -> dict[str, object]:
"""
Configure PyTorch's CPU execution engine, including thread counts and
MKLDNN acceleration.
"""
logical_cpus = os.cpu_count() or 1
default_threads = (
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
)
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
torch.set_num_threads(max(1, intra_threads))
try:
torch.set_num_interop_threads(max(1, interop_threads))
except RuntimeError as exc:
logger.warning("Could not set CPU inter-op threads: %s", exc)
torch.backends.mkldnn.enabled = mkldnn_enabled
return {
"logical_cpus": logical_cpus,
"threads": torch.get_num_threads(),
"interop_threads": torch.get_num_interop_threads(),
"mkldnn_available": torch.backends.mkldnn.is_available(),
"mkldnn_enabled": torch.backends.mkldnn.enabled,
}
# ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"]
_processor = None
_model = None
_device: str = "cpu"
_model_status: ModelStatus = "loading"
_model_error: str | None = None
_voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
# Config defaults (can be overridden by env vars)
# These are populated in _load_model_sync once the device is known.
_config = {
"device": "cpu",
"chunk_accum": 1,
"prebuffer_secs": 2.0,
"rebuffer_threshold_secs": 0.4,
"resume_threshold_secs": 1.5,
"default_inference_steps": 10,
}
# Download progress (files downloaded so far)
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
def _make_dl_tqdm() -> type:
class _DlTqdm(_BaseTqdm):
def __init__(self, *args: object, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
_dl_progress["total"] = int(self.total)
_dl_progress["done"] = 0
def update(self, n: int = 1) -> "bool | None":
result = super().update(n)
if isinstance(self.total, (int, float)) and 0 < self.total < 10_000:
_dl_progress["done"] = int(self.n)
return result
return _DlTqdm
# ── Model / voice helpers ───────────────────────────────────────────────────────
def _is_model_cached() -> bool:
try:
from huggingface_hub import snapshot_download
snapshot_download(
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
)
return True
except Exception:
return False
def _download_model() -> None:
from huggingface_hub import snapshot_download
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download(
repo_id=MODEL_ID,
ignore_patterns=_IGNORE_PATTERNS,
token=token or None,
tqdm_class=DlTqdm,
)
logger.info("Model download complete.")
def _download_voices() -> None:
VOICES_DIR.mkdir(parents=True, exist_ok=True)
for _name, filename in EN_VOICES.items():
dest = VOICES_DIR / filename
if not dest.exists():
url = f"{VOICE_BASE_URL}/{filename}"
logger.info("Downloading voice preset: %s", filename)
urllib.request.urlretrieve(url, dest)
logger.info("Voice presets ready.")
# ── Background model loader ─────────────────────────────────────────────────────
def _init_processor():
"""
Initialize the VibeVoiceStreamingProcessor from the model repository.
"""
logger.info("Loading processor...")
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
def _init_model(device: str):
"""
Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and
apply VibePod-specific performance optimizations.
"""
logger.info("Loading model on %s...", device)
if device == "cuda":
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
logger.info(
"PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s",
torch.backends.cuda.flash_sdp_enabled(),
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled(),
)
elif device == "cpu":
torch.set_float32_matmul_precision("medium")
logger.info("CPU runtime configuration: %s", _configure_cpu_runtime())
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
if device == "cuda" and cuda_dtype == "fp16":
load_dtype = torch.float16
elif device == "cuda":
load_dtype = torch.bfloat16
else:
cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower()
if cpu_bf16_env == "1":
load_dtype = torch.bfloat16
logger.info("CPU BF16 forced via VIBEPOD_CPU_BF16=1")
elif cpu_bf16_env == "0":
load_dtype = torch.float32
logger.info("CPU float32 forced via VIBEPOD_CPU_BF16=0")
elif _cpu_supports_bf16():
load_dtype = torch.bfloat16
logger.info("AVX512_BF16 detected — loading model in bfloat16")
else:
load_dtype = torch.float32
logger.info(
"No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)"
)
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
if requested_attn_impl in {"eager", "sdpa"}:
attn_impl = requested_attn_impl
elif requested_attn_impl == "flash_attention_2":
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
else:
attn_impl = (
"flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
)
logger.info("Using Transformers attention implementation: %s", attn_impl)
if device == "cuda" and not has_flash_attn:
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
try:
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
attn_implementation=attn_impl,
)
except Exception as exc:
if attn_impl == "sdpa":
raise
logger.warning(
"Model load with %s failed (%s); falling back to sdpa",
attn_impl,
exc,
)
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
attn_implementation="sdpa",
)
model.eval()
if device == "cpu":
model = _apply_cpu_optimizations(model)
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
_install_generation_optimizations(model)
if device == "cpu":
# Must run after _install_generation_optimizations so the async wrapper
# sits outside the profiling wrapper (VibeVoice calls async → profiling → real decode).
_install_cpu_pipeline_optimizations(model)
return model
def _apply_cpu_optimizations(model: object) -> object:
"""
Apply experimental CPU performance features like dynamic INT8 quantization
or torch.compile if enabled via environment variables.
"""
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
if do_quantize:
logger.info("Applying dynamic INT8 quantization to Linear layers...")
try:
import torch.ao.quantization
# The diffusion prediction_head operates on small fixed-size tensors where
# INT8 pack/unpack overhead exceeds the matmul savings (~+20% regression in
# testing). Save and restore it so it stays in float32.
saved_prediction_head = None
if hasattr(model, "model") and hasattr(model.model, "prediction_head"):
saved_prediction_head = model.model.prediction_head
del model.model.prediction_head
model = torch.ao.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
if saved_prediction_head is not None:
model.model.prediction_head = saved_prediction_head
logger.info(
"Dynamic INT8 quantization applied (prediction_head excluded — stays float32)."
)
else:
logger.info("Dynamic INT8 quantization applied.")
except Exception as exc:
logger.warning("Dynamic quantization failed: %s — skipping", exc)
if do_compile:
# torch.compile with inductor on CPU is ineffective for autoregressive TTS:
# each token step produces a unique input shape, so every step triggers a new
# kernel compile event rather than reusing compiled code. Kept as an escape
# hatch but not recommended.
compile_mode = os.environ.get("VIBEPOD_COMPILE_MODE", "reduce-overhead")
logger.info(
"torch.compile enabled (mode=%s) — NOTE: limited benefit for autoregressive"
" models on CPU due to dynamic sequence lengths.",
compile_mode,
)
_compile_targets: list[tuple[str, object, str, bool]] = [
("forward_tts_lm", model, "forward_tts_lm", True),
]
if hasattr(model, "model"):
inner = model.model
if hasattr(inner, "prediction_head"):
_compile_targets.append(
("prediction_head", inner, "prediction_head", False)
)
if hasattr(inner, "acoustic_tokenizer") and hasattr(
inner.acoustic_tokenizer, "decode"
):
_compile_targets.append(
(
"acoustic_tokenizer.decode",
inner.acoustic_tokenizer,
"decode",
False,
)
)
for label, obj, attr, dynamic in _compile_targets:
try:
compiled = torch.compile(
getattr(obj, attr),
backend="inductor",
mode=compile_mode,
dynamic=dynamic,
)
setattr(obj, attr, compiled)
logger.info(" compiled: %s", label)
except Exception as exc:
logger.warning(
" torch.compile failed for %s: %s — skipping", label, exc
)
return model
def _install_generation_optimizations(model: object) -> None:
"""
VibePod Optimization Patch:
Replaces performance-critical VibeVoice methods with optimized versions.
Includes caching for the noise scheduler and noise tensors to avoid
re-allocation overhead during the diffusion loop.
"""
def profile_enabled() -> bool:
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
def profile_sync() -> None:
if torch.cuda.is_available():
torch.cuda.synchronize()
def profile_record(self, key: str, elapsed: float) -> None:
stats = getattr(self, "_vibepod_profile", None)
if stats is None:
stats = {}
self._vibepod_profile = stats
bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0})
bucket["count"] += 1
bucket["seconds"] += elapsed
def timed_method(self, key: str, fn, *args, **kwargs):
if not profile_enabled():
return fn(*args, **kwargs)
profile_sync()
started = time.perf_counter()
result = fn(*args, **kwargs)
profile_sync()
profile_record(self, key, time.perf_counter() - started)
return result
def prepare_noise_scheduler(self):
scheduler = self.model.noise_scheduler
cache_key = self.ddpm_inference_steps
cache = getattr(self, "_vibepod_scheduler_cache", {})
cached = cache.get(cache_key)
if cached is None:
scheduler.set_timesteps(self.ddpm_inference_steps)
cached = {
"num_inference_steps": scheduler.num_inference_steps,
"timesteps": scheduler.timesteps,
"sigmas": scheduler.sigmas,
}
cache[cache_key] = cached
self._vibepod_scheduler_cache = cache
else:
scheduler.num_inference_steps = cached["num_inference_steps"]
scheduler.timesteps = cached["timesteps"]
scheduler.sigmas = cached["sigmas"]
scheduler.model_outputs = [None] * scheduler.config.solver_order
scheduler.lower_order_nums = 0
scheduler._step_index = None
scheduler._begin_index = None
return scheduler
def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0):
scheduler = prepare_noise_scheduler(self)
condition = torch.cat([condition, neg_condition], dim=0).to(
self.model.prediction_head.device
)
batch_size = condition.shape[0] // 2
speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition)
t_batch_cache_key = (
self.ddpm_inference_steps,
condition.device.type,
condition.device.index,
condition.dtype,
batch_size,
)
t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {})
t_batches = t_batch_cache.get(t_batch_cache_key)
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
t_batches = [
t.repeat(condition.shape[0]).to(
device=condition.device, dtype=condition.dtype
)
for t in scheduler.timesteps
]
t_batch_cache[t_batch_cache_key] = t_batches
self._vibepod_t_batch_cache = t_batch_cache
for t, t_batch in zip(scheduler.timesteps, t_batches):
if batch_size == 1:
combined = speech.expand(condition.shape[0], -1)
else:
combined = torch.cat([speech, speech], dim=0)
if profile_enabled():
profile_sync()
started = time.perf_counter()
eps = self.model.prediction_head(combined, t_batch, condition=condition)
if profile_enabled():
profile_sync()
profile_record(
self, "diffusion_prediction_head", time.perf_counter() - started
)
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
if profile_enabled():
started = time.perf_counter()
speech = scheduler.step(guided_eps, t, speech).prev_sample
if profile_enabled():
profile_record(
self, "diffusion_scheduler_step", time.perf_counter() - started
)
return speech
forward_lm = model.forward_lm
forward_tts_lm = model.forward_tts_lm
acoustic_decode = model.model.acoustic_tokenizer.decode
def forward_lm_profiled(*args, **kwargs):
return timed_method(model, "forward_lm", forward_lm, *args, **kwargs)
def forward_tts_lm_profiled(*args, **kwargs):
return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs)
def acoustic_decode_profiled(*args, **kwargs):
return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs)
model.forward_lm = forward_lm_profiled
model.forward_tts_lm = forward_tts_lm_profiled
model.model.acoustic_tokenizer.decode = acoustic_decode_profiled
model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model)
logger.info("Installed VibeVoice generation hot-path optimizations.")
def _install_cpu_pipeline_optimizations(model: object) -> None:
"""
VibePod Optimization Patch:
Enables asynchronous audio decoding on a background thread.
This allows the acoustic_decode (Vocoder) step to run in parallel with
the next chunk's forward_tts_lm (Inference) step, significantly reducing
the real-time factor on CPU.
The JezzWTF/VibeVoice fork's generate() checks for two optional attributes:
model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) that
overlaps acoustic_decode with acoustic_connector + forward_tts_lm.
Profiling showed this hides ~72s of decode cost behind tts_lm work,
capturing ~96% of the theoretical overlap savings.
model._vibepod_cfg_executor — intentionally NOT set. Parallel pos/neg
forward_tts_lm via a second thread causes MKL OpenMP thread-pool
contention on CPU: both threads compete for the same OMP worker pool,
making each call slower rather than faster. Net effect: ~6% regression.
The hook remains in the fork for potential GPU or future use.
Attributes default to None, so the fork's generate() falls back to the
original sequential behaviour on CUDA or any non-VibePod install.
"""
global _decode_executor
if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1":
logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.")
return
_decode_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="vibepod-decode"
)
model._vibepod_decode_executor = _decode_executor
logger.info(
"CPU pipeline: decode executor attached — acoustic_decode overlaps "
"tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0."
)
def _model_float_dtype() -> torch.dtype:
try:
return next(_model.parameters()).dtype
except StopIteration:
return torch.float32
def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object:
if torch.is_tensor(value):
if torch.is_floating_point(value):
return value.to(device=device, dtype=dtype)
return value.to(device=device)
if isinstance(value, dict):
for k in list(value.keys()):
value[k] = _move_cached_prompt(value[k], device, dtype)
return value
if isinstance(value, list):
return [_move_cached_prompt(v, device, dtype) for v in value]
if isinstance(value, tuple):
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
value.key_cache = [
_move_cached_prompt(t, device, dtype) for t in value.key_cache
]
value.value_cache = [
_move_cached_prompt(t, device, dtype) for t in value.value_cache
]
return value
def _load_voice_presets(device: str) -> dict[str, object]:
"""
Load all pre-downloaded voice tensor files (.pt) from the voices directory.
"""
presets = {}
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
presets[name] = torch.load(path, map_location=device, weights_only=False)
return presets
def _load_model_sync() -> None:
"""
Main synchronous initialization routine. Handles model/voice downloads,
device configuration, and model loading. Updates global status for the
health endpoint.
"""
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
with _load_lock:
if _model is not None:
return
try:
if not _is_model_cached():
_model_status = "downloading"
_download_model()
_model_status = "loading"
_download_voices()
# Resolve device from env var (set by start.sh --cpu/--cuda) or auto-detect.
_device = _resolve_device()
logger.info("Using device: %s", _device)
# Populate config based on device
is_cpu = _device == "cpu"
_config["device"] = _device
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
_config["prebuffer_secs"] = _env_float(
"VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0
)
_config["rebuffer_threshold_secs"] = _env_float(
"VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0
)
_config["resume_threshold_secs"] = _env_float(
"VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0
)
_config["default_inference_steps"] = _env_int(
"VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10
)
if is_cpu:
logical_cpus = os.cpu_count() or 1
_config["cpu_threads"] = _env_int(
"VIBEPOD_CPU_THREADS",
(
max(1, logical_cpus // 2)
if platform.system() == "Windows"
else logical_cpus
),
)
_config["cpu_interop_threads"] = _env_int(
"VIBEPOD_CPU_INTEROP_THREADS", 1
)
_config["cpu_mkldnn"] = (
os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
)
_processor = _init_processor()
_model = _init_model(_device)
_voice_presets = _load_voice_presets(_device)
_model_status = "online"
logger.info(
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
)
logger.info("Configuration: %s", _config)
except Exception as exc:
_model_status = "error"
_model_error = "Internal server error during model initialization."
logger.exception("Failed to initialise model: %s", exc)
# ── FastAPI app ─────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
if _decode_executor is not None:
_decode_executor.shutdown(wait=False)
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
# ── Schemas ─────────────────────────────────────────────────────────────────────
class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000)
speaker: str = Field(default=DEFAULT_SPEAKER)
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
inference_steps: int | None = Field(default=None, ge=5, le=20)
@field_validator("text")
@classmethod
def text_not_blank(cls, v: str) -> str:
if not v.strip():
raise ValueError("text must not be blank")
return v.strip()
@field_validator("speaker")
@classmethod
def normalise_speaker(cls, v: str) -> str:
return v.lower().strip()
# ── Endpoints ───────────────────────────────────────────────────────────────────
@app.get("/health")
async def health() -> dict:
body: dict = {
"status": _model_status,
"model": MODEL_ID,
"device": _device,
"voices": list(_voice_presets.keys()),
"config": _config,
}
if _model_status == "downloading":
body["progress"] = {
"done": _dl_progress["done"],
"total": _dl_progress["total"],
}
if _model_error:
body["message"] = _model_error
return body
def _sync_generate(
req: GenerateRequest,
streamer: object | None = None,
cancel_event: threading.Event | None = None,
) -> str:
"""
Performs blocking model inference for TTS generation.
This function should always be run in a thread-pool executor to avoid
blocking the FastAPI event loop. It streams audio chunks back to the
caller via the provided streamer object.
"""
if cancel_event and cancel_event.is_set():
raise RuntimeError("Generation cancelled.")
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
model_dtype = _model_float_dtype()
voice_preset = _move_cached_prompt(
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
)
steps = (
req.inference_steps
if req.inference_steps is not None
else _config["default_inference_steps"]
)
_model.set_ddpm_inference_steps(num_steps=steps)
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
_model._vibepod_profile = {}
inputs = _processor.process_input_with_cached_prompt(
text=req.text,
cached_prompt=voice_preset,
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(_device)
with torch.inference_mode():
_model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=req.cfg_scale,
tokenizer=_processor.tokenizer,
generation_config={"do_sample": False},
verbose=False,
show_progress_bar=False,
return_speech=False,
stop_check_fn=cancel_event.is_set if cancel_event else None,
all_prefilled_outputs=voice_preset,
audio_streamer=streamer,
)
return speaker
def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
def _generation_profile() -> dict[str, dict[str, float]] | None:
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
return None
stats = getattr(_model, "_vibepod_profile", None)
if not stats:
return {}
return {
key: {
"count": value["count"],
"seconds": round(value["seconds"], 3),
"avg_ms": (
round(value["seconds"] * 1000 / value["count"], 3)
if value["count"]
else 0.0
),
}
for key, value in sorted(stats.items())
}
@app.post("/generate")
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if _model_status != "online":
detail = {
"downloading": "Model is downloading — please wait.",
"loading": "Model is loading into memory — please wait.",
"error": f"Model failed to load: {_model_error or 'unknown error'}",
}.get(_model_status, "Server not ready.")
raise HTTPException(status_code=503, detail=detail)
if _generation_lock.locked():
raise HTTPException(
status_code=503, detail="Server is already generating audio. Please wait."
)
async def event_stream() -> AsyncGenerator[str, None]:
class NonBlockingAudioStreamer:
"""Async streamer that keeps GPU->CPU copies out of the model thread."""
def __init__(self, batch_size: int, stop_signal: object = None) -> None:
self.batch_size = batch_size
self.stop_signal = stop_signal
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(
self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor
) -> None:
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait,
audio_chunks[i].detach(),
)
def end(self, sample_indices: torch.Tensor | None = None) -> None:
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [
s.item() if torch.is_tensor(s) else s for s in sample_indices
]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
start = time.monotonic()
streamer = NonBlockingAudioStreamer(batch_size=1)
cancel_event = threading.Event()
accum_size = max(1, _config["chunk_accum"])
accumulated_chunks = []
chunk_count = 0
audio_samples = 0
first_chunk_at: float | None = None
last_chunk_at: float | None = None
max_chunk_gap = 0.0
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
async with _generation_lock:
loop = asyncio.get_event_loop()
future = loop.run_in_executor(
None, functools.partial(_sync_generate, req, streamer, cancel_event)
)
future.add_done_callback(lambda _: streamer.end())
# Drain audio chunks as they arrive from the diffusion head.
# stop_signal=None is the default sentinel that ends the queue.
while True:
try:
chunk = await asyncio.wait_for(
streamer.audio_queues[0].get(), timeout=120.0
)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
yield _sse({"type": "error", "message": "Generation timed out"})
return
if await request.is_disconnected():
cancel_event.set()
future.cancel()
logger.info("Generation client disconnected; stream cancelled.")
return
if chunk is None: # stop signal
break
accumulated_chunks.append(chunk.detach())
if len(accumulated_chunks) >= accum_size:
now = time.monotonic()
if first_chunk_at is None:
first_chunk_at = now
if last_chunk_at is not None:
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
last_chunk_at = now
combined = (
torch.cat(accumulated_chunks, dim=0)
.detach()
.to("cpu", dtype=torch.float32)
.contiguous()
)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
accumulated_chunks = []
# Flush any remaining chunks
if accumulated_chunks:
now = time.monotonic()
if first_chunk_at is None:
first_chunk_at = now
if last_chunk_at is not None:
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
last_chunk_at = now
combined = (
torch.cat(accumulated_chunks, dim=0)
.detach()
.to("cpu", dtype=torch.float32)
.contiguous()
)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
try:
speaker = await future
except asyncio.CancelledError:
logger.info("Generation cancelled.")
yield _sse({"type": "cancelled"})
return
except Exception as exc:
logger.exception("Generation failed: %s", exc)
yield _sse(
{
"type": "error",
"message": f"Generation failed: {exc}",
}
)
return
elapsed = round(time.monotonic() - start, 1)
audio_secs = audio_samples / SAMPLE_RATE
realtime_factor = audio_secs / elapsed if elapsed > 0 else None
profile = _generation_profile()
if profile is not None:
logger.info("Generation profile: %s", profile)
logger.info("Generation complete in %.1fs", elapsed)
complete_event = {
"type": "complete",
"elapsed": elapsed,
"speaker": speaker,
"audio_secs": round(audio_secs, 2),
"realtime_factor": (
round(realtime_factor, 3) if realtime_factor is not None else None
),
"chunks": chunk_count,
"first_chunk_secs": (
round(first_chunk_at - start, 2) if first_chunk_at is not None else None
),
"max_chunk_gap_secs": round(max_chunk_gap, 2),
}
if profile is not None:
complete_event["profile"] = profile
yield _sse(complete_event)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"X-Content-Type-Options": "nosniff",
},
)