Improve code documentation and maintainer notes

- Add a top-level doc comment to useStreamingGeneration.ts and document the streaming lifecycle.
- Add docstrings to helper functions in useStreamingGeneration.ts.
- Add section comments to web/app/page.tsx around reducer state, server health polling, and generation handling.
- Add file-level comments to API proxy routes explaining the security architecture.
- Add a file map / maintainer guide comment to server/vibevoice_server.py.
- Add docstrings for key internal helpers in server/vibevoice_server.py.
- Document environment variables used by the server in server/vibevoice_server.py.
- Add comments identifying VibePod-specific patches around VibeVoice internals.
- Format server/vibevoice_server.py with black.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>
This commit is contained in:
google-labs-jules[bot]
2026-05-02 16:44:38 +00:00
parent 0236807928
commit e64048e500
5 changed files with 219 additions and 50 deletions
+161 -50
View File
@@ -1,21 +1,38 @@
"""
VibePod — VibeVoice FastAPI TTS Server
Startup sequence (background thread):
1. Download model weights if not cached -> status: downloading
2. Download voice preset .pt files -> status: loading
3. Load processor + model into memory -> status: loading
4. Pre-load all voice tensors -> status: loading
-> Server ready -> status: online
This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model,
optimized for real-time streaming on both CPU and NVIDIA GPU hardware.
Generation flow:
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
ends with {type:"complete"}
MAINTAINER GUIDE / FILE MAP:
- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars.
- Global State: Thread-safe storage for the loaded model, processor, and server status.
- Background Model Loader: Logic for downloading weights and initializing the model with optimizations.
- VibeVoice Patches: Performance-critical overrides for VibeVoice internals (hot-paths).
- FastAPI Application: SSE-based generation endpoint and health/status polling.
- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network.
Device selection:
Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh).
Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available).
If unset, the server auto-detects: CUDA if available, otherwise CPU.
RUNTIME CONFIGURATION (Environment Variables):
- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset).
- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU).
- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend).
- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill.
- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback.
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA).
- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging.
CPU-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_CPU_THREADS: Number of intra-op threads (defaults to logical core count / 2).
- VIBEPOD_CPU_INTEROP_THREADS: Number of inter-op threads (default: 1).
- VIBEPOD_CPU_MKLDNN: Set to '0' to disable MKLDNN (default: 1).
- VIBEPOD_CPU_BF16: Set to '1' to force bfloat16, '0' for float32.
- VIBEPOD_ASYNC_DECODE: Set to '1' to overlap decoding with inference on a separate thread (default: 1).
- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization.
- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS).
CUDA-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'.
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'.
"""
import asyncio
@@ -50,9 +67,7 @@ MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
SAMPLE_RATE = 24_000
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
)
VOICE_BASE_URL = "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
EN_VOICES: dict[str, str] = {
"carter": "en-Carter_man.pt",
@@ -77,7 +92,10 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
def _resolve_device() -> str:
"""Resolve the target device from env var or auto-detect."""
"""
Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment
variable, falling back to CUDA if available, otherwise CPU.
"""
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
if env in ("cpu", "cuda"):
if env == "cuda" and not torch.cuda.is_available():
@@ -94,6 +112,7 @@ def _resolve_device() -> str:
def _env_int(name: str, default: int) -> int:
"""Helper to read an integer environment variable with a fallback default."""
raw = os.environ.get(name, "").strip()
if not raw:
return default
@@ -105,6 +124,7 @@ def _env_int(name: str, default: int) -> int:
def _env_float(name: str, default: float) -> float:
"""Helper to read a float environment variable with a fallback default."""
raw = os.environ.get(name, "").strip()
if not raw:
return default
@@ -125,8 +145,14 @@ def _cpu_supports_bf16() -> bool:
def _configure_cpu_runtime() -> dict[str, object]:
"""
Configure PyTorch's CPU execution engine, including thread counts and
MKLDNN acceleration.
"""
logical_cpus = os.cpu_count() or 1
default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
default_threads = (
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
)
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
@@ -202,7 +228,9 @@ def _is_model_cached() -> bool:
try:
from huggingface_hub import snapshot_download
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
snapshot_download(
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
)
return True
except Exception:
return False
@@ -211,7 +239,9 @@ def _is_model_cached() -> bool:
def _download_model() -> None:
from huggingface_hub import snapshot_download
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download(
@@ -238,6 +268,9 @@ def _download_voices() -> None:
def _init_processor():
"""
Initialize the VibeVoiceStreamingProcessor from the model repository.
"""
logger.info("Loading processor...")
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
@@ -247,6 +280,10 @@ def _init_processor():
def _init_model(device: str):
"""
Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and
apply VibePod-specific performance optimizations.
"""
logger.info("Loading model on %s...", device)
if device == "cuda":
torch.set_float32_matmul_precision("high")
@@ -285,7 +322,9 @@ def _init_model(device: str):
logger.info("AVX512_BF16 detected — loading model in bfloat16")
else:
load_dtype = torch.float32
logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)")
logger.info(
"No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)"
)
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
@@ -294,7 +333,9 @@ def _init_model(device: str):
elif requested_attn_impl == "flash_attention_2":
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
else:
attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
attn_impl = (
"flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
)
logger.info("Using Transformers attention implementation: %s", attn_impl)
if device == "cuda" and not has_flash_attn:
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
@@ -338,7 +379,10 @@ def _init_model(device: str):
def _apply_cpu_optimizations(model: object) -> object:
"""Apply optional post-load CPU optimizations. Returns (possibly new) model object."""
"""
Apply experimental CPU performance features like dynamic INT8 quantization
or torch.compile if enabled via environment variables.
"""
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
@@ -387,10 +431,19 @@ def _apply_cpu_optimizations(model: object) -> object:
if hasattr(model, "model"):
inner = model.model
if hasattr(inner, "prediction_head"):
_compile_targets.append(("prediction_head", inner, "prediction_head", False))
if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
_compile_targets.append(
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
("prediction_head", inner, "prediction_head", False)
)
if hasattr(inner, "acoustic_tokenizer") and hasattr(
inner.acoustic_tokenizer, "decode"
):
_compile_targets.append(
(
"acoustic_tokenizer.decode",
inner.acoustic_tokenizer,
"decode",
False,
)
)
for label, obj, attr, dynamic in _compile_targets:
@@ -404,13 +457,20 @@ def _apply_cpu_optimizations(model: object) -> object:
setattr(obj, attr, compiled)
logger.info(" compiled: %s", label)
except Exception as exc:
logger.warning(" torch.compile failed for %s: %s — skipping", label, exc)
logger.warning(
" torch.compile failed for %s: %s — skipping", label, exc
)
return model
def _install_generation_optimizations(model: object) -> None:
"""Patch VibeVoice hot paths without changing model quality settings."""
"""
VibePod Optimization Patch:
Replaces performance-critical VibeVoice methods with optimized versions.
Includes caching for the noise scheduler and noise tensors to avoid
re-allocation overhead during the diffusion loop.
"""
def profile_enabled() -> bool:
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
@@ -483,7 +543,9 @@ def _install_generation_optimizations(model: object) -> None:
t_batches = t_batch_cache.get(t_batch_cache_key)
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
t_batches = [
t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
t.repeat(condition.shape[0]).to(
device=condition.device, dtype=condition.dtype
)
for t in scheduler.timesteps
]
t_batch_cache[t_batch_cache_key] = t_batches
@@ -500,14 +562,18 @@ def _install_generation_optimizations(model: object) -> None:
eps = self.model.prediction_head(combined, t_batch, condition=condition)
if profile_enabled():
profile_sync()
profile_record(self, "diffusion_prediction_head", time.perf_counter() - started)
profile_record(
self, "diffusion_prediction_head", time.perf_counter() - started
)
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
if profile_enabled():
started = time.perf_counter()
speech = scheduler.step(guided_eps, t, speech).prev_sample
if profile_enabled():
profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started)
profile_record(
self, "diffusion_scheduler_step", time.perf_counter() - started
)
return speech
@@ -532,7 +598,13 @@ def _install_generation_optimizations(model: object) -> None:
def _install_cpu_pipeline_optimizations(model: object) -> None:
"""Attach the decode executor to the model for the optimised generate() loop.
"""
VibePod Optimization Patch:
Enables asynchronous audio decoding on a background thread.
This allows the acoustic_decode (Vocoder) step to run in parallel with
the next chunk's forward_tts_lm (Inference) step, significantly reducing
the real-time factor on CPU.
The JezzWTF/VibeVoice fork's generate() checks for two optional attributes:
@@ -587,12 +659,19 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec
if isinstance(value, tuple):
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
value.key_cache = [
_move_cached_prompt(t, device, dtype) for t in value.key_cache
]
value.value_cache = [
_move_cached_prompt(t, device, dtype) for t in value.value_cache
]
return value
def _load_voice_presets(device: str) -> dict[str, object]:
"""
Load all pre-downloaded voice tensor files (.pt) from the voices directory.
"""
presets = {}
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
@@ -602,6 +681,11 @@ def _load_voice_presets(device: str) -> dict[str, object]:
def _load_model_sync() -> None:
"""
Main synchronous initialization routine. Handles model/voice downloads,
device configuration, and model loading. Updates global status for the
health endpoint.
"""
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
with _load_lock:
@@ -640,17 +724,27 @@ def _load_model_sync() -> None:
logical_cpus = os.cpu_count() or 1
_config["cpu_threads"] = _env_int(
"VIBEPOD_CPU_THREADS",
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
(
max(1, logical_cpus // 2)
if platform.system() == "Windows"
else logical_cpus
),
)
_config["cpu_interop_threads"] = _env_int(
"VIBEPOD_CPU_INTEROP_THREADS", 1
)
_config["cpu_mkldnn"] = (
os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
)
_config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
_config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
_processor = _init_processor()
_model = _init_model(_device)
_voice_presets = _load_voice_presets(_device)
_model_status = "online"
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
logger.info(
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
)
logger.info("Configuration: %s", _config)
except Exception as exc:
@@ -723,16 +817,21 @@ def _sync_generate(
streamer: object | None = None,
cancel_event: threading.Event | None = None,
) -> str:
"""Blocking inference. Returns the speaker used.
Runs in a thread-pool executor — do not call from the event loop directly.
Pass an AsyncAudioStreamer to receive audio chunks in real time.
"""
Performs blocking model inference for TTS generation.
This function should always be run in a thread-pool executor to avoid
blocking the FastAPI event loop. It streams audio chunks back to the
caller via the provided streamer object.
"""
if cancel_event and cancel_event.is_set():
raise RuntimeError("Generation cancelled.")
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
model_dtype = _model_float_dtype()
voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
voice_preset = _move_cached_prompt(
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
)
steps = (
req.inference_steps
@@ -786,7 +885,11 @@ def _generation_profile() -> dict[str, dict[str, float]] | None:
key: {
"count": value["count"],
"seconds": round(value["seconds"], 3),
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
"avg_ms": (
round(value["seconds"] * 1000 / value["count"], 3)
if value["count"]
else 0.0
),
}
for key, value in sorted(stats.items())
}
@@ -818,7 +921,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
self.finished_flags = [False for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None:
def put(
self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor
) -> None:
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
@@ -831,7 +936,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
indices_to_end = [
s.item() if torch.is_tensor(s) else s for s in sample_indices
]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
@@ -863,7 +970,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
# stop_signal=None is the default sentinel that ends the queue.
while True:
try:
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
chunk = await asyncio.wait_for(
streamer.audio_queues[0].get(), timeout=120.0
)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
@@ -949,11 +1058,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
"elapsed": elapsed,
"speaker": speaker,
"audio_secs": round(audio_secs, 2),
"realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
"realtime_factor": (
round(realtime_factor, 3) if realtime_factor is not None else None
),
"chunks": chunk_count,
"first_chunk_secs": round(first_chunk_at - start, 2)
if first_chunk_at is not None
else None,
"first_chunk_secs": (
round(first_chunk_at - start, 2) if first_chunk_at is not None else None
),
"max_chunk_gap_secs": round(max_chunk_gap, 2),
}
if profile is not None: