From 75b84b211bb9e0c1ec31295ae866d3110f9220a9 Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 18:54:14 +0100 Subject: [PATCH 01/11] perf: improve streaming generation pipeline Add CUDA inference hot-path optimizations, safer attention fallback handling, and generation profiling hooks. Improve SSE streaming, browser buffering telemetry, and playback recovery while preserving default audio quality settings. --- .gitattributes | 2 + server/.python-version | 1 + server/start.sh | 48 ++++ server/vibevoice_server.py | 365 +++++++++++++++++++++++++--- web/app/api/generate/route.ts | 7 +- web/app/api/health/route.ts | 1 + web/app/page.tsx | 6 +- web/hooks/useStreamingGeneration.ts | 75 ++++-- web/package.json | 2 +- 9 files changed, 459 insertions(+), 48 deletions(-) create mode 100644 .gitattributes create mode 100644 server/.python-version diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..1bbd695 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.sh text eol=lf +*.py text eol=lf diff --git a/server/.python-version b/server/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/server/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/server/start.sh b/server/start.sh index befa203..2daa340 100755 --- a/server/start.sh +++ b/server/start.sh @@ -7,6 +7,11 @@ # ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv) # ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu) # +# Optional CUDA acceleration: +# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh +# Installs a matching third-party Windows flash-attn wheel when the CUDA venv +# uses Python 3.12, torch 2.6.0, and CUDA 12.4. +# # The two modes maintain completely separate virtual environments so their torch # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; # --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel. @@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then exit 1 fi +validate_flash_attn() { + uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null +} + +remove_broken_flash_attn() { + if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then + if ! validate_flash_attn; then + echo " Installed flash-attn is not usable in this environment; removing it." + uv pip uninstall flash-attn + fi + fi +} + # --------------------------------------------------------------------------- # 2. Sync Python environment # CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the @@ -65,6 +83,36 @@ if $CPU_MODE; then else echo "--> Syncing CUDA Python environment (.venv)..." uv sync + + remove_broken_flash_attn + + if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then + echo "" + echo "--> Checking optional FlashAttention wheel..." + + if validate_flash_attn; then + echo " flash-attn already installed and importable." + else + PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")" + TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")" + CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")" + + if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then + FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" + echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..." + uv pip install "$FLASH_ATTN_WHEEL_URL" + if validate_flash_attn; then + echo " flash-attn import check passed." + else + echo " flash-attn import check failed; removing it and continuing with SDPA." + uv pip uninstall flash-attn + fi + else + echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG." + echo " Continuing with PyTorch SDPA attention." + fi + fi + fi fi # --------------------------------------------------------------------------- diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index f5bc012..39541f5 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -22,11 +22,13 @@ import asyncio import base64 import copy import functools +import importlib.util import json import logging import os import threading import time +import types import urllib.request from contextlib import asynccontextmanager from pathlib import Path @@ -211,8 +213,39 @@ def _init_processor(): def _init_model(device: str): logger.info("Loading model on %s...", device) - load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 - attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" + if device == "cuda": + torch.set_float32_matmul_precision("high") + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) + logger.info( + "PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s", + torch.backends.cuda.flash_sdp_enabled(), + torch.backends.cuda.mem_efficient_sdp_enabled(), + torch.backends.cuda.math_sdp_enabled(), + ) + + cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower() + if device == "cuda" and cuda_dtype == "fp16": + load_dtype = torch.float16 + else: + load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 + logger.info("Loading model weights with dtype %s", load_dtype) + requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower() + has_flash_attn = importlib.util.find_spec("flash_attn") is not None + if requested_attn_impl in {"eager", "sdpa"}: + attn_impl = requested_attn_impl + elif requested_attn_impl == "flash_attention_2": + attn_impl = "flash_attention_2" if has_flash_attn else "sdpa" + else: + attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa" + logger.info("Using Transformers attention implementation: %s", attn_impl) + if device == "cuda" and not has_flash_attn: + logger.info("flash_attn is not installed; using PyTorch SDPA attention.") from vibevoice.modular.modeling_vibevoice_streaming_inference import ( VibeVoiceStreamingForConditionalGenerationInference, @@ -225,9 +258,13 @@ def _init_model(device: str): device_map=device, attn_implementation=attn_impl, ) - except Exception: + except Exception as exc: + if attn_impl == "sdpa": + raise logger.warning( - "Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True + "Model load with %s failed (%s); falling back to sdpa", + attn_impl, + exc, ) model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, @@ -238,9 +275,164 @@ def _init_model(device: str): model.eval() model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"]) + _install_generation_optimizations(model) return model +def _install_generation_optimizations(model: object) -> None: + """Patch VibeVoice hot paths without changing model quality settings.""" + + def profile_enabled() -> bool: + return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1" + + def profile_sync() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def profile_record(self, key: str, elapsed: float) -> None: + stats = getattr(self, "_vibepod_profile", None) + if stats is None: + stats = {} + self._vibepod_profile = stats + bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0}) + bucket["count"] += 1 + bucket["seconds"] += elapsed + + def timed_method(self, key: str, fn, *args, **kwargs): + if not profile_enabled(): + return fn(*args, **kwargs) + profile_sync() + started = time.perf_counter() + result = fn(*args, **kwargs) + profile_sync() + profile_record(self, key, time.perf_counter() - started) + return result + + def prepare_noise_scheduler(self): + scheduler = self.model.noise_scheduler + cache_key = self.ddpm_inference_steps + cache = getattr(self, "_vibepod_scheduler_cache", {}) + cached = cache.get(cache_key) + + if cached is None: + scheduler.set_timesteps(self.ddpm_inference_steps) + cached = { + "num_inference_steps": scheduler.num_inference_steps, + "timesteps": scheduler.timesteps, + "sigmas": scheduler.sigmas, + } + cache[cache_key] = cached + self._vibepod_scheduler_cache = cache + else: + scheduler.num_inference_steps = cached["num_inference_steps"] + scheduler.timesteps = cached["timesteps"] + scheduler.sigmas = cached["sigmas"] + scheduler.model_outputs = [None] * scheduler.config.solver_order + scheduler.lower_order_nums = 0 + scheduler._step_index = None + scheduler._begin_index = None + + return scheduler + + def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0): + scheduler = prepare_noise_scheduler(self) + + condition = torch.cat([condition, neg_condition], dim=0).to( + self.model.prediction_head.device + ) + batch_size = condition.shape[0] // 2 + speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition) + t_batch_cache_key = ( + self.ddpm_inference_steps, + condition.device.type, + condition.device.index, + condition.dtype, + batch_size, + ) + t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {}) + t_batches = t_batch_cache.get(t_batch_cache_key) + if t_batches is None or len(t_batches) != len(scheduler.timesteps): + t_batches = [ + t.repeat(condition.shape[0]).to( + device=condition.device, dtype=condition.dtype + ) + for t in scheduler.timesteps + ] + t_batch_cache[t_batch_cache_key] = t_batches + self._vibepod_t_batch_cache = t_batch_cache + + for t, t_batch in zip(scheduler.timesteps, t_batches): + if batch_size == 1: + combined = speech.expand(condition.shape[0], -1) + else: + combined = torch.cat([speech, speech], dim=0) + if profile_enabled(): + profile_sync() + started = time.perf_counter() + eps = self.model.prediction_head(combined, t_batch, condition=condition) + if profile_enabled(): + profile_sync() + profile_record(self, "diffusion_prediction_head", time.perf_counter() - started) + cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0) + guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + if profile_enabled(): + started = time.perf_counter() + speech = scheduler.step(guided_eps, t, speech).prev_sample + if profile_enabled(): + profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started) + + return speech + + forward_lm = model.forward_lm + forward_tts_lm = model.forward_tts_lm + acoustic_decode = model.model.acoustic_tokenizer.decode + + def forward_lm_profiled(*args, **kwargs): + return timed_method(model, "forward_lm", forward_lm, *args, **kwargs) + + def forward_tts_lm_profiled(*args, **kwargs): + return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs) + + def acoustic_decode_profiled(*args, **kwargs): + return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs) + + model.forward_lm = forward_lm_profiled + model.forward_tts_lm = forward_tts_lm_profiled + model.model.acoustic_tokenizer.decode = acoustic_decode_profiled + model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model) + logger.info("Installed VibeVoice generation hot-path optimizations.") + + +def _model_float_dtype() -> torch.dtype: + try: + return next(_model.parameters()).dtype + except StopIteration: + return torch.float32 + + +def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object: + if torch.is_tensor(value): + if torch.is_floating_point(value): + return value.to(device=device, dtype=dtype) + return value.to(device=device) + if isinstance(value, dict): + for k in list(value.keys()): + value[k] = _move_cached_prompt(value[k], device, dtype) + return value + if isinstance(value, list): + return [_move_cached_prompt(v, device, dtype) for v in value] + if isinstance(value, tuple): + return tuple(_move_cached_prompt(v, device, dtype) for v in value) + if hasattr(value, "key_cache") and hasattr(value, "value_cache"): + value.key_cache = [ + _move_cached_prompt(t, device, dtype) for t in value.key_cache + ] + value.value_cache = [ + _move_cached_prompt(t, device, dtype) for t in value.value_cache + ] + return value + + def _load_voice_presets(device: str) -> dict[str, object]: presets = {} for name, filename in EN_VOICES.items(): @@ -273,9 +465,9 @@ def _load_model_sync() -> None: is_cpu = _device == "cpu" _config["device"] = _device _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) - _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0) - _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4) - _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5) + _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 6.0 if is_cpu else 5.0) + _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0) + _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0) _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) _processor = _init_processor() @@ -364,10 +556,15 @@ def _sync_generate( raise RuntimeError("Generation cancelled.") speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER - voice_preset = copy.deepcopy(_voice_presets[speaker]) + model_dtype = _model_float_dtype() + voice_preset = _move_cached_prompt( + copy.deepcopy(_voice_presets[speaker]), _device, model_dtype + ) steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"] _model.set_ddpm_inference_steps(num_steps=steps) + if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1": + _model._vibepod_profile = {} inputs = _processor.process_input_with_cached_prompt( text=req.text, @@ -380,19 +577,20 @@ def _sync_generate( if torch.is_tensor(v): inputs[k] = v.to(_device) - outputs = _model.generate( - **inputs, - max_new_tokens=None, - cfg_scale=req.cfg_scale, - tokenizer=_processor.tokenizer, - generation_config={"do_sample": False}, - verbose=True, - all_prefilled_outputs=copy.deepcopy(voice_preset), - audio_streamer=streamer, - ) - - if not outputs.speech_outputs or outputs.speech_outputs[0] is None: - raise ValueError("Model returned no audio output.") + with torch.inference_mode(): + _model.generate( + **inputs, + max_new_tokens=None, + cfg_scale=req.cfg_scale, + tokenizer=_processor.tokenizer, + generation_config={"do_sample": False}, + verbose=False, + show_progress_bar=False, + return_speech=False, + stop_check_fn=cancel_event.is_set if cancel_event else None, + all_prefilled_outputs=voice_preset, + audio_streamer=streamer, + ) return speaker @@ -401,6 +599,24 @@ def _sse(event: dict) -> str: return f"data: {json.dumps(event)}\n\n" +def _generation_profile() -> Optional[dict[str, dict[str, float]]]: + if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1": + return None + stats = getattr(_model, "_vibepod_profile", None) + if not stats: + return {} + return { + key: { + "count": value["count"], + "seconds": round(value["seconds"], 3), + "avg_ms": round(value["seconds"] * 1000 / value["count"], 3) + if value["count"] + else 0.0, + } + for key, value in sorted(stats.items()) + } + + @app.post("/generate") async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if _model_status != "online": @@ -417,20 +633,58 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: ) async def event_stream() -> AsyncGenerator[str, None]: - from vibevoice.modular.streamer import AsyncAudioStreamer + class NonBlockingAudioStreamer: + """Async streamer that keeps GPU->CPU copies out of the model thread.""" + + def __init__(self, batch_size: int, stop_signal: object = None) -> None: + self.batch_size = batch_size + self.stop_signal = stop_signal + self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] + self.finished_flags = [False for _ in range(batch_size)] + self.loop = asyncio.get_running_loop() + + def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None: + for i, sample_idx in enumerate(sample_indices): + idx = sample_idx.item() + if idx < self.batch_size and not self.finished_flags[idx]: + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, + audio_chunks[i].detach(), + ) + + def end(self, sample_indices: Optional[torch.Tensor] = None) -> None: + if sample_indices is None: + indices_to_end = range(self.batch_size) + else: + indices_to_end = [ + s.item() if torch.is_tensor(s) else s for s in sample_indices + ] + for idx in indices_to_end: + if idx < self.batch_size and not self.finished_flags[idx]: + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, self.stop_signal + ) + self.finished_flags[idx] = True start = time.monotonic() - streamer = AsyncAudioStreamer(batch_size=1) + streamer = NonBlockingAudioStreamer(batch_size=1) cancel_event = threading.Event() accum_size = max(1, _config["chunk_accum"]) accumulated_chunks = [] + chunk_count = 0 + audio_samples = 0 + first_chunk_at: Optional[float] = None + last_chunk_at: Optional[float] = None + max_chunk_gap = 0.0 + speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER async with _generation_lock: loop = asyncio.get_event_loop() future = loop.run_in_executor( None, functools.partial(_sync_generate, req, streamer, cancel_event) ) + future.add_done_callback(lambda _: streamer.end()) # Drain audio chunks as they arrive from the diffusion head. # stop_signal=None is the default sentinel that ends the queue. @@ -454,17 +708,45 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if chunk is None: # stop signal break - accumulated_chunks.append(chunk.detach().cpu().float()) + accumulated_chunks.append(chunk.detach()) if len(accumulated_chunks) >= accum_size: - combined = torch.cat(accumulated_chunks, dim=0) + now = time.monotonic() + if first_chunk_at is None: + first_chunk_at = now + if last_chunk_at is not None: + max_chunk_gap = max(max_chunk_gap, now - last_chunk_at) + last_chunk_at = now + + combined = ( + torch.cat(accumulated_chunks, dim=0) + .detach() + .to("cpu", dtype=torch.float32) + .contiguous() + ) + chunk_count += 1 + audio_samples += combined.numel() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() yield _sse({"type": "audio_chunk", "data": pcm_b64}) accumulated_chunks = [] # Flush any remaining chunks if accumulated_chunks: - combined = torch.cat(accumulated_chunks, dim=0) + now = time.monotonic() + if first_chunk_at is None: + first_chunk_at = now + if last_chunk_at is not None: + max_chunk_gap = max(max_chunk_gap, now - last_chunk_at) + last_chunk_at = now + + combined = ( + torch.cat(accumulated_chunks, dim=0) + .detach() + .to("cpu", dtype=torch.float32) + .contiguous() + ) + chunk_count += 1 + audio_samples += combined.numel() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() yield _sse({"type": "audio_chunk", "data": pcm_b64}) @@ -479,17 +761,42 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: yield _sse( { "type": "error", - "message": "Internal server error during generation.", + "message": f"Generation failed: {exc}", } ) return elapsed = round(time.monotonic() - start, 1) + audio_secs = audio_samples / SAMPLE_RATE + realtime_factor = audio_secs / elapsed if elapsed > 0 else None + profile = _generation_profile() + if profile is not None: + logger.info("Generation profile: %s", profile) logger.info("Generation complete in %.1fs", elapsed) - yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker}) + complete_event = { + "type": "complete", + "elapsed": elapsed, + "speaker": speaker, + "audio_secs": round(audio_secs, 2), + "realtime_factor": round(realtime_factor, 3) + if realtime_factor is not None + else None, + "chunks": chunk_count, + "first_chunk_secs": round(first_chunk_at - start, 2) + if first_chunk_at is not None + else None, + "max_chunk_gap_secs": round(max_chunk_gap, 2), + } + if profile is not None: + complete_event["profile"] = profile + yield _sse(complete_event) return StreamingResponse( event_stream(), media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + headers={ + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "X-Content-Type-Options": "nosniff", + }, ) diff --git a/web/app/api/generate/route.ts b/web/app/api/generate/route.ts index 8bd1b94..180c659 100644 --- a/web/app/api/generate/route.ts +++ b/web/app/api/generate/route.ts @@ -1,5 +1,8 @@ import { NextRequest, NextResponse } from "next/server"; +export const dynamic = "force-dynamic"; +export const runtime = "nodejs"; + export async function POST(request: NextRequest) { const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; @@ -24,6 +27,7 @@ export async function POST(request: NextRequest) { cfg_scale: body.cfg_scale ?? 1.5, inference_steps: body.inference_steps ?? 10, }), + signal: request.signal, }); if (!upstream.ok) { @@ -36,8 +40,9 @@ export async function POST(request: NextRequest) { status: 200, headers: { "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", + "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", + "X-Content-Type-Options": "nosniff", "X-Accel-Buffering": "no", }, }); diff --git a/web/app/api/health/route.ts b/web/app/api/health/route.ts index dffb3f8..ba17edb 100644 --- a/web/app/api/health/route.ts +++ b/web/app/api/health/route.ts @@ -27,6 +27,7 @@ export async function GET() { message: data.message, progress: data.progress ?? null, voices: data.voices ?? [], + config: data.config ?? null, }, COMMON_OPTIONS ); diff --git a/web/app/page.tsx b/web/app/page.tsx index 275d658..a9f6317 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -130,9 +130,9 @@ const initialState: AppState = { speaker: "carter", cfgScale: 1.5, inferenceSteps: 10, - prebufferSecs: 2.0, - rebufferThresholdSecs: 0.4, - resumeThresholdSecs: 1.5, + prebufferSecs: 5.0, + rebufferThresholdSecs: 1.0, + resumeThresholdSecs: 3.0, isGenerating: false, genElapsed: 0, genPct: null, diff --git a/web/hooks/useStreamingGeneration.ts b/web/hooks/useStreamingGeneration.ts index 192cb27..257f0c8 100644 --- a/web/hooks/useStreamingGeneration.ts +++ b/web/hooks/useStreamingGeneration.ts @@ -3,9 +3,10 @@ import { useCallback, useEffect, useRef, useState } from "react"; const SAMPLE_RATE = 24_000; -const DEFAULT_PREBUFFER_SECS = 2.0; -const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4; -const DEFAULT_RESUME_THRESHOLD_SECS = 1.5; +const DEFAULT_PREBUFFER_SECS = 5.0; +const DEFAULT_REBUFFER_THRESHOLD_SECS = 1.0; +const DEFAULT_RESUME_THRESHOLD_SECS = 3.0; +const MAX_ADAPTIVE_RESUME_SECS = 18.0; interface GenerateOptions { text: string; @@ -104,6 +105,10 @@ export function useStreamingGeneration({ const isAutoBufferingRef = useRef(false); const isUserPausedRef = useRef(false); const audioUrlRef = useRef(null); + const firstChunkSeenRef = useRef(false); + const underrunCountRef = useRef(0); + const totalAudioSamplesRef = useRef(0); + const adaptiveResumeSecsRef = useRef(DEFAULT_RESUME_THRESHOLD_SECS); const revokeCurrentUrl = useCallback(() => { if (audioUrlRef.current) { @@ -122,8 +127,12 @@ export function useStreamingGeneration({ hasStartedPlaybackRef.current = false; isAutoBufferingRef.current = false; isUserPausedRef.current = false; + firstChunkSeenRef.current = false; + underrunCountRef.current = 0; + totalAudioSamplesRef.current = 0; + adaptiveResumeSecsRef.current = resumeThresholdSecs; setIsStreamPaused(false); - }, []); + }, [resumeThresholdSecs]); useEffect(() => { return () => { @@ -158,10 +167,17 @@ export function useStreamingGeneration({ if (!ctx) return; chunksRef.current.push(chunk); + totalAudioSamplesRef.current += chunk.length; + + if (!firstChunkSeenRef.current) { + firstChunkSeenRef.current = true; + onLog("First audio chunk received"); + } if (!hasStartedPlaybackRef.current) { const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE; if (bufferedSecs >= prebufferSecs) { + onLog(`Playback started after ${bufferedSecs.toFixed(1)}s buffered`); flushBufferedAudio(); } return; @@ -171,18 +187,30 @@ export function useStreamingGeneration({ if (isUserPausedRef.current) return; const ahead = nextStartTimeRef.current - ctx.currentTime; - if (ctx.state === "running" && ahead < rebufferThresholdSecs) { - ctx.suspend().catch(() => {}); - isAutoBufferingRef.current = true; - } else if ( - ctx.state === "suspended" && - isAutoBufferingRef.current && - ahead >= resumeThresholdSecs + if ( + ctx.state === "running" && + !isAutoBufferingRef.current && + ahead < rebufferThresholdSecs + ) { + isAutoBufferingRef.current = true; + underrunCountRef.current += 1; + adaptiveResumeSecsRef.current = Math.min( + MAX_ADAPTIVE_RESUME_SECS, + Math.max(resumeThresholdSecs, prebufferSecs + underrunCountRef.current * 2), + ); + ctx.suspend().catch(() => {}); + onLog( + `Buffer underrun ${underrunCountRef.current}; refilling to ${adaptiveResumeSecsRef.current.toFixed(1)}s`, + ); + } else if ( + isAutoBufferingRef.current && + ahead >= adaptiveResumeSecsRef.current ) { - ctx.resume().catch(() => {}); isAutoBufferingRef.current = false; + ctx.resume().catch(() => {}); + onLog(`Buffer recovered with ${ahead.toFixed(1)}s queued`); } - }, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]); + }, [enqueue, flushBufferedAudio, onLog, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]); const generate = useCallback(async (options: GenerateOptions) => { if (!options.text.trim()) return; @@ -239,6 +267,11 @@ export function useStreamingGeneration({ type: "audio_chunk" | "complete" | "error" | "cancelled"; data?: string; elapsed?: number; + audio_secs?: number; + realtime_factor?: number | null; + chunks?: number; + first_chunk_secs?: number | null; + max_chunk_gap_secs?: number; message?: string; }; @@ -247,12 +280,26 @@ export function useStreamingGeneration({ } else if (event.type === "complete") { if (!hasStartedPlaybackRef.current) { flushBufferedAudio(); + } else if (isAutoBufferingRef.current) { + isAutoBufferingRef.current = false; + audioCtxRef.current?.resume().catch(() => {}); } const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE); const audioUrl = URL.createObjectURL(wavBlob); audioUrlRef.current = audioUrl; const kb = (wavBlob.size / 1024).toFixed(0); - onLog(`Done in ${event.elapsed}s - ${kb} KB`); + const audioSecs = event.audio_secs ?? totalAudioSamplesRef.current / SAMPLE_RATE; + const realtimeFactor = + event.realtime_factor ?? + (event.elapsed && event.elapsed > 0 ? audioSecs / event.elapsed : null); + const speedText = + realtimeFactor === null ? "" : ` - ${realtimeFactor.toFixed(2)}x realtime`; + onLog(`Done in ${event.elapsed}s - ${audioSecs.toFixed(1)}s audio${speedText} - ${kb} KB`); + if (event.chunks && event.first_chunk_secs !== undefined) { + onLog( + `Stream: first chunk ${event.first_chunk_secs}s, ${event.chunks} chunks, max gap ${event.max_chunk_gap_secs}s`, + ); + } onSuccess(audioUrl); } else if (event.type === "cancelled") { throw new DOMException("Generation cancelled", "AbortError"); diff --git a/web/package.json b/web/package.json index 3f88249..f5c4ca1 100644 --- a/web/package.json +++ b/web/package.json @@ -4,7 +4,7 @@ "private": true, "scripts": { "dev": "next dev --turbopack", - "build": "next build --turbopack", + "build": "next build", "start": "next start" }, "dependencies": { From 7591d15a5221b417a57a33e9bf7956e72b0628ac Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 20:46:29 +0100 Subject: [PATCH 02/11] perf: CPU async pipeline overlap + INT8 quantization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Overlap acoustic_decode with forward_tts_lm calls using a background ThreadPoolExecutor, hiding ~72s of decode cost behind tts_lm work. Achieved 0.67x realtime (up from 0.43x, ~56% improvement). - vibevoice_generate_patch.py: patched generate() loop reordered to submit decode to thread before running connector + tts_lm×2, then resolve future. Installed as instance method via types.MethodType so uv sync reinstalling the package cannot revert the patch. - Dynamic INT8 quantization of Linear layers (VIBEPOD_QUANTIZE=1, default on CPU). prediction_head excluded — small fixed-size tensors regressed ~20% with INT8 due to pack/unpack overhead. - Auto-detect AVX512_BF16 and load model in bfloat16 if supported (VIBEPOD_CPU_BF16=auto, overridable with 0/1). - CPU thread count auto-configured from logical CPU count; OMP/MKL env vars set accordingly. Lock file preserved around uv sync --no-sources so CPU mode does not alter the shared uv.lock. - torch.compile retained as opt-in (VIBEPOD_COMPILE=1) but marked not recommended — dynamic KV cache shapes prevent kernel reuse. --- server/start.sh | 28 +- server/vibevoice_generate_patch.py | 463 +++++++++++++++++++++++++++++ server/vibevoice_server.py | 196 +++++++++++- 3 files changed, 685 insertions(+), 2 deletions(-) create mode 100644 server/vibevoice_generate_patch.py diff --git a/server/start.sh b/server/start.sh index 2daa340..995fbc1 100755 --- a/server/start.sh +++ b/server/start.sh @@ -79,7 +79,16 @@ echo "" if $CPU_MODE; then echo "--> Syncing CPU Python environment (.venv-cpu)..." export UV_PROJECT_ENVIRONMENT=".venv-cpu" + LOCK_BACKUP="" + if [[ -f uv.lock ]]; then + LOCK_BACKUP="$(mktemp)" + cp uv.lock "$LOCK_BACKUP" + fi uv sync --no-sources + if [[ -n "$LOCK_BACKUP" ]]; then + cp "$LOCK_BACKUP" uv.lock + rm -f "$LOCK_BACKUP" + fi else echo "--> Syncing CUDA Python environment (.venv)..." uv sync @@ -126,11 +135,28 @@ export PYTHONUTF8=1 if $CPU_MODE; then export VIBEPOD_DEVICE="cpu" export UV_PROJECT_ENVIRONMENT=".venv-cpu" + if [[ -z "${VIBEPOD_CPU_THREADS:-}" ]]; then + VIBEPOD_CPU_THREADS="$(uv run --no-sources python -c "import os; print(max(1, (os.cpu_count() or 2) // 2))")" + export VIBEPOD_CPU_THREADS + fi + export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$VIBEPOD_CPU_THREADS}" + export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$VIBEPOD_CPU_THREADS}" + # Dynamic INT8 quantization — on by default for CPU (~22% faster, prediction_head + # excluded automatically to avoid regression on small fixed-size tensors). + # Set VIBEPOD_QUANTIZE=0 to disable if you notice audio quality differences. + export VIBEPOD_QUANTIZE="${VIBEPOD_QUANTIZE:-1}" + # Optional CPU flags: + # VIBEPOD_ASYNC_DECODE=0 Disable async decode/tts_lm overlap (on by default) + # VIBEPOD_CPU_BF16=1 Force bfloat16 weights (auto-detected via AVX512_BF16) + # VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive + # models on CPU — not recommended, kept for experimentation) + UV_RUN_ARGS=(--no-sync --no-sources) else export VIBEPOD_DEVICE="cuda" + UV_RUN_ARGS=() fi -exec uv run uvicorn vibevoice_server:app \ +exec uv run "${UV_RUN_ARGS[@]}" uvicorn vibevoice_server:app \ --host 127.0.0.1 \ --port 8000 \ --log-level info \ diff --git a/server/vibevoice_generate_patch.py b/server/vibevoice_generate_patch.py new file mode 100644 index 0000000..825577d --- /dev/null +++ b/server/vibevoice_generate_patch.py @@ -0,0 +1,463 @@ +""" +VibePod CPU pipeline optimisation — patched VibeVoice generate() loop. + +WHY THIS FILE EXISTS +-------------------- +The VibeVoice inner speech-generation loop runs: + + decode(speech_latent) # 87 ms — VAE decode to audio waveform + audio_chunks.append(chunk) # store for final return value + audio_streamer.put(chunk) # stream to client + acoustic_connector(speech_latent) -> acoustic_embed # 1 ms + forward_tts_lm(acoustic_embed) # ~49 ms (positive) + forward_tts_lm(acoustic_embed) # ~49 ms (negative CFG) + +acoustic_connector and both forward_tts_lm calls depend only on speech_latent / +acoustic_embed — they are completely independent of the decoded audio waveform. +Running decode in a thread while connector + tts_lm run on the main thread hides +~87 ms of decode cost per token behind the ~99 ms of tts_lm work: + + Before: 87 + 1 + 49 + 49 = 186 ms / token + After: max(87, 1 + 49 + 49) = 99 ms / token (~47 % reduction) + +HOW IT WORKS +------------ +At model load time, _install_cpu_pipeline_optimizations() in vibevoice_server.py: + 1. Creates a single-worker ThreadPoolExecutor and attaches it to the model as + model._vibepod_decode_executor. + 2. Installs this module's `patched_generate` as a bound method on the model + instance via types.MethodType, shadowing the class-level generate(). + +Because the patch lives on the *instance*, uv sync reinstalling the VibeVoice +package has no effect — Python resolves instance attributes before class ones. + +MAINTENANCE +----------- +This is a verbatim copy of VibeVoice's generate() method (lines 574–910 of +modeling_vibevoice_streaming_inference.py) with the inner speech loop reordered. +The only changed region is marked with # [VibePod] comments. + +If VibeVoice updates its generate() method, diff the new version against this +file and merge carefully. The sentinel string "[VibePod]" marks every changed +line to make diffing easy. +""" + +import concurrent.futures +import types +from typing import Callable, List, Optional, Union + +import torch +from tqdm import tqdm +from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from transformers.modeling_utils import PreTrainedModel + +from vibevoice.modular.modeling_vibevoice_streaming_inference import ( + TTS_TEXT_WINDOW_SIZE, + TTS_SPEECH_WINDOW_SIZE, + VibeVoiceGenerationOutput, + _update_model_kwargs_for_generation, +) +from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache +from vibevoice.modular.streamer import AudioStreamer, AsyncAudioStreamer + + +def patched_generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speech_input_mask: Optional[torch.BoolTensor] = None, + tts_text_ids: Optional[torch.LongTensor] = None, + return_speech: bool = True, + cfg_scale: float = 1.0, + stop_check_fn: Optional[Callable[[], bool]] = None, + **kwargs, +) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]: + # ── Setup (unchanged from original) ───────────────────────────────────── + tokenizer = kwargs.pop("tokenizer", None) + neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") + + tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None) + tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None) + all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None) + tts_text_ids = tts_text_ids.to(self.device) + + if kwargs.get("max_new_tokens", None) is None: + kwargs["max_new_tokens"] = ( + self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1] + ) + + generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = ( + self._build_generate_config_model_kwargs( + generation_config, inputs, tokenizer, return_processors=True, **kwargs + ) + ) + + negative_kwargs = { + "input_ids": torch.full( + (kwargs["input_ids"].shape[0], 1), + neg_text_input_id, + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "attention_mask": torch.ones( + (kwargs["input_ids"].shape[0], 1), + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + negative_generation_config, negative_model_kwargs, negative_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **negative_kwargs + ) + ) + + tts_lm_kwargs = { + "input_ids": tts_lm_input_ids, + "attention_mask": tts_lm_attention_mask, + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **tts_lm_kwargs + ) + ) + + tts_lm_negative_kwargs = { + "input_ids": torch.full( + (kwargs["input_ids"].shape[0], 1), + neg_text_input_id, + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "attention_mask": torch.ones( + (kwargs["input_ids"].shape[0], 1), + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs + ) + ) + + acoustic_cache = VibeVoiceTokenizerStreamingCache() + batch_size = input_ids.shape[0] + assert batch_size == 1, "Currently only supports batch size == 1" + device = input_ids.device + finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) + verbose = kwargs.get("verbose", False) + + audio_chunks = [[] for _ in range(batch_size)] + tts_text_window_index = 0 + reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) + first_text_window_size = ( + TTS_TEXT_WINDOW_SIZE + if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE + else tts_text_ids.shape[1] + ) + + outputs = all_prefilled_outputs["lm"] + tts_lm_outputs = all_prefilled_outputs["tts_lm"] + negative_outputs = all_prefilled_outputs["neg_lm"] + tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"] + + model_kwargs = _update_model_kwargs_for_generation( + outputs, model_kwargs, num_new_tokens=first_text_window_size + ) + tts_lm_model_kwargs = _update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size + ) + negative_model_kwargs = self._update_model_kwargs_for_generation( + negative_outputs, negative_model_kwargs, is_encoder_decoder=False + ) + tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False + ) + + step = tts_lm_input_ids.shape[1] + total_generated_speech_tokens = 0 + total_prefilled_text_tokens = 0 + if kwargs.get("show_progress_bar", True): + progress_bar = tqdm( + total=tts_lm_generation_config.max_length, + desc=f"Prefilled {step} tokens, current step ({step} / {tts_lm_generation_config.max_length})", + initial=step, + leave=False, + ) + else: + progress_bar = None + + # [VibePod] Grab the executor once; None means standard sequential path. + _vp_executor: Optional[concurrent.futures.ThreadPoolExecutor] = getattr( + self, "_vibepod_decode_executor", None + ) + + # ── Main generation loop (unchanged from original) ─────────────────────── + while True: + if stop_check_fn is not None and stop_check_fn(): + if verbose: + print(f"Generation stopped externally at step {step + 1}") + if audio_streamer is not None: + audio_streamer.end() + break + + if finished_tags.all(): + if hasattr(progress_bar, "set_description"): + progress_bar.set_description("Generation complete") + break + + cur_input_tts_text_ids = tts_text_ids[ + :, + tts_text_window_index * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1) + * TTS_TEXT_WINDOW_SIZE, + ] + next_text_window_size = tts_text_ids[ + :, + (tts_text_window_index + 1) + * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2) + * TTS_TEXT_WINDOW_SIZE, + ].shape[1] + tts_text_window_index += 1 + + if cur_input_tts_text_ids.shape[1] > 0: + input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1) + tts_lm_input_ids = torch.cat([tts_lm_input_ids, cur_input_tts_text_ids], dim=-1) + + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + if verbose: + print( + f"Reached maximum generation length {generation_config.max_length}, stopped it." + ) + reached_samples = torch.arange(batch_size, device=device)[~finished_tags] + if reached_samples.numel() > 0: + reach_max_step_sample[reached_samples] = True + break + + step += cur_input_tts_text_ids.shape[1] + total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1] + if progress_bar is not None: + progress_bar.update(cur_input_tts_text_ids.shape[1]) + progress_bar.set_description( + f"Prefilled {total_prefilled_text_tokens} text tokens, " + f"generated {total_generated_speech_tokens} speech tokens, " + f"current step ({step} / {tts_lm_generation_config.max_length})" + ) + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self.forward_lm( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + model_kwargs = _update_model_kwargs_for_generation( + outputs, model_kwargs, num_new_tokens=next_text_window_size + ) + + tts_lm_model_inputs = self.prepare_inputs_for_generation( + tts_lm_input_ids, **tts_lm_model_kwargs + ) + tts_lm_additional_inputs = { + "tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]), + "lm_last_hidden_state": outputs.last_hidden_state, + } + tts_lm_outputs = self.forward_tts_lm( + **tts_lm_model_inputs, + **tts_lm_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + tts_lm_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False + ) + + diffusion_indices = torch.LongTensor([0]) + + # ── Inner speech loop ──────────────────────────────────────────────── + for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE): + positive_condition = tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] + negative_condition = tts_lm_negative_outputs.last_hidden_state[ + diffusion_indices, -1, : + ] + + speech_latent = self.sample_speech_tokens( + positive_condition, + negative_condition, + cfg_scale=cfg_scale, + ).unsqueeze(1) + + scaled_latent = ( + speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) + - self.model.speech_bias_factor.to(speech_latent.device) + ) + + # [VibePod] If a decode executor is configured, submit decode to a + # background thread so acoustic_connector and forward_tts_lm can run + # concurrently on the main thread. The future is resolved after both + # tts_lm calls complete, before appending/streaming the audio chunk. + # Without the executor, the original sequential path is used unchanged. + if _vp_executor is not None: + _decode_future: concurrent.futures.Future[torch.Tensor] = _vp_executor.submit( + self.model.acoustic_tokenizer.decode, + scaled_latent.to(self.model.acoustic_tokenizer.device), + cache=acoustic_cache, + sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), + use_cache=True, + debug=False, + ) + else: + audio_chunk = self.model.acoustic_tokenizer.decode( + scaled_latent.to(self.model.acoustic_tokenizer.device), + cache=acoustic_cache, + sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), + use_cache=True, + debug=False, + ) + + # [VibePod] connector + tts_lm run here while decode is in the thread. + acoustic_embed = self.model.acoustic_connector(speech_latent) + tts_lm_input_ids = torch.cat( + [tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 + ) + + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + # [VibePod] Resolve before break so audio_chunks stays consistent. + if _vp_executor is not None: + audio_chunk = _decode_future.result() + for i, sample_idx in enumerate(diffusion_indices): + idx = sample_idx.item() + if not finished_tags[idx]: + audio_chunks[idx].append(audio_chunk[i]) + if audio_streamer is not None: + audio_streamer.put(audio_chunk, diffusion_indices) + break + + step += 1 + total_generated_speech_tokens += 1 + if progress_bar is not None: + progress_bar.update(1) + progress_bar.set_description( + f"Prefilled {total_prefilled_text_tokens} text tokens, " + f"generated {total_generated_speech_tokens} speech tokens, " + f"current step ({step} / {tts_lm_generation_config.max_length})" + ) + + tts_lm_model_inputs = self.prepare_inputs_for_generation( + tts_lm_input_ids, **tts_lm_model_kwargs + ) + tts_lm_additional_inputs = { + "tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]), + "lm_last_hidden_state": acoustic_embed, + } + tts_lm_outputs = self.forward_tts_lm( + **tts_lm_model_inputs, + **tts_lm_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + if cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0: + tts_lm_model_kwargs = _update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size + ) + else: + tts_lm_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False + ) + + tts_lm_negative_input_ids = torch.cat( + [tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 + ) + tts_lm_negative_model_inputs = self.prepare_inputs_for_generation( + tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs + ) + tts_lm_negative_additional_inputs = { + "tts_text_masks": torch.zeros_like(tts_lm_negative_input_ids[:, -1:]), + "lm_last_hidden_state": acoustic_embed, + } + tts_lm_negative_outputs = self.forward_tts_lm( + **tts_lm_negative_model_inputs, + **tts_lm_negative_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_negative_outputs, + tts_lm_negative_model_kwargs, + is_encoder_decoder=False, + ) + + # [VibePod] Decode is done (or was never async). Resolve future, + # then append + stream — moved here from before connector/tts_lm. + if _vp_executor is not None: + audio_chunk = _decode_future.result() + for i, sample_idx in enumerate(diffusion_indices): + idx = sample_idx.item() + if not finished_tags[idx]: + audio_chunks[idx].append(audio_chunk[i]) + if audio_streamer is not None: + audio_streamer.put(audio_chunk, diffusion_indices) + + tts_eos_logits = torch.sigmoid( + self.tts_eos_classifier( + tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] + ) + ) + if tts_eos_logits[0].item() > 0.5: + finished_tags[diffusion_indices] = True + if audio_streamer is not None: + audio_streamer.end(diffusion_indices) + + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + if verbose: + print( + f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." + ) + reached_samples = torch.arange(batch_size, device=device)[~finished_tags] + if reached_samples.numel() > 0: + reach_max_step_sample[reached_samples] = True + break + + if audio_streamer is not None: + audio_streamer.end() + + # ── Audio finalisation (unchanged from original) ───────────────────────── + final_audio_outputs = [] + for sample_chunks in audio_chunks: + if sample_chunks: + concatenated_audio = torch.cat(sample_chunks, dim=-1) + final_audio_outputs.append(concatenated_audio) + else: + final_audio_outputs.append(None) + + if reach_max_step_sample is not None and reach_max_step_sample.any(): + print( + f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." + ) + + return VibeVoiceGenerationOutput( + sequences=tts_lm_input_ids, + speech_outputs=final_audio_outputs if return_speech else None, + reach_max_step_sample=reach_max_step_sample, + ) + + +def install(model: object, executor: concurrent.futures.ThreadPoolExecutor) -> None: + """Install the patched generate() on a model instance and attach the executor.""" + model._vibepod_decode_executor = executor + model.generate = types.MethodType(patched_generate, model) diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 39541f5..14ccb36 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -20,12 +20,14 @@ Device selection: import asyncio import base64 +import concurrent.futures import copy import functools import importlib.util import json import logging import os +import platform import threading import time import types @@ -64,6 +66,10 @@ DEFAULT_SPEAKER = "carter" _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] +# ── Decode pipeline executor ──────────────────────────────────────────────────── + +_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None + # ── Device selection ──────────────────────────────────────────────────────────── # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. # Falls back to auto-detection if not set. @@ -108,6 +114,40 @@ def _env_float(name: str, default: float) -> float: return default +def _cpu_supports_bf16() -> bool: + """Return True if the CPU has AVX512_BF16 hardware support.""" + return ( + hasattr(torch, "cpu") + and hasattr(torch.cpu, "is_avx512_bf16_supported") + and torch.cpu.is_avx512_bf16_supported() + ) + + +def _configure_cpu_runtime() -> dict[str, object]: + logical_cpus = os.cpu_count() or 1 + default_threads = ( + max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus + ) + intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads) + interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) + mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" + + torch.set_num_threads(max(1, intra_threads)) + try: + torch.set_num_interop_threads(max(1, interop_threads)) + except RuntimeError as exc: + logger.warning("Could not set CPU inter-op threads: %s", exc) + + torch.backends.mkldnn.enabled = mkldnn_enabled + return { + "logical_cpus": logical_cpus, + "threads": torch.get_num_threads(), + "interop_threads": torch.get_num_interop_threads(), + "mkldnn_available": torch.backends.mkldnn.is_available(), + "mkldnn_enabled": torch.backends.mkldnn.enabled, + } + + # ── Global state ──────────────────────────────────────────────────────────────── ModelStatus = Literal["downloading", "loading", "online", "error"] @@ -228,12 +268,29 @@ def _init_model(device: str): torch.backends.cuda.mem_efficient_sdp_enabled(), torch.backends.cuda.math_sdp_enabled(), ) + elif device == "cpu": + torch.set_float32_matmul_precision("medium") + logger.info("CPU runtime configuration: %s", _configure_cpu_runtime()) cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower() if device == "cuda" and cuda_dtype == "fp16": load_dtype = torch.float16 + elif device == "cuda": + load_dtype = torch.bfloat16 else: - load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 + cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower() + if cpu_bf16_env == "1": + load_dtype = torch.bfloat16 + logger.info("CPU BF16 forced via VIBEPOD_CPU_BF16=1") + elif cpu_bf16_env == "0": + load_dtype = torch.float32 + logger.info("CPU float32 forced via VIBEPOD_CPU_BF16=0") + elif _cpu_supports_bf16(): + load_dtype = torch.bfloat16 + logger.info("AVX512_BF16 detected — loading model in bfloat16") + else: + load_dtype = torch.float32 + logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)") logger.info("Loading model weights with dtype %s", load_dtype) requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower() has_flash_attn = importlib.util.find_spec("flash_attn") is not None @@ -274,8 +331,90 @@ def _init_model(device: str): ) model.eval() + if device == "cpu": + model = _apply_cpu_optimizations(model) model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"]) _install_generation_optimizations(model) + if device == "cpu": + # Must run after _install_generation_optimizations so the async wrapper + # sits outside the profiling wrapper (VibeVoice calls async → profiling → real decode). + _install_cpu_pipeline_optimizations(model) + return model + + +def _apply_cpu_optimizations(model: object) -> object: + """Apply optional post-load CPU optimizations. Returns (possibly new) model object.""" + + do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1" + do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1" + + if do_quantize: + logger.info("Applying dynamic INT8 quantization to Linear layers...") + try: + import torch.ao.quantization + + # The diffusion prediction_head operates on small fixed-size tensors where + # INT8 pack/unpack overhead exceeds the matmul savings (~+20% regression in + # testing). Save and restore it so it stays in float32. + saved_prediction_head = None + if hasattr(model, "model") and hasattr(model.model, "prediction_head"): + saved_prediction_head = model.model.prediction_head + del model.model.prediction_head + + model = torch.ao.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) + + if saved_prediction_head is not None: + model.model.prediction_head = saved_prediction_head + logger.info( + "Dynamic INT8 quantization applied (prediction_head excluded — stays float32)." + ) + else: + logger.info("Dynamic INT8 quantization applied.") + except Exception as exc: + logger.warning("Dynamic quantization failed: %s — skipping", exc) + + if do_compile: + # torch.compile with inductor on CPU is ineffective for autoregressive TTS: + # each token step produces a unique input shape, so every step triggers a new + # kernel compile event rather than reusing compiled code. Kept as an escape + # hatch but not recommended. + compile_mode = os.environ.get("VIBEPOD_COMPILE_MODE", "reduce-overhead") + logger.info( + "torch.compile enabled (mode=%s) — NOTE: limited benefit for autoregressive" + " models on CPU due to dynamic sequence lengths.", + compile_mode, + ) + _compile_targets: list[tuple[str, object, str, bool]] = [ + ("forward_tts_lm", model, "forward_tts_lm", True), + ] + if hasattr(model, "model"): + inner = model.model + if hasattr(inner, "prediction_head"): + _compile_targets.append( + ("prediction_head", inner, "prediction_head", False) + ) + if hasattr(inner, "acoustic_tokenizer") and hasattr( + inner.acoustic_tokenizer, "decode" + ): + _compile_targets.append( + ("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False) + ) + + for label, obj, attr, dynamic in _compile_targets: + try: + compiled = torch.compile( + getattr(obj, attr), + backend="inductor", + mode=compile_mode, + dynamic=dynamic, + ) + setattr(obj, attr, compiled) + logger.info(" compiled: %s", label) + except Exception as exc: + logger.warning(" torch.compile failed for %s: %s — skipping", label, exc) + return model @@ -403,6 +542,45 @@ def _install_generation_optimizations(model: object) -> None: logger.info("Installed VibeVoice generation hot-path optimizations.") +def _install_cpu_pipeline_optimizations(model: object) -> None: + """Install the async-decode generate() patch and its thread pool on the model instance. + + The VibeVoice inner loop runs: + decode(speech_latent) → append → put → connector → tts_lm(pos) → tts_lm(neg) + + connector and both tts_lm calls only need speech_latent/acoustic_embed, not + audio_chunk. The patched generate() reorders this to: + submit decode to thread → connector → tts_lm(pos) → tts_lm(neg) + → wait for decode future → append → put + + The patch is applied as an instance method via types.MethodType, which shadows + the class-level generate() and is immune to uv sync reinstalling the package. + """ + global _decode_executor + + if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1": + logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.") + return + + try: + import vibevoice_generate_patch + except ImportError: + logger.warning( + "vibevoice_generate_patch not found — async decode unavailable. " + "Ensure vibevoice_generate_patch.py is in the server directory." + ) + return + + _decode_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="vibepod-decode" + ) + vibevoice_generate_patch.install(model, _decode_executor) + logger.info( + "CPU pipeline: patched generate() installed (async decode enabled) — " + "acoustic_decode overlaps forward_tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0." + ) + + def _model_float_dtype() -> torch.dtype: try: return next(_model.parameters()).dtype @@ -469,6 +647,20 @@ def _load_model_sync() -> None: _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0) _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0) _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) + if is_cpu: + logical_cpus = os.cpu_count() or 1 + _config["cpu_threads"] = _env_int( + "VIBEPOD_CPU_THREADS", + max(1, logical_cpus // 2) + if platform.system() == "Windows" + else logical_cpus, + ) + _config["cpu_interop_threads"] = _env_int( + "VIBEPOD_CPU_INTEROP_THREADS", 1 + ) + _config["cpu_mkldnn"] = os.environ.get( + "VIBEPOD_CPU_MKLDNN", "1" + ).strip() != "0" _processor = _init_processor() _model = _init_model(_device) @@ -494,6 +686,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader") thread.start() yield + if _decode_executor is not None: + _decode_executor.shutdown(wait=False) app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan) From 98e2bf9237a07b2f79af43e3905e06e71bf8523f Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 21:30:07 +0100 Subject: [PATCH 03/11] perf: migrate to JezzWTF/VibeVoice fork, parallel CFG executors Switch vibevoice dependency from microsoft/VibeVoice to JezzWTF/VibeVoice fork (commit e76701f) which contains the async decode + parallel CFG optimisations directly in generate(). Removes the instance-method patching approach (vibevoice_generate_patch.py deleted). server/vibevoice_server.py: - Add _cfg_executor (ThreadPoolExecutor, 1 worker) alongside _decode_executor - _install_cpu_pipeline_optimizations now sets both executors directly as model._vibepod_decode_executor and model._vibepod_cfg_executor - Both executors shut down in lifespan on exit - Remove vibevoice_generate_patch import/install (no longer needed) server/pyproject.toml: - vibevoice source changed to git+https://github.com/JezzWTF/VibeVoice.git - No machine-local paths; works identically on any clone --- server/pyproject.toml | 3 +- server/uv.lock | 16 +- server/vibevoice_generate_patch.py | 463 ----------------------------- server/vibevoice_server.py | 50 ++-- 4 files changed, 36 insertions(+), 496 deletions(-) delete mode 100644 server/vibevoice_generate_patch.py diff --git a/server/pyproject.toml b/server/pyproject.toml index 3756ed3..5099808 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -8,7 +8,8 @@ dependencies = [ # To switch back to CPU-only, remove the [tool.uv.sources] torch entry below. "torch>=2.0.0", # VibeVoice custom model + processor classes (not yet in upstream transformers) - "vibevoice @ git+https://github.com/microsoft/VibeVoice.git", + # Uses JezzWTF/VibeVoice fork so VibePod-specific optimisations land here. + "vibevoice @ git+https://github.com/JezzWTF/VibeVoice.git", # Exact version required by vibevoice's streaming TTS module "transformers==4.51.3", "fastapi>=0.111.0", diff --git a/server/uv.lock b/server/uv.lock index 7fc34c0..187530e 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -1479,7 +1479,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -1490,7 +1490,7 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, @@ -1509,9 +1509,9 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "nvidia-cusparse-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, @@ -1522,7 +1522,7 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, @@ -3109,13 +3109,13 @@ requires-dist = [ { name = "torch", specifier = ">=2.0.0", index = "https://download.pytorch.org/whl/cu124" }, { name = "transformers", specifier = "==4.51.3" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.29.0" }, - { name = "vibevoice", git = "https://github.com/microsoft/VibeVoice.git" }, + { name = "vibevoice", git = "https://github.com/JezzWTF/VibeVoice.git" }, ] [[package]] name = "vibevoice" version = "1.0.0" -source = { git = "https://github.com/microsoft/VibeVoice.git#e73d1e17c3754f046352014856a922f8208fb5d3" } +source = { git = "https://github.com/JezzWTF/VibeVoice.git#e76701f17a0d93cd612d56f0db5865a615c4d16e" } dependencies = [ { name = "absl-py" }, { name = "accelerate" }, diff --git a/server/vibevoice_generate_patch.py b/server/vibevoice_generate_patch.py deleted file mode 100644 index 825577d..0000000 --- a/server/vibevoice_generate_patch.py +++ /dev/null @@ -1,463 +0,0 @@ -""" -VibePod CPU pipeline optimisation — patched VibeVoice generate() loop. - -WHY THIS FILE EXISTS --------------------- -The VibeVoice inner speech-generation loop runs: - - decode(speech_latent) # 87 ms — VAE decode to audio waveform - audio_chunks.append(chunk) # store for final return value - audio_streamer.put(chunk) # stream to client - acoustic_connector(speech_latent) -> acoustic_embed # 1 ms - forward_tts_lm(acoustic_embed) # ~49 ms (positive) - forward_tts_lm(acoustic_embed) # ~49 ms (negative CFG) - -acoustic_connector and both forward_tts_lm calls depend only on speech_latent / -acoustic_embed — they are completely independent of the decoded audio waveform. -Running decode in a thread while connector + tts_lm run on the main thread hides -~87 ms of decode cost per token behind the ~99 ms of tts_lm work: - - Before: 87 + 1 + 49 + 49 = 186 ms / token - After: max(87, 1 + 49 + 49) = 99 ms / token (~47 % reduction) - -HOW IT WORKS ------------- -At model load time, _install_cpu_pipeline_optimizations() in vibevoice_server.py: - 1. Creates a single-worker ThreadPoolExecutor and attaches it to the model as - model._vibepod_decode_executor. - 2. Installs this module's `patched_generate` as a bound method on the model - instance via types.MethodType, shadowing the class-level generate(). - -Because the patch lives on the *instance*, uv sync reinstalling the VibeVoice -package has no effect — Python resolves instance attributes before class ones. - -MAINTENANCE ------------ -This is a verbatim copy of VibeVoice's generate() method (lines 574–910 of -modeling_vibevoice_streaming_inference.py) with the inner speech loop reordered. -The only changed region is marked with # [VibePod] comments. - -If VibeVoice updates its generate() method, diff the new version against this -file and merge carefully. The sentinel string "[VibePod]" marks every changed -line to make diffing easy. -""" - -import concurrent.futures -import types -from typing import Callable, List, Optional, Union - -import torch -from tqdm import tqdm -from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList -from transformers.modeling_utils import PreTrainedModel - -from vibevoice.modular.modeling_vibevoice_streaming_inference import ( - TTS_TEXT_WINDOW_SIZE, - TTS_SPEECH_WINDOW_SIZE, - VibeVoiceGenerationOutput, - _update_model_kwargs_for_generation, -) -from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache -from vibevoice.modular.streamer import AudioStreamer, AsyncAudioStreamer - - -def patched_generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - speech_tensors: Optional[torch.FloatTensor] = None, - speech_masks: Optional[torch.BoolTensor] = None, - speech_input_mask: Optional[torch.BoolTensor] = None, - tts_text_ids: Optional[torch.LongTensor] = None, - return_speech: bool = True, - cfg_scale: float = 1.0, - stop_check_fn: Optional[Callable[[], bool]] = None, - **kwargs, -) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]: - # ── Setup (unchanged from original) ───────────────────────────────────── - tokenizer = kwargs.pop("tokenizer", None) - neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") - - tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None) - tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None) - all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None) - tts_text_ids = tts_text_ids.to(self.device) - - if kwargs.get("max_new_tokens", None) is None: - kwargs["max_new_tokens"] = ( - self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1] - ) - - generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = ( - self._build_generate_config_model_kwargs( - generation_config, inputs, tokenizer, return_processors=True, **kwargs - ) - ) - - negative_kwargs = { - "input_ids": torch.full( - (kwargs["input_ids"].shape[0], 1), - neg_text_input_id, - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "attention_mask": torch.ones( - (kwargs["input_ids"].shape[0], 1), - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "max_new_tokens": kwargs.get("max_new_tokens", 100), - } - negative_generation_config, negative_model_kwargs, negative_input_ids = ( - self._build_generate_config_model_kwargs( - None, None, tokenizer, return_processors=False, **negative_kwargs - ) - ) - - tts_lm_kwargs = { - "input_ids": tts_lm_input_ids, - "attention_mask": tts_lm_attention_mask, - "max_new_tokens": kwargs.get("max_new_tokens", 100), - } - tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = ( - self._build_generate_config_model_kwargs( - None, None, tokenizer, return_processors=False, **tts_lm_kwargs - ) - ) - - tts_lm_negative_kwargs = { - "input_ids": torch.full( - (kwargs["input_ids"].shape[0], 1), - neg_text_input_id, - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "attention_mask": torch.ones( - (kwargs["input_ids"].shape[0], 1), - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "max_new_tokens": kwargs.get("max_new_tokens", 100), - } - tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids = ( - self._build_generate_config_model_kwargs( - None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs - ) - ) - - acoustic_cache = VibeVoiceTokenizerStreamingCache() - batch_size = input_ids.shape[0] - assert batch_size == 1, "Currently only supports batch size == 1" - device = input_ids.device - finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) - verbose = kwargs.get("verbose", False) - - audio_chunks = [[] for _ in range(batch_size)] - tts_text_window_index = 0 - reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) - first_text_window_size = ( - TTS_TEXT_WINDOW_SIZE - if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE - else tts_text_ids.shape[1] - ) - - outputs = all_prefilled_outputs["lm"] - tts_lm_outputs = all_prefilled_outputs["tts_lm"] - negative_outputs = all_prefilled_outputs["neg_lm"] - tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"] - - model_kwargs = _update_model_kwargs_for_generation( - outputs, model_kwargs, num_new_tokens=first_text_window_size - ) - tts_lm_model_kwargs = _update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size - ) - negative_model_kwargs = self._update_model_kwargs_for_generation( - negative_outputs, negative_model_kwargs, is_encoder_decoder=False - ) - tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False - ) - - step = tts_lm_input_ids.shape[1] - total_generated_speech_tokens = 0 - total_prefilled_text_tokens = 0 - if kwargs.get("show_progress_bar", True): - progress_bar = tqdm( - total=tts_lm_generation_config.max_length, - desc=f"Prefilled {step} tokens, current step ({step} / {tts_lm_generation_config.max_length})", - initial=step, - leave=False, - ) - else: - progress_bar = None - - # [VibePod] Grab the executor once; None means standard sequential path. - _vp_executor: Optional[concurrent.futures.ThreadPoolExecutor] = getattr( - self, "_vibepod_decode_executor", None - ) - - # ── Main generation loop (unchanged from original) ─────────────────────── - while True: - if stop_check_fn is not None and stop_check_fn(): - if verbose: - print(f"Generation stopped externally at step {step + 1}") - if audio_streamer is not None: - audio_streamer.end() - break - - if finished_tags.all(): - if hasattr(progress_bar, "set_description"): - progress_bar.set_description("Generation complete") - break - - cur_input_tts_text_ids = tts_text_ids[ - :, - tts_text_window_index * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1) - * TTS_TEXT_WINDOW_SIZE, - ] - next_text_window_size = tts_text_ids[ - :, - (tts_text_window_index + 1) - * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2) - * TTS_TEXT_WINDOW_SIZE, - ].shape[1] - tts_text_window_index += 1 - - if cur_input_tts_text_ids.shape[1] > 0: - input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1) - tts_lm_input_ids = torch.cat([tts_lm_input_ids, cur_input_tts_text_ids], dim=-1) - - if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: - if verbose: - print( - f"Reached maximum generation length {generation_config.max_length}, stopped it." - ) - reached_samples = torch.arange(batch_size, device=device)[~finished_tags] - if reached_samples.numel() > 0: - reach_max_step_sample[reached_samples] = True - break - - step += cur_input_tts_text_ids.shape[1] - total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1] - if progress_bar is not None: - progress_bar.update(cur_input_tts_text_ids.shape[1]) - progress_bar.set_description( - f"Prefilled {total_prefilled_text_tokens} text tokens, " - f"generated {total_generated_speech_tokens} speech tokens, " - f"current step ({step} / {tts_lm_generation_config.max_length})" - ) - - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self.forward_lm( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - model_kwargs = _update_model_kwargs_for_generation( - outputs, model_kwargs, num_new_tokens=next_text_window_size - ) - - tts_lm_model_inputs = self.prepare_inputs_for_generation( - tts_lm_input_ids, **tts_lm_model_kwargs - ) - tts_lm_additional_inputs = { - "tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]), - "lm_last_hidden_state": outputs.last_hidden_state, - } - tts_lm_outputs = self.forward_tts_lm( - **tts_lm_model_inputs, - **tts_lm_additional_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - tts_lm_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False - ) - - diffusion_indices = torch.LongTensor([0]) - - # ── Inner speech loop ──────────────────────────────────────────────── - for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE): - positive_condition = tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] - negative_condition = tts_lm_negative_outputs.last_hidden_state[ - diffusion_indices, -1, : - ] - - speech_latent = self.sample_speech_tokens( - positive_condition, - negative_condition, - cfg_scale=cfg_scale, - ).unsqueeze(1) - - scaled_latent = ( - speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - - self.model.speech_bias_factor.to(speech_latent.device) - ) - - # [VibePod] If a decode executor is configured, submit decode to a - # background thread so acoustic_connector and forward_tts_lm can run - # concurrently on the main thread. The future is resolved after both - # tts_lm calls complete, before appending/streaming the audio chunk. - # Without the executor, the original sequential path is used unchanged. - if _vp_executor is not None: - _decode_future: concurrent.futures.Future[torch.Tensor] = _vp_executor.submit( - self.model.acoustic_tokenizer.decode, - scaled_latent.to(self.model.acoustic_tokenizer.device), - cache=acoustic_cache, - sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), - use_cache=True, - debug=False, - ) - else: - audio_chunk = self.model.acoustic_tokenizer.decode( - scaled_latent.to(self.model.acoustic_tokenizer.device), - cache=acoustic_cache, - sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), - use_cache=True, - debug=False, - ) - - # [VibePod] connector + tts_lm run here while decode is in the thread. - acoustic_embed = self.model.acoustic_connector(speech_latent) - tts_lm_input_ids = torch.cat( - [tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 - ) - - if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: - # [VibePod] Resolve before break so audio_chunks stays consistent. - if _vp_executor is not None: - audio_chunk = _decode_future.result() - for i, sample_idx in enumerate(diffusion_indices): - idx = sample_idx.item() - if not finished_tags[idx]: - audio_chunks[idx].append(audio_chunk[i]) - if audio_streamer is not None: - audio_streamer.put(audio_chunk, diffusion_indices) - break - - step += 1 - total_generated_speech_tokens += 1 - if progress_bar is not None: - progress_bar.update(1) - progress_bar.set_description( - f"Prefilled {total_prefilled_text_tokens} text tokens, " - f"generated {total_generated_speech_tokens} speech tokens, " - f"current step ({step} / {tts_lm_generation_config.max_length})" - ) - - tts_lm_model_inputs = self.prepare_inputs_for_generation( - tts_lm_input_ids, **tts_lm_model_kwargs - ) - tts_lm_additional_inputs = { - "tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]), - "lm_last_hidden_state": acoustic_embed, - } - tts_lm_outputs = self.forward_tts_lm( - **tts_lm_model_inputs, - **tts_lm_additional_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - if cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0: - tts_lm_model_kwargs = _update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size - ) - else: - tts_lm_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False - ) - - tts_lm_negative_input_ids = torch.cat( - [tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 - ) - tts_lm_negative_model_inputs = self.prepare_inputs_for_generation( - tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs - ) - tts_lm_negative_additional_inputs = { - "tts_text_masks": torch.zeros_like(tts_lm_negative_input_ids[:, -1:]), - "lm_last_hidden_state": acoustic_embed, - } - tts_lm_negative_outputs = self.forward_tts_lm( - **tts_lm_negative_model_inputs, - **tts_lm_negative_additional_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_negative_outputs, - tts_lm_negative_model_kwargs, - is_encoder_decoder=False, - ) - - # [VibePod] Decode is done (or was never async). Resolve future, - # then append + stream — moved here from before connector/tts_lm. - if _vp_executor is not None: - audio_chunk = _decode_future.result() - for i, sample_idx in enumerate(diffusion_indices): - idx = sample_idx.item() - if not finished_tags[idx]: - audio_chunks[idx].append(audio_chunk[i]) - if audio_streamer is not None: - audio_streamer.put(audio_chunk, diffusion_indices) - - tts_eos_logits = torch.sigmoid( - self.tts_eos_classifier( - tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] - ) - ) - if tts_eos_logits[0].item() > 0.5: - finished_tags[diffusion_indices] = True - if audio_streamer is not None: - audio_streamer.end(diffusion_indices) - - if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: - if verbose: - print( - f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." - ) - reached_samples = torch.arange(batch_size, device=device)[~finished_tags] - if reached_samples.numel() > 0: - reach_max_step_sample[reached_samples] = True - break - - if audio_streamer is not None: - audio_streamer.end() - - # ── Audio finalisation (unchanged from original) ───────────────────────── - final_audio_outputs = [] - for sample_chunks in audio_chunks: - if sample_chunks: - concatenated_audio = torch.cat(sample_chunks, dim=-1) - final_audio_outputs.append(concatenated_audio) - else: - final_audio_outputs.append(None) - - if reach_max_step_sample is not None and reach_max_step_sample.any(): - print( - f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." - ) - - return VibeVoiceGenerationOutput( - sequences=tts_lm_input_ids, - speech_outputs=final_audio_outputs if return_speech else None, - reach_max_step_sample=reach_max_step_sample, - ) - - -def install(model: object, executor: concurrent.futures.ThreadPoolExecutor) -> None: - """Install the patched generate() on a model instance and attach the executor.""" - model._vibepod_decode_executor = executor - model.generate = types.MethodType(patched_generate, model) diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 14ccb36..2516a59 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -66,9 +66,12 @@ DEFAULT_SPEAKER = "carter" _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] -# ── Decode pipeline executor ──────────────────────────────────────────────────── +# ── Pipeline executors ───────────────────────────────────────────────────────── +# _decode_executor: overlaps acoustic_decode with forward_tts_lm (1 worker). +# _cfg_executor: runs positive + negative forward_tts_lm in parallel (1 worker). _decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None +_cfg_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None # ── Device selection ──────────────────────────────────────────────────────────── # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. @@ -543,41 +546,38 @@ def _install_generation_optimizations(model: object) -> None: def _install_cpu_pipeline_optimizations(model: object) -> None: - """Install the async-decode generate() patch and its thread pool on the model instance. + """Attach pipeline executors to the model for the optimised generate() loop. - The VibeVoice inner loop runs: - decode(speech_latent) → append → put → connector → tts_lm(pos) → tts_lm(neg) + The JezzWTF/VibeVoice fork's generate() checks for two optional attributes: - connector and both tts_lm calls only need speech_latent/acoustic_embed, not - audio_chunk. The patched generate() reorders this to: - submit decode to thread → connector → tts_lm(pos) → tts_lm(neg) - → wait for decode future → append → put + model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) used to + overlap acoustic_decode with acoustic_connector + forward_tts_lm. - The patch is applied as an instance method via types.MethodType, which shadows - the class-level generate() and is immune to uv sync reinstalling the package. + model._vibepod_cfg_executor — ThreadPoolExecutor (1 worker) used to + run the positive and negative forward_tts_lm calls in parallel, so + both CFG passes execute concurrently instead of sequentially. + + Both are None by default, making the fork's generate() behave identically + to upstream on CUDA or any machine where these aren't set. """ - global _decode_executor + global _decode_executor, _cfg_executor if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1": - logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.") - return - - try: - import vibevoice_generate_patch - except ImportError: - logger.warning( - "vibevoice_generate_patch not found — async decode unavailable. " - "Ensure vibevoice_generate_patch.py is in the server directory." - ) + logger.info("CPU async decode/CFG parallelism disabled via VIBEPOD_ASYNC_DECODE=0.") return _decode_executor = concurrent.futures.ThreadPoolExecutor( max_workers=1, thread_name_prefix="vibepod-decode" ) - vibevoice_generate_patch.install(model, _decode_executor) + _cfg_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="vibepod-cfg" + ) + model._vibepod_decode_executor = _decode_executor + model._vibepod_cfg_executor = _cfg_executor logger.info( - "CPU pipeline: patched generate() installed (async decode enabled) — " - "acoustic_decode overlaps forward_tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0." + "CPU pipeline: decode executor and CFG executor attached — " + "acoustic_decode overlaps tts_lm, pos/neg CFG runs in parallel. " + "Disable with VIBEPOD_ASYNC_DECODE=0." ) @@ -688,6 +688,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: yield if _decode_executor is not None: _decode_executor.shutdown(wait=False) + if _cfg_executor is not None: + _cfg_executor.shutdown(wait=False) app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan) From d80d5ba46bc51f3eac10d626861de849fdb42d93 Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 21:46:02 +0100 Subject: [PATCH 04/11] fix: update lock to vibevoice fe832f2 (inference_mode thread fix) --- server/uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/uv.lock b/server/uv.lock index 187530e..1650583 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -3115,7 +3115,7 @@ requires-dist = [ [[package]] name = "vibevoice" version = "1.0.0" -source = { git = "https://github.com/JezzWTF/VibeVoice.git#e76701f17a0d93cd612d56f0db5865a615c4d16e" } +source = { git = "https://github.com/JezzWTF/VibeVoice.git#fe832f20e3d1638594f551a08f02253f14408dbd" } dependencies = [ { name = "absl-py" }, { name = "accelerate" }, From 01ab3d1fc47587f29b0ea56091f41f4d34be17ba Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 23:20:46 +0100 Subject: [PATCH 05/11] perf(cpu): tune streaming playback Keep CPU async decode enabled without CFG parallelism, expand CPU buffering defaults for smooth playback, prevent CPU startup from mutating the lockfile during thread autodetection, and document runtime tuning variables in the example environment file. --- .env.example | 39 ++++++++++++++++++++++ server/start.sh | 2 +- server/vibevoice_server.py | 47 ++++++++++++--------------- web/components/GenerationControls.tsx | 6 ++-- web/hooks/useStreamingGeneration.ts | 2 +- 5 files changed, 65 insertions(+), 31 deletions(-) diff --git a/.env.example b/.env.example index 537770d..dc32c2b 100644 --- a/.env.example +++ b/.env.example @@ -8,3 +8,42 @@ HF_TOKEN= # Override the HuggingFace model cache directory (optional) # HF_HOME=/path/to/hf-cache + +# --------------------------------------------------------------------------- +# Runtime tuning +# --------------------------------------------------------------------------- + +# Force the Python server device. Usually set by `pnpm dev` / `pnpm dev:cpu`. +# VIBEPOD_DEVICE=cuda +# VIBEPOD_DEVICE=cpu + +# CPU mode: keep async decode enabled. This overlaps acoustic decoding with +# language-model work and measured ~20% faster on an 8-thread CPU run. +VIBEPOD_ASYNC_DECODE=1 + +# CPU mode: thread tuning. On an 8-core / 16-thread Ryzen test system, +# 8 worker threads with 1 inter-op thread gave the best wall time, while 12 +# over-subscribed and regressed. +# VIBEPOD_CPU_THREADS=8 +# VIBEPOD_CPU_INTEROP_THREADS=1 + +# CPU mode: playback buffering. CPU generation is slower than realtime, so +# smooth streaming needs a larger initial buffer than CUDA. Lower these for +# faster startup if you are OK with occasional rebuffering. +# VIBEPOD_PREBUFFER_SECS=24 +# VIBEPOD_REBUFFER_THRESHOLD_SECS=2 +# VIBEPOD_RESUME_THRESHOLD_SECS=12 + +# CPU mode: dynamic INT8 quantization is enabled by default in start.sh. +# Set to 0 if you are comparing quality/performance or debugging. +# VIBEPOD_QUANTIZE=1 + +# CUDA mode: dtype and attention selection. Defaults are bf16 + SDPA unless +# optional FlashAttention is explicitly enabled and importable. +# VIBEPOD_CUDA_DTYPE=bf16 +# VIBEPOD_ATTN_IMPL=sdpa +# VIBEPOD_ENABLE_FLASH_ATTN=0 + +# Debug/profiling. Keep disabled for benchmark timing; async CPU profiling +# double-counts overlapped decode work. +# VIBEPOD_PROFILE_GENERATION=0 diff --git a/server/start.sh b/server/start.sh index 995fbc1..060cde5 100755 --- a/server/start.sh +++ b/server/start.sh @@ -136,7 +136,7 @@ if $CPU_MODE; then export VIBEPOD_DEVICE="cpu" export UV_PROJECT_ENVIRONMENT=".venv-cpu" if [[ -z "${VIBEPOD_CPU_THREADS:-}" ]]; then - VIBEPOD_CPU_THREADS="$(uv run --no-sources python -c "import os; print(max(1, (os.cpu_count() or 2) // 2))")" + VIBEPOD_CPU_THREADS="$(uv run --no-sync --no-sources python -c "import os; print(max(1, (os.cpu_count() or 2) // 2))")" export VIBEPOD_CPU_THREADS fi export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$VIBEPOD_CPU_THREADS}" diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 2516a59..74fbd94 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -66,12 +66,10 @@ DEFAULT_SPEAKER = "carter" _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] -# ── Pipeline executors ───────────────────────────────────────────────────────── -# _decode_executor: overlaps acoustic_decode with forward_tts_lm (1 worker). -# _cfg_executor: runs positive + negative forward_tts_lm in parallel (1 worker). +# ── Pipeline executor ────────────────────────────────────────────────────────── +# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker). _decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None -_cfg_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None # ── Device selection ──────────────────────────────────────────────────────────── # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. @@ -546,38 +544,37 @@ def _install_generation_optimizations(model: object) -> None: def _install_cpu_pipeline_optimizations(model: object) -> None: - """Attach pipeline executors to the model for the optimised generate() loop. + """Attach the decode executor to the model for the optimised generate() loop. The JezzWTF/VibeVoice fork's generate() checks for two optional attributes: - model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) used to - overlap acoustic_decode with acoustic_connector + forward_tts_lm. + model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) that + overlaps acoustic_decode with acoustic_connector + forward_tts_lm. + Profiling showed this hides ~72s of decode cost behind tts_lm work, + capturing ~96% of the theoretical overlap savings. - model._vibepod_cfg_executor — ThreadPoolExecutor (1 worker) used to - run the positive and negative forward_tts_lm calls in parallel, so - both CFG passes execute concurrently instead of sequentially. + model._vibepod_cfg_executor — intentionally NOT set. Parallel pos/neg + forward_tts_lm via a second thread causes MKL OpenMP thread-pool + contention on CPU: both threads compete for the same OMP worker pool, + making each call slower rather than faster. Net effect: ~6% regression. + The hook remains in the fork for potential GPU or future use. - Both are None by default, making the fork's generate() behave identically - to upstream on CUDA or any machine where these aren't set. + Attributes default to None, so the fork's generate() falls back to the + original sequential behaviour on CUDA or any non-VibePod install. """ - global _decode_executor, _cfg_executor + global _decode_executor if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1": - logger.info("CPU async decode/CFG parallelism disabled via VIBEPOD_ASYNC_DECODE=0.") + logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.") return _decode_executor = concurrent.futures.ThreadPoolExecutor( max_workers=1, thread_name_prefix="vibepod-decode" ) - _cfg_executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1, thread_name_prefix="vibepod-cfg" - ) model._vibepod_decode_executor = _decode_executor - model._vibepod_cfg_executor = _cfg_executor logger.info( - "CPU pipeline: decode executor and CFG executor attached — " - "acoustic_decode overlaps tts_lm, pos/neg CFG runs in parallel. " - "Disable with VIBEPOD_ASYNC_DECODE=0." + "CPU pipeline: decode executor attached — acoustic_decode overlaps " + "tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0." ) @@ -643,9 +640,9 @@ def _load_model_sync() -> None: is_cpu = _device == "cpu" _config["device"] = _device _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) - _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 6.0 if is_cpu else 5.0) - _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0) - _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0) + _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0) + _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0) + _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0) _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) if is_cpu: logical_cpus = os.cpu_count() or 1 @@ -688,8 +685,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: yield if _decode_executor is not None: _decode_executor.shutdown(wait=False) - if _cfg_executor is not None: - _cfg_executor.shutdown(wait=False) app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan) diff --git a/web/components/GenerationControls.tsx b/web/components/GenerationControls.tsx index f9a7d4c..41bce72 100644 --- a/web/components/GenerationControls.tsx +++ b/web/components/GenerationControls.tsx @@ -157,7 +157,7 @@ export default function GenerationControls({
onPrebufferSecsChange(parseFloat(e.target.value))} @@ -271,7 +271,7 @@ export default function GenerationControls({ id="resume-threshold" type="range" min={0.5} - max={5.0} + max={30.0} step={0.1} value={resumeThresholdSecs} onChange={(e) => { diff --git a/web/hooks/useStreamingGeneration.ts b/web/hooks/useStreamingGeneration.ts index 257f0c8..a8dbcbb 100644 --- a/web/hooks/useStreamingGeneration.ts +++ b/web/hooks/useStreamingGeneration.ts @@ -6,7 +6,7 @@ const SAMPLE_RATE = 24_000; const DEFAULT_PREBUFFER_SECS = 5.0; const DEFAULT_REBUFFER_THRESHOLD_SECS = 1.0; const DEFAULT_RESUME_THRESHOLD_SECS = 3.0; -const MAX_ADAPTIVE_RESUME_SECS = 18.0; +const MAX_ADAPTIVE_RESUME_SECS = 30.0; interface GenerateOptions { text: string; From 737d315c1af4b41de4b65f78c0998d0c2146c1bd Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 23:32:20 +0100 Subject: [PATCH 06/11] docs: update roadmap and ignore Claude settings --- .claude/settings.local.json | 20 -------------------- .gitignore | 1 + roadmap.md | 7 +++++++ 3 files changed, 8 insertions(+), 20 deletions(-) delete mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index cf36061..0000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(mv podcast-forge/pnpm-lock.yaml /tmp/vibepod-pnpm-lock.yaml)", - "Bash(git mv *)", - "Bash(mv /tmp/vibepod-pnpm-lock.yaml web/pnpm-lock.yaml)", - "Bash(git rm *)", - "Bash(uv lock *)", - "Bash(pnpm install *)", - "Bash(git add *)", - "Bash(command -v uv)", - "Bash(uv --version)", - "Bash(uv sync *)", - "Bash(pnpm --filter vibepod-web exec tsc --noEmit)", - "Bash(xargs cat *)", - "Bash(.venv/Scripts/python.exe -c \"import torch; print\\('torch:', torch.__version__\\); print\\('CUDA available:', torch.cuda.is_available\\(\\)\\); print\\('CUDA version:', torch.version.cuda\\)\")", - "Bash(nvidia-smi)" - ] - } -} diff --git a/.gitignore b/.gitignore index 13db197..040007f 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ web/node_modules/ .DS_Store Thumbs.db .vscode/settings.json +.claude/settings.local.json diff --git a/roadmap.md b/roadmap.md index b2035fe..5825102 100644 --- a/roadmap.md +++ b/roadmap.md @@ -39,6 +39,13 @@ VibePod Studio will turn generated audio from a one-shot download into a reusabl - Add project save/load, autosave, and recoverable render jobs. - Prepare the audio pipeline for queueing longer renders outside the request lifecycle. +## Later: VibeVoice Performance Research + +- Move the current VibePod hot-path monkey patches into the `JezzWTF/VibeVoice` fork once the feature direction has settled. +- Add clearer generation profiling for overlapped CPU work, especially decode wait time versus total acoustic decode time. +- Prototype batched positive/negative CFG TTS LM inference behind an opt-in flag and benchmark it against the current sequential path on CPU and CUDA. +- Keep experimental performance work isolated from user-facing feature work unless it shows a clear speedup without audio quality regressions. + ## Foundation Work Needed First - Persist generated outputs with stable IDs. From d60c5ae4983fdd6f809fd264e631f1951eda4a41 Mon Sep 17 00:00:00 2001 From: LyAhn Date: Fri, 1 May 2026 18:36:04 +0100 Subject: [PATCH 07/11] chore: add prettier + enforce LF line endings - Add .prettierrc (double quotes, 2-space, trailing comma es5, LF, 100 cols) - Add .prettierignore (excludes node_modules, .next, server/, lock files) - Add .editorconfig (LF + per-language indent rules for all editors) - Expand .gitattributes to cover all text file types with eol=lf - Add prettier@^3.5.3 devDep at workspace root with format/format:check scripts - Add format/format:check scripts to web/package.json --- .editorconfig | 37 +++++++++++++++++++++++++++++++++++++ .gitattributes | 22 ++++++++++++++++++++++ .prettierignore | 18 ++++++++++++++++++ .prettierrc | 8 ++++++++ package.json | 7 ++++++- pnpm-lock.yaml | 13 ++++++++++++- web/package.json | 4 +++- 7 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 .editorconfig create mode 100644 .prettierignore create mode 100644 .prettierrc diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..d37e399 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,37 @@ +root = true + +[*] +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{ts,tsx,js,jsx,mjs,cjs,mts,cts}] +indent_style = space +indent_size = 2 + +[*.{json,jsonc}] +indent_style = space +indent_size = 2 + +[*.{css,html}] +indent_style = space +indent_size = 2 + +[*.{yaml,yml}] +indent_style = space +indent_size = 2 + +[*.py] +indent_style = space +indent_size = 4 + +[*.{toml}] +indent_style = space +indent_size = 4 + +[*.md] +trim_trailing_whitespace = false + +[Makefile] +indent_style = tab diff --git a/.gitattributes b/.gitattributes index 1bbd695..629ff9c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,24 @@ +* text=auto eol=lf + *.sh text eol=lf *.py text eol=lf +*.ts text eol=lf +*.tsx text eol=lf +*.js text eol=lf +*.jsx text eol=lf +*.mjs text eol=lf +*.cjs text eol=lf +*.mts text eol=lf +*.cts text eol=lf +*.css text eol=lf +*.html text eol=lf +*.json text eol=lf +*.jsonc text eol=lf +*.yaml text eol=lf +*.yml text eol=lf +*.toml text eol=lf +*.md text eol=lf +*.mdx text eol=lf +*.lock text eol=lf +*.env text eol=lf +*.env.* text eol=lf diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000..08615cc --- /dev/null +++ b/.prettierignore @@ -0,0 +1,18 @@ +# Dependencies +node_modules/ +web/node_modules/ + +# Build outputs +web/.next/ +web/tsconfig.tsbuildinfo +web/next-env.d.ts + +# Python / server +server/ + +# Lock files +pnpm-lock.yaml +web/pnpm-lock.yaml + +# Generated +web/public/ diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..2aa8717 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,8 @@ +{ + "semi": true, + "singleQuote": false, + "tabWidth": 2, + "trailingComma": "es5", + "printWidth": 100, + "endOfLine": "lf" +} diff --git a/package.json b/package.json index 034f905..16d0345 100644 --- a/package.json +++ b/package.json @@ -8,7 +8,12 @@ "dev:cpu": "bash dev.sh --cpu", "dev:server": "bash server/start.sh", "dev:server:cpu": "bash server/start.sh --cpu", - "dev:web": "pnpm --filter vibepod-web dev" + "dev:web": "pnpm --filter vibepod-web dev", + "format": "prettier --write .", + "format:check": "prettier --check ." + }, + "devDependencies": { + "prettier": "^3.5.3" }, "packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8" } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 07d0a9a..b89acfe 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -6,7 +6,11 @@ settings: importers: - .: {} + .: + devDependencies: + prettier: + specifier: ^3.5.3 + version: 3.8.3 web: dependencies: @@ -516,6 +520,11 @@ packages: resolution: {integrity: sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==} engines: {node: ^10 || ^12 || >=14} + prettier@3.8.3: + resolution: {integrity: sha512-7igPTM53cGHMW8xWuVTydi2KO233VFiTNyF5hLJqpilHfmn8C8gPf+PS7dUT64YcXFbiMGZxS9pCSxL/Dxm/Jw==} + engines: {node: '>=14'} + hasBin: true + react-dom@19.1.0: resolution: {integrity: sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==} peerDependencies: @@ -917,6 +926,8 @@ snapshots: picocolors: 1.1.1 source-map-js: 1.2.1 + prettier@3.8.3: {} + react-dom@19.1.0(react@19.1.0): dependencies: react: 19.1.0 diff --git a/web/package.json b/web/package.json index f5c4ca1..416b784 100644 --- a/web/package.json +++ b/web/package.json @@ -5,7 +5,9 @@ "scripts": { "dev": "next dev --turbopack", "build": "next build", - "start": "next start" + "start": "next start", + "format": "prettier --write .", + "format:check": "prettier --check ." }, "dependencies": { "next": "15.5.15", From a351910fd2fe7a25872d631cbb42de2610d45039 Mon Sep 17 00:00:00 2001 From: LyAhn Date: Fri, 1 May 2026 18:36:42 +0100 Subject: [PATCH 08/11] style: apply prettier formatting across all source files --- AGENTS.md | 28 ++- DESIGN.md | 5 + README.md | 36 +-- pnpm-workspace.yaml | 2 +- web/app/api/generate/route.ts | 4 +- web/app/api/health/route.ts | 3 +- web/app/globals.css | 6 +- web/app/page.tsx | 84 ++++--- web/components/AudioPlayer.tsx | 36 +-- web/components/GenerationControls.tsx | 85 +++++-- web/components/Header.tsx | 47 ++-- web/components/StatusLog.tsx | 3 +- web/components/TextInputPanel.tsx | 22 +- web/hooks/useAudioPlayer.ts | 16 +- web/hooks/useStreamingGeneration.ts | 317 +++++++++++++------------- 15 files changed, 376 insertions(+), 318 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ed18570..7ca403c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -8,10 +8,10 @@ This file gives AI coding agents (Jules, Copilot, Claude Code, etc.) the context VibePod is a text-to-speech web app. It has two services that must both run for the app to work: -| Service | Language | Entry point | Port | -|---------|----------|-------------|------| -| **server** | Python 3.10+ (FastAPI + VibeVoice) | `server/start.sh` | 8000 | -| **web** | TypeScript (Next.js 15, React 19) | `pnpm --filter vibepod-web dev` | 3000 | +| Service | Language | Entry point | Port | +| ---------- | ---------------------------------- | ------------------------------- | ---- | +| **server** | Python 3.10+ (FastAPI + VibeVoice) | `server/start.sh` | 8000 | +| **web** | TypeScript (Next.js 15, React 19) | `pnpm --filter vibepod-web dev` | 3000 | The Next.js frontend proxies all model requests through its own API routes to the FastAPI server — it never calls the Python server directly from the browser. @@ -51,12 +51,12 @@ pnpm build The `--cpu` flag in `start.sh` sets `VIBEPOD_DEVICE=cpu` and uses a separate venv (`server/.venv-cpu`) so CUDA and CPU installs never conflict. `vibevoice_server.py` reads `VIBEPOD_DEVICE` at startup via `_resolve_device()` — do not remove or rename that function. -| Env var | Values | Set by | -|---------|--------|--------| -| `VIBEPOD_DEVICE` | `cpu` \| `cuda` | `server/start.sh` | -| `UV_PROJECT_ENVIRONMENT` | `.venv-cpu` \| `.venv` | `server/start.sh` | -| `HF_TOKEN` | HuggingFace token | Jules secret / `.env.local` | -| `VIBEVOICE_SERVER_URL` | `http://localhost:8000` | `.env.local` | +| Env var | Values | Set by | +| ------------------------ | ----------------------- | --------------------------- | +| `VIBEPOD_DEVICE` | `cpu` \| `cuda` | `server/start.sh` | +| `UV_PROJECT_ENVIRONMENT` | `.venv-cpu` \| `.venv` | `server/start.sh` | +| `HF_TOKEN` | HuggingFace token | Jules secret / `.env.local` | +| `VIBEVOICE_SERVER_URL` | `http://localhost:8000` | `.env.local` | --- @@ -94,7 +94,9 @@ dev.sh Concurrent launcher (forwards flags to start.sh) ## API reference ### `GET /health` + Returns server status. Safe to poll. + ```json { "status": "online", @@ -103,13 +105,17 @@ Returns server status. Safe to poll. "voices": ["carter", "davis", "emma", "frank", "grace", "mike"] } ``` + `status` values: `downloading` | `loading` | `online` | `error` ### `POST /generate` + Streams audio as SSE events. + ```json { "text": "Hello world", "speaker": "carter", "cfg_scale": 1.5, "inference_steps": 10 } ``` + Event types: `audio_chunk` (base64 float32 PCM) | `complete` | `error` | `cancelled` --- @@ -117,12 +123,14 @@ Event types: `audio_chunk` (base64 float32 PCM) | `complete` | `error` | `cancel ## Do / Don't **Do:** + - Use `pnpm dev:cpu` in Jules — never plain `pnpm dev` - Run `git checkout server/uv.lock` if uv rewrites it during setup - Keep `_resolve_device()` in `vibevoice_server.py` — it's the CPU/CUDA switching logic - Test server changes against `GET /health` and `POST /generate` **Don't:** + - Run `uv sync` without `UV_PROJECT_ENVIRONMENT=.venv-cpu` in the Jules sandbox - Install Python packages with pip - Modify `server/uv.lock` manually diff --git a/DESIGN.md b/DESIGN.md index 2654734..42a00df 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -173,16 +173,21 @@ The shape language is a hybrid of structural precision and tactile softness. ## Components ### Card Containers + The fundamental building block of the UI. Every distinct section (Script, Player, Controls, Logs) is housed in a card featuring the `card-bg`, a 1px `border`, and `rounded-xl` corners. The internal layout always features an uppercase teal header for immediate section identification. ### Primary Action Buttons + Used for high-leverage actions like "Generate Audio" and "Play/Pause." These buttons utilize the `gradient-primary-dim` background, bold white text, and emit a soft teal glow to draw the eye and signify their importance. ### Range Sliders + Custom-styled input ranges replace default browser styles. The tracks are muted and slim, while the thumbs are bright teal, fully rounded, and emit a glow that intensifies on hover, providing a premium, tactile scrubbing experience. ### Status Indicators & Logs + A critical component of the application. Status badges utilize a minimalist pill shape with a pulsing ring animation to indicate active server processing. The log panel explicitly uses monospace typography and color-codes messages (green for success, red for error, white for neutral) to provide a terminal-like readout of the backend systems. ### Gradients + Gradients are used purposefully to indicate progress, activity, or brand presence. The primary gradient (`135deg` from teal to violet) is used for branding (the logo icon and text) and primary buttons. Horizontal gradients (`90deg`) are used dynamically in progress bars to represent the flow of data over time (e.g., loading, downloading, and audio generation). diff --git a/README.md b/README.md index f8202f5..ab76d3c 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,12 @@ The Next.js app proxies audio generation requests to the FastAPI server, keeping ## Prerequisites -| Tool | Install | -|------|---------| -| [Node.js 20+](https://nodejs.org) | `winget install OpenJS.NodeJS.LTS` | -| [pnpm](https://pnpm.io) | `npm i -g pnpm` | +| Tool | Install | +| ---------------------------------- | ----------------------------------- | +| [Node.js 20+](https://nodejs.org) | `winget install OpenJS.NodeJS.LTS` | +| [pnpm](https://pnpm.io) | `npm i -g pnpm` | | [Python 3.10+](https://python.org) | `winget install Python.Python.3.13` | -| [uv](https://docs.astral.sh/uv/) | `winget install astral-sh.uv` | +| [uv](https://docs.astral.sh/uv/) | `winget install astral-sh.uv` | ## Getting started @@ -50,10 +50,10 @@ The frontend shows a loading indicator while the model downloads. Once the serve VibePod maintains two completely separate Python virtual environments so CUDA and CPU torch installs never conflict: -| Mode | Command | venv | torch source | -|------|---------|------|--------------| -| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index | -| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) | +| Mode | Command | venv | torch source | +| -------------- | -------------- | ------------------ | ----------------------- | +| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index | +| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) | On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"` or `"device": "cuda"`. @@ -74,11 +74,11 @@ pnpm build # Production build of the frontend Copy `.env.example` to `.env.local` and set: -| Variable | Default | Description | -|----------|---------|-------------| +| Variable | Default | Description | +| ---------------------- | ----------------------- | --------------------------------------------------------- | | `VIBEVOICE_SERVER_URL` | `http://localhost:8000` | URL the Next.js API routes use to reach the Python server | -| `HF_TOKEN` | — | HuggingFace token (required if the model repo is gated) | -| `HF_HOME` | — | Override the HuggingFace model cache directory | +| `HF_TOKEN` | — | HuggingFace token (required if the model repo is gated) | +| `HF_HOME` | — | Override the HuggingFace model cache directory | ## Project structure @@ -107,11 +107,11 @@ server/ ## Generation parameters -| Parameter | Range | Default | Effect | -|-----------|-------|---------|--------| -| `speaker` | `carter`, `davis`, `emma`, `frank`, `grace`, `mike` | `carter` | Voice preset used for the generated audio | -| `cfg_scale` | 0.5 – 4.0 | 1.5 | Higher = more expressive guidance | -| `inference_steps` | 5 – 20 | 10 | More steps = higher quality, slower generation | +| Parameter | Range | Default | Effect | +| ----------------- | --------------------------------------------------- | -------- | ---------------------------------------------- | +| `speaker` | `carter`, `davis`, `emma`, `frank`, `grace`, `mike` | `carter` | Voice preset used for the generated audio | +| `cfg_scale` | 0.5 – 4.0 | 1.5 | Higher = more expressive guidance | +| `inference_steps` | 5 – 20 | 10 | More steps = higher quality, slower generation | ## How it works diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index b1cedb5..92a7e8b 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -1,2 +1,2 @@ packages: - - 'web' + - "web" diff --git a/web/app/api/generate/route.ts b/web/app/api/generate/route.ts index 180c659..310bb01 100644 --- a/web/app/api/generate/route.ts +++ b/web/app/api/generate/route.ts @@ -7,7 +7,7 @@ export async function POST(request: NextRequest) { const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; try { - const body = await request.json() as { + const body = (await request.json()) as { text: string; speaker?: string; cfg_scale?: number; @@ -41,7 +41,7 @@ export async function POST(request: NextRequest) { headers: { "Content-Type": "text/event-stream", "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", + Connection: "keep-alive", "X-Content-Type-Options": "nosniff", "X-Accel-Buffering": "no", }, diff --git a/web/app/api/health/route.ts b/web/app/api/health/route.ts index ba17edb..e4d3506 100644 --- a/web/app/api/health/route.ts +++ b/web/app/api/health/route.ts @@ -4,8 +4,7 @@ const OFFLINE_RESPONSE = { status: "offline" }; const COMMON_OPTIONS = { headers: { "Cache-Control": "no-store" } }; export async function GET() { - const pythonServerUrl = - process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; + const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; try { const res = await fetch(`${pythonServerUrl}/health`, { diff --git a/web/app/globals.css b/web/app/globals.css index 9388e7f..d4569ee 100644 --- a/web/app/globals.css +++ b/web/app/globals.css @@ -12,8 +12,10 @@ --muted: #64748b; --success: #22c55e; --error: #ef4444; - --font-sans: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; - --font-mono: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas, "Liberation Mono", monospace; + --font-sans: + ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + --font-mono: + ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas, "Liberation Mono", monospace; } @theme inline { diff --git a/web/app/page.tsx b/web/app/page.tsx index a9f6317..128824a 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -69,19 +69,39 @@ type AppAction = function reducer(state: AppState, action: AppAction): AppState { switch (action.type) { - case "SET_SCRIPT": return { ...state, script: action.payload }; - case "SET_SPEAKER": return { ...state, speaker: action.payload }; - case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload }; - case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload }; - case "SET_PREBUFFER_SECS": return { ...state, prebufferSecs: action.payload }; - case "SET_REBUFFER_THRESHOLD": return { ...state, rebufferThresholdSecs: action.payload }; - case "SET_RESUME_THRESHOLD": return { ...state, resumeThresholdSecs: action.payload }; + case "SET_SCRIPT": + return { ...state, script: action.payload }; + case "SET_SPEAKER": + return { ...state, speaker: action.payload }; + case "SET_CFG_SCALE": + return { ...state, cfgScale: action.payload }; + case "SET_INFERENCE_STEPS": + return { ...state, inferenceSteps: action.payload }; + case "SET_PREBUFFER_SECS": + return { ...state, prebufferSecs: action.payload }; + case "SET_REBUFFER_THRESHOLD": + return { ...state, rebufferThresholdSecs: action.payload }; + case "SET_RESUME_THRESHOLD": + return { ...state, resumeThresholdSecs: action.payload }; case "START_GENERATION": - return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null }; + return { + ...state, + isGenerating: true, + audioUrl: null, + logs: [], + genElapsed: 0, + genPct: null, + }; case "GEN_PROGRESS": return { ...state, genElapsed: action.elapsed, genPct: action.pct }; case "GENERATION_SUCCESS": - return { ...state, isGenerating: false, genElapsed: 0, genPct: null, audioUrl: action.payload }; + return { + ...state, + isGenerating: false, + genElapsed: 0, + genPct: null, + audioUrl: action.payload, + }; case "GENERATION_CANCELLED": case "GENERATION_ERROR": return { ...state, isGenerating: false, genElapsed: 0, genPct: null }; @@ -89,21 +109,27 @@ function reducer(state: AppState, action: AppAction): AppState { return { ...state, logs: [...state.logs, action.payload] }; case "SET_SERVER_STATUS": { const isNewConfig = !state.serverConfig && action.payload.config; - const deviceChanged = !!(state.serverConfig && action.payload.config && state.serverConfig.device !== action.payload.config.device); + const deviceChanged = !!( + state.serverConfig && + action.payload.config && + state.serverConfig.device !== action.payload.config.device + ); - const nextSteps = (isNewConfig || deviceChanged) + const nextSteps = + isNewConfig || deviceChanged ? action.payload.config!.default_inference_steps : state.inferenceSteps; - const nextPrebuffer = (isNewConfig || deviceChanged) - ? action.payload.config!.prebuffer_secs - : state.prebufferSecs; + const nextPrebuffer = + isNewConfig || deviceChanged ? action.payload.config!.prebuffer_secs : state.prebufferSecs; - const nextRebuffer = (isNewConfig || deviceChanged) + const nextRebuffer = + isNewConfig || deviceChanged ? action.payload.config!.rebuffer_threshold_secs : state.rebufferThresholdSecs; - const nextResume = (isNewConfig || deviceChanged) + const nextResume = + isNewConfig || deviceChanged ? action.payload.config!.resume_threshold_secs : state.resumeThresholdSecs; @@ -121,7 +147,8 @@ function reducer(state: AppState, action: AppAction): AppState { resumeThresholdSecs: nextResume, }; } - default: return state; + default: + return state; } } @@ -213,7 +240,10 @@ export default function HomePage() { } poll(); - return () => { cancelled = true; clearTimeout(timeoutId); }; + return () => { + cancelled = true; + clearTimeout(timeoutId); + }; }, []); const handleGenerate = useCallback(async () => { @@ -241,7 +271,6 @@ export default function HomePage() {
- {/* Left: script + audio player */}
dispatch({ type: "SET_CFG_SCALE", payload: v })} inferenceSteps={state.inferenceSteps} onInferenceStepsChange={(v) => dispatch({ type: "SET_INFERENCE_STEPS", payload: v })} - prebufferSecs={state.prebufferSecs} - onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })} - rebufferThresholdSecs={state.rebufferThresholdSecs} - onRebufferThresholdChange={(v) => dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })} - resumeThresholdSecs={state.resumeThresholdSecs} - onResumeThresholdChange={(v) => dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })} + prebufferSecs={state.prebufferSecs} + onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })} + rebufferThresholdSecs={state.rebufferThresholdSecs} + onRebufferThresholdChange={(v) => + dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v }) + } + resumeThresholdSecs={state.resumeThresholdSecs} + onResumeThresholdChange={(v) => + dispatch({ type: "SET_RESUME_THRESHOLD", payload: v }) + } onGenerate={handleGenerate} onStop={stop} onPauseStream={pauseStream} @@ -281,7 +314,6 @@ export default function HomePage() { />
-
diff --git a/web/components/AudioPlayer.tsx b/web/components/AudioPlayer.tsx index f54f25e..36ca5b1 100644 --- a/web/components/AudioPlayer.tsx +++ b/web/components/AudioPlayer.tsx @@ -14,15 +14,8 @@ function formatTime(seconds: number): string { } export default function AudioPlayer({ audioUrl }: AudioPlayerProps) { - const { - isPlaying, - currentTime, - duration, - volume, - toggle, - seek, - setVolume, - } = useAudioPlayer(audioUrl); + const { isPlaying, currentTime, duration, volume, toggle, seek, setVolume } = + useAudioPlayer(audioUrl); if (!audioUrl) return null; @@ -56,12 +49,10 @@ export default function AudioPlayer({ audioUrl }: AudioPlayerProps) { background: "rgba(45, 212, 191, 0.05)", }} onMouseEnter={(e) => { - (e.currentTarget as HTMLButtonElement).style.background = - "rgba(45, 212, 191, 0.15)"; + (e.currentTarget as HTMLButtonElement).style.background = "rgba(45, 212, 191, 0.15)"; }} onMouseLeave={(e) => { - (e.currentTarget as HTMLButtonElement).style.background = - "rgba(45, 212, 191, 0.05)"; + (e.currentTarget as HTMLButtonElement).style.background = "rgba(45, 212, 191, 0.05)"; }} > {isPlaying ? ( - + ) : ( - + )} @@ -143,9 +125,7 @@ export default function AudioPlayer({ audioUrl }: AudioPlayerProps) { {/* Duration info */}
- - {formatTime(currentTime)} - + {formatTime(currentTime)} / {formatTime(duration)}
diff --git a/web/components/GenerationControls.tsx b/web/components/GenerationControls.tsx index 41bce72..373fe0d 100644 --- a/web/components/GenerationControls.tsx +++ b/web/components/GenerationControls.tsx @@ -36,18 +36,27 @@ const STATUS_CONFIG: Record< Exclude, { color: string; label: (p: DownloadProgress | null) => string } > = { - offline: { color: "var(--error)", label: () => "Server offline — waiting for connection..." }, - downloading: { color: "#60a5fa", label: (p) => p && p.total > 0 ? `Downloading model... (${p.done} / ${p.total} files)` : "Downloading model (~1 GB)..." }, - loading: { color: "#fbbf24", label: () => "Loading model into memory..." }, - error: { color: "var(--error)", label: () => "Server error — check the terminal for details." }, + offline: { color: "var(--error)", label: () => "Server offline — waiting for connection..." }, + downloading: { + color: "#60a5fa", + label: (p) => + p && p.total > 0 + ? `Downloading model... (${p.done} / ${p.total} files)` + : "Downloading model (~1 GB)...", + }, + loading: { color: "#fbbf24", label: () => "Loading model into memory..." }, + error: { color: "var(--error)", label: () => "Server error — check the terminal for details." }, }; - function SpinnerIcon() { return ( - + ); } @@ -146,7 +155,10 @@ export default function GenerationControls({ onChange={(e) => onCfgScaleChange(parseFloat(e.target.value))} className="w-full" /> -
+
Flat (0.5) CFG Scale Expressive (4.0) @@ -176,7 +188,10 @@ export default function GenerationControls({ className="w-full" style={{ "--thumb-color": "var(--accent-violet)" } as React.CSSProperties} /> -
+
Faster (5) Diffusion Steps Better (20) @@ -207,7 +222,11 @@ export default function GenerationControls({
{showAdvanced && ( -
+
{/* Pre-buffer */}
@@ -232,7 +251,11 @@ export default function GenerationControls({ {/* Re-buffer threshold */}
-