diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 0c32e6c..b5f7880 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -1,21 +1,38 @@ """ VibePod — VibeVoice FastAPI TTS Server -Startup sequence (background thread): - 1. Download model weights if not cached -> status: downloading - 2. Download voice preset .pt files -> status: loading - 3. Load processor + model into memory -> status: loading - 4. Pre-load all voice tensors -> status: loading - -> Server ready -> status: online +This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model, +optimized for real-time streaming on both CPU and NVIDIA GPU hardware. -Generation flow: - POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM), - ends with {type:"complete"} +MAINTAINER GUIDE / FILE MAP: +- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars. +- Global State: Thread-safe storage for the loaded model, processor, and server status. +- Background Model Loader: Logic for downloading weights and initializing the model with optimizations. +- VibeVoice Patches: Performance-critical overrides for VibeVoice internals (hot-paths). +- FastAPI Application: SSE-based generation endpoint and health/status polling. +- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network. -Device selection: - Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh). - Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available). - If unset, the server auto-detects: CUDA if available, otherwise CPU. +RUNTIME CONFIGURATION (Environment Variables): +- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset). +- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU). +- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend). +- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill. +- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback. +- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA). +- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging. + +CPU-SPECIFIC OPTIMIZATIONS: +- VIBEPOD_CPU_THREADS: Number of intra-op threads (defaults to logical core count / 2). +- VIBEPOD_CPU_INTEROP_THREADS: Number of inter-op threads (default: 1). +- VIBEPOD_CPU_MKLDNN: Set to '0' to disable MKLDNN (default: 1). +- VIBEPOD_CPU_BF16: Set to '1' to force bfloat16, '0' for float32. +- VIBEPOD_ASYNC_DECODE: Set to '1' to overlap decoding with inference on a separate thread (default: 1). +- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization. +- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS). + +CUDA-SPECIFIC OPTIMIZATIONS: +- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'. +- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'. """ import asyncio @@ -50,9 +67,7 @@ MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B" SAMPLE_RATE = 24_000 VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model" -VOICE_BASE_URL = ( - "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model" -) +VOICE_BASE_URL = "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model" EN_VOICES: dict[str, str] = { "carter": "en-Carter_man.pt", @@ -77,7 +92,10 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None def _resolve_device() -> str: - """Resolve the target device from env var or auto-detect.""" + """ + Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment + variable, falling back to CUDA if available, otherwise CPU. + """ env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower() if env in ("cpu", "cuda"): if env == "cuda" and not torch.cuda.is_available(): @@ -94,6 +112,7 @@ def _resolve_device() -> str: def _env_int(name: str, default: int) -> int: + """Helper to read an integer environment variable with a fallback default.""" raw = os.environ.get(name, "").strip() if not raw: return default @@ -105,6 +124,7 @@ def _env_int(name: str, default: int) -> int: def _env_float(name: str, default: float) -> float: + """Helper to read a float environment variable with a fallback default.""" raw = os.environ.get(name, "").strip() if not raw: return default @@ -125,8 +145,14 @@ def _cpu_supports_bf16() -> bool: def _configure_cpu_runtime() -> dict[str, object]: + """ + Configure PyTorch's CPU execution engine, including thread counts and + MKLDNN acceleration. + """ logical_cpus = os.cpu_count() or 1 - default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus + default_threads = ( + max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus + ) intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads) interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" @@ -202,7 +228,9 @@ def _is_model_cached() -> bool: try: from huggingface_hub import snapshot_download - snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS) + snapshot_download( + MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS + ) return True except Exception: return False @@ -211,7 +239,9 @@ def _is_model_cached() -> bool: def _download_model() -> None: from huggingface_hub import snapshot_download - token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") + token: str | None = os.environ.get("HF_TOKEN") or os.environ.get( + "HUGGINGFACE_TOKEN" + ) DlTqdm = _make_dl_tqdm() logger.info("Model not cached — downloading %s...", MODEL_ID) snapshot_download( @@ -238,6 +268,9 @@ def _download_voices() -> None: def _init_processor(): + """ + Initialize the VibeVoiceStreamingProcessor from the model repository. + """ logger.info("Loading processor...") from vibevoice.processor.vibevoice_streaming_processor import ( VibeVoiceStreamingProcessor, @@ -247,6 +280,10 @@ def _init_processor(): def _init_model(device: str): + """ + Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and + apply VibePod-specific performance optimizations. + """ logger.info("Loading model on %s...", device) if device == "cuda": torch.set_float32_matmul_precision("high") @@ -285,7 +322,9 @@ def _init_model(device: str): logger.info("AVX512_BF16 detected — loading model in bfloat16") else: load_dtype = torch.float32 - logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)") + logger.info( + "No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)" + ) logger.info("Loading model weights with dtype %s", load_dtype) requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower() has_flash_attn = importlib.util.find_spec("flash_attn") is not None @@ -294,7 +333,9 @@ def _init_model(device: str): elif requested_attn_impl == "flash_attention_2": attn_impl = "flash_attention_2" if has_flash_attn else "sdpa" else: - attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa" + attn_impl = ( + "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa" + ) logger.info("Using Transformers attention implementation: %s", attn_impl) if device == "cuda" and not has_flash_attn: logger.info("flash_attn is not installed; using PyTorch SDPA attention.") @@ -338,7 +379,10 @@ def _init_model(device: str): def _apply_cpu_optimizations(model: object) -> object: - """Apply optional post-load CPU optimizations. Returns (possibly new) model object.""" + """ + Apply experimental CPU performance features like dynamic INT8 quantization + or torch.compile if enabled via environment variables. + """ do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1" do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1" @@ -387,10 +431,19 @@ def _apply_cpu_optimizations(model: object) -> object: if hasattr(model, "model"): inner = model.model if hasattr(inner, "prediction_head"): - _compile_targets.append(("prediction_head", inner, "prediction_head", False)) - if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"): _compile_targets.append( - ("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False) + ("prediction_head", inner, "prediction_head", False) + ) + if hasattr(inner, "acoustic_tokenizer") and hasattr( + inner.acoustic_tokenizer, "decode" + ): + _compile_targets.append( + ( + "acoustic_tokenizer.decode", + inner.acoustic_tokenizer, + "decode", + False, + ) ) for label, obj, attr, dynamic in _compile_targets: @@ -404,13 +457,20 @@ def _apply_cpu_optimizations(model: object) -> object: setattr(obj, attr, compiled) logger.info(" compiled: %s", label) except Exception as exc: - logger.warning(" torch.compile failed for %s: %s — skipping", label, exc) + logger.warning( + " torch.compile failed for %s: %s — skipping", label, exc + ) return model def _install_generation_optimizations(model: object) -> None: - """Patch VibeVoice hot paths without changing model quality settings.""" + """ + VibePod Optimization Patch: + Replaces performance-critical VibeVoice methods with optimized versions. + Includes caching for the noise scheduler and noise tensors to avoid + re-allocation overhead during the diffusion loop. + """ def profile_enabled() -> bool: return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1" @@ -483,7 +543,9 @@ def _install_generation_optimizations(model: object) -> None: t_batches = t_batch_cache.get(t_batch_cache_key) if t_batches is None or len(t_batches) != len(scheduler.timesteps): t_batches = [ - t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype) + t.repeat(condition.shape[0]).to( + device=condition.device, dtype=condition.dtype + ) for t in scheduler.timesteps ] t_batch_cache[t_batch_cache_key] = t_batches @@ -500,14 +562,18 @@ def _install_generation_optimizations(model: object) -> None: eps = self.model.prediction_head(combined, t_batch, condition=condition) if profile_enabled(): profile_sync() - profile_record(self, "diffusion_prediction_head", time.perf_counter() - started) + profile_record( + self, "diffusion_prediction_head", time.perf_counter() - started + ) cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0) guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) if profile_enabled(): started = time.perf_counter() speech = scheduler.step(guided_eps, t, speech).prev_sample if profile_enabled(): - profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started) + profile_record( + self, "diffusion_scheduler_step", time.perf_counter() - started + ) return speech @@ -532,7 +598,13 @@ def _install_generation_optimizations(model: object) -> None: def _install_cpu_pipeline_optimizations(model: object) -> None: - """Attach the decode executor to the model for the optimised generate() loop. + """ + VibePod Optimization Patch: + Enables asynchronous audio decoding on a background thread. + + This allows the acoustic_decode (Vocoder) step to run in parallel with + the next chunk's forward_tts_lm (Inference) step, significantly reducing + the real-time factor on CPU. The JezzWTF/VibeVoice fork's generate() checks for two optional attributes: @@ -587,12 +659,19 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec if isinstance(value, tuple): return tuple(_move_cached_prompt(v, device, dtype) for v in value) if hasattr(value, "key_cache") and hasattr(value, "value_cache"): - value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache] - value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache] + value.key_cache = [ + _move_cached_prompt(t, device, dtype) for t in value.key_cache + ] + value.value_cache = [ + _move_cached_prompt(t, device, dtype) for t in value.value_cache + ] return value def _load_voice_presets(device: str) -> dict[str, object]: + """ + Load all pre-downloaded voice tensor files (.pt) from the voices directory. + """ presets = {} for name, filename in EN_VOICES.items(): path = VOICES_DIR / filename @@ -602,6 +681,11 @@ def _load_voice_presets(device: str) -> dict[str, object]: def _load_model_sync() -> None: + """ + Main synchronous initialization routine. Handles model/voice downloads, + device configuration, and model loading. Updates global status for the + health endpoint. + """ global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config with _load_lock: @@ -640,17 +724,27 @@ def _load_model_sync() -> None: logical_cpus = os.cpu_count() or 1 _config["cpu_threads"] = _env_int( "VIBEPOD_CPU_THREADS", - max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus, + ( + max(1, logical_cpus // 2) + if platform.system() == "Windows" + else logical_cpus + ), + ) + _config["cpu_interop_threads"] = _env_int( + "VIBEPOD_CPU_INTEROP_THREADS", 1 + ) + _config["cpu_mkldnn"] = ( + os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" ) - _config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) - _config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" _processor = _init_processor() _model = _init_model(_device) _voice_presets = _load_voice_presets(_device) _model_status = "online" - logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())) + logger.info( + "Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()) + ) logger.info("Configuration: %s", _config) except Exception as exc: @@ -723,16 +817,21 @@ def _sync_generate( streamer: object | None = None, cancel_event: threading.Event | None = None, ) -> str: - """Blocking inference. Returns the speaker used. - Runs in a thread-pool executor — do not call from the event loop directly. - Pass an AsyncAudioStreamer to receive audio chunks in real time. + """ + Performs blocking model inference for TTS generation. + + This function should always be run in a thread-pool executor to avoid + blocking the FastAPI event loop. It streams audio chunks back to the + caller via the provided streamer object. """ if cancel_event and cancel_event.is_set(): raise RuntimeError("Generation cancelled.") speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER model_dtype = _model_float_dtype() - voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype) + voice_preset = _move_cached_prompt( + copy.deepcopy(_voice_presets[speaker]), _device, model_dtype + ) steps = ( req.inference_steps @@ -786,7 +885,11 @@ def _generation_profile() -> dict[str, dict[str, float]] | None: key: { "count": value["count"], "seconds": round(value["seconds"], 3), - "avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0, + "avg_ms": ( + round(value["seconds"] * 1000 / value["count"], 3) + if value["count"] + else 0.0 + ), } for key, value in sorted(stats.items()) } @@ -818,7 +921,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: self.finished_flags = [False for _ in range(batch_size)] self.loop = asyncio.get_running_loop() - def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None: + def put( + self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor + ) -> None: for i, sample_idx in enumerate(sample_indices): idx = sample_idx.item() if idx < self.batch_size and not self.finished_flags[idx]: @@ -831,7 +936,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if sample_indices is None: indices_to_end = range(self.batch_size) else: - indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] + indices_to_end = [ + s.item() if torch.is_tensor(s) else s for s in sample_indices + ] for idx in indices_to_end: if idx < self.batch_size and not self.finished_flags[idx]: self.loop.call_soon_threadsafe( @@ -863,7 +970,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: # stop_signal=None is the default sentinel that ends the queue. while True: try: - chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0) + chunk = await asyncio.wait_for( + streamer.audio_queues[0].get(), timeout=120.0 + ) except asyncio.TimeoutError: cancel_event.set() future.cancel() @@ -949,11 +1058,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: "elapsed": elapsed, "speaker": speaker, "audio_secs": round(audio_secs, 2), - "realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None, + "realtime_factor": ( + round(realtime_factor, 3) if realtime_factor is not None else None + ), "chunks": chunk_count, - "first_chunk_secs": round(first_chunk_at - start, 2) - if first_chunk_at is not None - else None, + "first_chunk_secs": ( + round(first_chunk_at - start, 2) if first_chunk_at is not None else None + ), "max_chunk_gap_secs": round(max_chunk_gap, 2), } if profile is not None: diff --git a/web/app/api/generate/route.ts b/web/app/api/generate/route.ts index 310bb01..deb34ed 100644 --- a/web/app/api/generate/route.ts +++ b/web/app/api/generate/route.ts @@ -1,3 +1,14 @@ +/** + * API Proxy Route: POST /api/generate + * + * This route proxies requests from the frontend to the FastAPI backend's /generate endpoint. + * + * Security Architecture: + * The FastAPI backend is configured to bind only to localhost (127.0.0.1). This prevents + * unauthenticated public access to the model inference engine. Next.js acts as a secure + * proxy, allowing the frontend to interact with the backend while maintaining a + * single public-facing origin. + */ import { NextRequest, NextResponse } from "next/server"; export const dynamic = "force-dynamic"; diff --git a/web/app/api/health/route.ts b/web/app/api/health/route.ts index e4d3506..3c24435 100644 --- a/web/app/api/health/route.ts +++ b/web/app/api/health/route.ts @@ -1,3 +1,14 @@ +/** + * API Proxy Route: GET /api/health + * + * This route proxies health check requests from the frontend to the FastAPI backend's /health endpoint. + * + * Security Architecture: + * The FastAPI backend is configured to bind only to localhost (127.0.0.1). This prevents + * unauthenticated public access to the server status and configuration. Next.js acts as a secure + * proxy, allowing the frontend to poll for server readiness and adaptive configuration + * while maintaining a single public-facing origin. + */ import { NextResponse } from "next/server"; const OFFLINE_RESPONSE = { status: "offline" }; diff --git a/web/app/page.tsx b/web/app/page.tsx index 128824a..710251e 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -24,6 +24,8 @@ export interface ServerConfig { default_inference_steps: number; } +// --- State Management --- + interface AppState { script: string; speaker: string; @@ -199,6 +201,8 @@ export default function HomePage() { resumeThresholdSecs: state.resumeThresholdSecs, }); + // --- Server Health & Status Polling --- + // Server health polling — fast while not ready, slow when online useEffect(() => { let timeoutId: ReturnType; @@ -246,6 +250,8 @@ export default function HomePage() { }; }, []); + // --- Generation Handling --- + const handleGenerate = useCallback(async () => { if (!state.script.trim() || state.isGenerating) return; addLog(`${wordCount} words queued`); diff --git a/web/hooks/useStreamingGeneration.ts b/web/hooks/useStreamingGeneration.ts index f5970d8..d6f9fdd 100644 --- a/web/hooks/useStreamingGeneration.ts +++ b/web/hooks/useStreamingGeneration.ts @@ -1,5 +1,16 @@ "use client"; +/** + * Hook for managing real-time streaming audio generation from the VibeVoice server. + * + * Streaming Lifecycle: + * 1. fetch /api/generate: Initiates a POST request to the generation endpoint. + * 2. parse SSE chunks: Listens for Server-Sent Events (SSE) containing audio data or status updates. + * 3. decode base64 float32 PCM: Converts incoming base64-encoded strings into raw Float32 audio samples. + * 4. schedule Web Audio playback: Enqueues audio chunks into an AudioContext for low-latency playback. + * 5. handle adaptive buffering: Monitors playback progress and pauses to refill the buffer if an underrun is detected. + * 6. assemble final WAV Blob: Combines all received chunks into a single WAV file once generation is complete. + */ import { useCallback, useEffect, useRef, useState } from "react"; const SAMPLE_RATE = 24_000; @@ -30,6 +41,9 @@ interface UseStreamingGenerationOptions { resumeThresholdSecs?: number; } +/** + * Concatenates multiple Float32Array chunks into a single Float32Array. + */ function mergeFloat32Arrays(chunks: Float32Array[]): Float32Array { const total = chunks.reduce((sum, chunk) => sum + chunk.length, 0); const out = new Float32Array(total); @@ -41,6 +55,9 @@ function mergeFloat32Arrays(chunks: Float32Array[]): Float32Array, sampleRate: number): Blob { const dataSize = samples.length * 4; const buffer = new ArrayBuffer(44 + dataSize); @@ -68,6 +85,9 @@ function buildWav(samples: Float32Array, sampleRate: number): Blob return new Blob([buffer], { type: "audio/wav" }); } +/** + * Decodes a base64-encoded string into a Float32Array of PCM samples. + */ function decodeFloat32Chunk(data: string): Float32Array { const raw = atob(data); const bytes = new Uint8Array(raw.length); @@ -141,6 +161,9 @@ export function useStreamingGeneration({ }; }, [resetPlayback, revokeCurrentUrl]); + /** + * Creates an AudioBuffer from a chunk and schedules it for playback in the AudioContext. + */ const enqueue = useCallback((ctx: AudioContext, chunk: Float32Array) => { const audioBuffer = ctx.createBuffer(1, chunk.length, SAMPLE_RATE); audioBuffer.copyToChannel(chunk, 0); @@ -152,6 +175,9 @@ export function useStreamingGeneration({ nextStartTimeRef.current = startAt + audioBuffer.duration; }, []); + /** + * Resets the playback timing and enqueues all currently buffered chunks for immediate playback. + */ const flushBufferedAudio = useCallback(() => { const ctx = audioCtxRef.current; if (!ctx || chunksRef.current.length === 0) return; @@ -162,6 +188,10 @@ export function useStreamingGeneration({ hasStartedPlaybackRef.current = true; }, [enqueue]); + /** + * Processes a new audio chunk, either buffering it for initial playback or enqueuing it for + * immediate playback with adaptive buffering logic. + */ const handleAudioChunk = useCallback( (chunk: Float32Array) => { const ctx = audioCtxRef.current;