From 75b84b211bb9e0c1ec31295ae866d3110f9220a9 Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 18:54:14 +0100 Subject: [PATCH] perf: improve streaming generation pipeline Add CUDA inference hot-path optimizations, safer attention fallback handling, and generation profiling hooks. Improve SSE streaming, browser buffering telemetry, and playback recovery while preserving default audio quality settings. --- .gitattributes | 2 + server/.python-version | 1 + server/start.sh | 48 ++++ server/vibevoice_server.py | 365 +++++++++++++++++++++++++--- web/app/api/generate/route.ts | 7 +- web/app/api/health/route.ts | 1 + web/app/page.tsx | 6 +- web/hooks/useStreamingGeneration.ts | 75 ++++-- web/package.json | 2 +- 9 files changed, 459 insertions(+), 48 deletions(-) create mode 100644 .gitattributes create mode 100644 server/.python-version diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..1bbd695 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.sh text eol=lf +*.py text eol=lf diff --git a/server/.python-version b/server/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/server/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/server/start.sh b/server/start.sh index befa203..2daa340 100755 --- a/server/start.sh +++ b/server/start.sh @@ -7,6 +7,11 @@ # ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv) # ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu) # +# Optional CUDA acceleration: +# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh +# Installs a matching third-party Windows flash-attn wheel when the CUDA venv +# uses Python 3.12, torch 2.6.0, and CUDA 12.4. +# # The two modes maintain completely separate virtual environments so their torch # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; # --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel. @@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then exit 1 fi +validate_flash_attn() { + uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null +} + +remove_broken_flash_attn() { + if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then + if ! validate_flash_attn; then + echo " Installed flash-attn is not usable in this environment; removing it." + uv pip uninstall flash-attn + fi + fi +} + # --------------------------------------------------------------------------- # 2. Sync Python environment # CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the @@ -65,6 +83,36 @@ if $CPU_MODE; then else echo "--> Syncing CUDA Python environment (.venv)..." uv sync + + remove_broken_flash_attn + + if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then + echo "" + echo "--> Checking optional FlashAttention wheel..." + + if validate_flash_attn; then + echo " flash-attn already installed and importable." + else + PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")" + TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")" + CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")" + + if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then + FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" + echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..." + uv pip install "$FLASH_ATTN_WHEEL_URL" + if validate_flash_attn; then + echo " flash-attn import check passed." + else + echo " flash-attn import check failed; removing it and continuing with SDPA." + uv pip uninstall flash-attn + fi + else + echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG." + echo " Continuing with PyTorch SDPA attention." + fi + fi + fi fi # --------------------------------------------------------------------------- diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index f5bc012..39541f5 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -22,11 +22,13 @@ import asyncio import base64 import copy import functools +import importlib.util import json import logging import os import threading import time +import types import urllib.request from contextlib import asynccontextmanager from pathlib import Path @@ -211,8 +213,39 @@ def _init_processor(): def _init_model(device: str): logger.info("Loading model on %s...", device) - load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 - attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" + if device == "cuda": + torch.set_float32_matmul_precision("high") + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) + logger.info( + "PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s", + torch.backends.cuda.flash_sdp_enabled(), + torch.backends.cuda.mem_efficient_sdp_enabled(), + torch.backends.cuda.math_sdp_enabled(), + ) + + cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower() + if device == "cuda" and cuda_dtype == "fp16": + load_dtype = torch.float16 + else: + load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 + logger.info("Loading model weights with dtype %s", load_dtype) + requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower() + has_flash_attn = importlib.util.find_spec("flash_attn") is not None + if requested_attn_impl in {"eager", "sdpa"}: + attn_impl = requested_attn_impl + elif requested_attn_impl == "flash_attention_2": + attn_impl = "flash_attention_2" if has_flash_attn else "sdpa" + else: + attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa" + logger.info("Using Transformers attention implementation: %s", attn_impl) + if device == "cuda" and not has_flash_attn: + logger.info("flash_attn is not installed; using PyTorch SDPA attention.") from vibevoice.modular.modeling_vibevoice_streaming_inference import ( VibeVoiceStreamingForConditionalGenerationInference, @@ -225,9 +258,13 @@ def _init_model(device: str): device_map=device, attn_implementation=attn_impl, ) - except Exception: + except Exception as exc: + if attn_impl == "sdpa": + raise logger.warning( - "Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True + "Model load with %s failed (%s); falling back to sdpa", + attn_impl, + exc, ) model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, @@ -238,9 +275,164 @@ def _init_model(device: str): model.eval() model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"]) + _install_generation_optimizations(model) return model +def _install_generation_optimizations(model: object) -> None: + """Patch VibeVoice hot paths without changing model quality settings.""" + + def profile_enabled() -> bool: + return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1" + + def profile_sync() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def profile_record(self, key: str, elapsed: float) -> None: + stats = getattr(self, "_vibepod_profile", None) + if stats is None: + stats = {} + self._vibepod_profile = stats + bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0}) + bucket["count"] += 1 + bucket["seconds"] += elapsed + + def timed_method(self, key: str, fn, *args, **kwargs): + if not profile_enabled(): + return fn(*args, **kwargs) + profile_sync() + started = time.perf_counter() + result = fn(*args, **kwargs) + profile_sync() + profile_record(self, key, time.perf_counter() - started) + return result + + def prepare_noise_scheduler(self): + scheduler = self.model.noise_scheduler + cache_key = self.ddpm_inference_steps + cache = getattr(self, "_vibepod_scheduler_cache", {}) + cached = cache.get(cache_key) + + if cached is None: + scheduler.set_timesteps(self.ddpm_inference_steps) + cached = { + "num_inference_steps": scheduler.num_inference_steps, + "timesteps": scheduler.timesteps, + "sigmas": scheduler.sigmas, + } + cache[cache_key] = cached + self._vibepod_scheduler_cache = cache + else: + scheduler.num_inference_steps = cached["num_inference_steps"] + scheduler.timesteps = cached["timesteps"] + scheduler.sigmas = cached["sigmas"] + scheduler.model_outputs = [None] * scheduler.config.solver_order + scheduler.lower_order_nums = 0 + scheduler._step_index = None + scheduler._begin_index = None + + return scheduler + + def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0): + scheduler = prepare_noise_scheduler(self) + + condition = torch.cat([condition, neg_condition], dim=0).to( + self.model.prediction_head.device + ) + batch_size = condition.shape[0] // 2 + speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition) + t_batch_cache_key = ( + self.ddpm_inference_steps, + condition.device.type, + condition.device.index, + condition.dtype, + batch_size, + ) + t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {}) + t_batches = t_batch_cache.get(t_batch_cache_key) + if t_batches is None or len(t_batches) != len(scheduler.timesteps): + t_batches = [ + t.repeat(condition.shape[0]).to( + device=condition.device, dtype=condition.dtype + ) + for t in scheduler.timesteps + ] + t_batch_cache[t_batch_cache_key] = t_batches + self._vibepod_t_batch_cache = t_batch_cache + + for t, t_batch in zip(scheduler.timesteps, t_batches): + if batch_size == 1: + combined = speech.expand(condition.shape[0], -1) + else: + combined = torch.cat([speech, speech], dim=0) + if profile_enabled(): + profile_sync() + started = time.perf_counter() + eps = self.model.prediction_head(combined, t_batch, condition=condition) + if profile_enabled(): + profile_sync() + profile_record(self, "diffusion_prediction_head", time.perf_counter() - started) + cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0) + guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + if profile_enabled(): + started = time.perf_counter() + speech = scheduler.step(guided_eps, t, speech).prev_sample + if profile_enabled(): + profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started) + + return speech + + forward_lm = model.forward_lm + forward_tts_lm = model.forward_tts_lm + acoustic_decode = model.model.acoustic_tokenizer.decode + + def forward_lm_profiled(*args, **kwargs): + return timed_method(model, "forward_lm", forward_lm, *args, **kwargs) + + def forward_tts_lm_profiled(*args, **kwargs): + return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs) + + def acoustic_decode_profiled(*args, **kwargs): + return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs) + + model.forward_lm = forward_lm_profiled + model.forward_tts_lm = forward_tts_lm_profiled + model.model.acoustic_tokenizer.decode = acoustic_decode_profiled + model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model) + logger.info("Installed VibeVoice generation hot-path optimizations.") + + +def _model_float_dtype() -> torch.dtype: + try: + return next(_model.parameters()).dtype + except StopIteration: + return torch.float32 + + +def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object: + if torch.is_tensor(value): + if torch.is_floating_point(value): + return value.to(device=device, dtype=dtype) + return value.to(device=device) + if isinstance(value, dict): + for k in list(value.keys()): + value[k] = _move_cached_prompt(value[k], device, dtype) + return value + if isinstance(value, list): + return [_move_cached_prompt(v, device, dtype) for v in value] + if isinstance(value, tuple): + return tuple(_move_cached_prompt(v, device, dtype) for v in value) + if hasattr(value, "key_cache") and hasattr(value, "value_cache"): + value.key_cache = [ + _move_cached_prompt(t, device, dtype) for t in value.key_cache + ] + value.value_cache = [ + _move_cached_prompt(t, device, dtype) for t in value.value_cache + ] + return value + + def _load_voice_presets(device: str) -> dict[str, object]: presets = {} for name, filename in EN_VOICES.items(): @@ -273,9 +465,9 @@ def _load_model_sync() -> None: is_cpu = _device == "cpu" _config["device"] = _device _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) - _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0) - _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4) - _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5) + _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 6.0 if is_cpu else 5.0) + _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0) + _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0) _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) _processor = _init_processor() @@ -364,10 +556,15 @@ def _sync_generate( raise RuntimeError("Generation cancelled.") speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER - voice_preset = copy.deepcopy(_voice_presets[speaker]) + model_dtype = _model_float_dtype() + voice_preset = _move_cached_prompt( + copy.deepcopy(_voice_presets[speaker]), _device, model_dtype + ) steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"] _model.set_ddpm_inference_steps(num_steps=steps) + if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1": + _model._vibepod_profile = {} inputs = _processor.process_input_with_cached_prompt( text=req.text, @@ -380,19 +577,20 @@ def _sync_generate( if torch.is_tensor(v): inputs[k] = v.to(_device) - outputs = _model.generate( - **inputs, - max_new_tokens=None, - cfg_scale=req.cfg_scale, - tokenizer=_processor.tokenizer, - generation_config={"do_sample": False}, - verbose=True, - all_prefilled_outputs=copy.deepcopy(voice_preset), - audio_streamer=streamer, - ) - - if not outputs.speech_outputs or outputs.speech_outputs[0] is None: - raise ValueError("Model returned no audio output.") + with torch.inference_mode(): + _model.generate( + **inputs, + max_new_tokens=None, + cfg_scale=req.cfg_scale, + tokenizer=_processor.tokenizer, + generation_config={"do_sample": False}, + verbose=False, + show_progress_bar=False, + return_speech=False, + stop_check_fn=cancel_event.is_set if cancel_event else None, + all_prefilled_outputs=voice_preset, + audio_streamer=streamer, + ) return speaker @@ -401,6 +599,24 @@ def _sse(event: dict) -> str: return f"data: {json.dumps(event)}\n\n" +def _generation_profile() -> Optional[dict[str, dict[str, float]]]: + if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1": + return None + stats = getattr(_model, "_vibepod_profile", None) + if not stats: + return {} + return { + key: { + "count": value["count"], + "seconds": round(value["seconds"], 3), + "avg_ms": round(value["seconds"] * 1000 / value["count"], 3) + if value["count"] + else 0.0, + } + for key, value in sorted(stats.items()) + } + + @app.post("/generate") async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if _model_status != "online": @@ -417,20 +633,58 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: ) async def event_stream() -> AsyncGenerator[str, None]: - from vibevoice.modular.streamer import AsyncAudioStreamer + class NonBlockingAudioStreamer: + """Async streamer that keeps GPU->CPU copies out of the model thread.""" + + def __init__(self, batch_size: int, stop_signal: object = None) -> None: + self.batch_size = batch_size + self.stop_signal = stop_signal + self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] + self.finished_flags = [False for _ in range(batch_size)] + self.loop = asyncio.get_running_loop() + + def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None: + for i, sample_idx in enumerate(sample_indices): + idx = sample_idx.item() + if idx < self.batch_size and not self.finished_flags[idx]: + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, + audio_chunks[i].detach(), + ) + + def end(self, sample_indices: Optional[torch.Tensor] = None) -> None: + if sample_indices is None: + indices_to_end = range(self.batch_size) + else: + indices_to_end = [ + s.item() if torch.is_tensor(s) else s for s in sample_indices + ] + for idx in indices_to_end: + if idx < self.batch_size and not self.finished_flags[idx]: + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, self.stop_signal + ) + self.finished_flags[idx] = True start = time.monotonic() - streamer = AsyncAudioStreamer(batch_size=1) + streamer = NonBlockingAudioStreamer(batch_size=1) cancel_event = threading.Event() accum_size = max(1, _config["chunk_accum"]) accumulated_chunks = [] + chunk_count = 0 + audio_samples = 0 + first_chunk_at: Optional[float] = None + last_chunk_at: Optional[float] = None + max_chunk_gap = 0.0 + speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER async with _generation_lock: loop = asyncio.get_event_loop() future = loop.run_in_executor( None, functools.partial(_sync_generate, req, streamer, cancel_event) ) + future.add_done_callback(lambda _: streamer.end()) # Drain audio chunks as they arrive from the diffusion head. # stop_signal=None is the default sentinel that ends the queue. @@ -454,17 +708,45 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if chunk is None: # stop signal break - accumulated_chunks.append(chunk.detach().cpu().float()) + accumulated_chunks.append(chunk.detach()) if len(accumulated_chunks) >= accum_size: - combined = torch.cat(accumulated_chunks, dim=0) + now = time.monotonic() + if first_chunk_at is None: + first_chunk_at = now + if last_chunk_at is not None: + max_chunk_gap = max(max_chunk_gap, now - last_chunk_at) + last_chunk_at = now + + combined = ( + torch.cat(accumulated_chunks, dim=0) + .detach() + .to("cpu", dtype=torch.float32) + .contiguous() + ) + chunk_count += 1 + audio_samples += combined.numel() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() yield _sse({"type": "audio_chunk", "data": pcm_b64}) accumulated_chunks = [] # Flush any remaining chunks if accumulated_chunks: - combined = torch.cat(accumulated_chunks, dim=0) + now = time.monotonic() + if first_chunk_at is None: + first_chunk_at = now + if last_chunk_at is not None: + max_chunk_gap = max(max_chunk_gap, now - last_chunk_at) + last_chunk_at = now + + combined = ( + torch.cat(accumulated_chunks, dim=0) + .detach() + .to("cpu", dtype=torch.float32) + .contiguous() + ) + chunk_count += 1 + audio_samples += combined.numel() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() yield _sse({"type": "audio_chunk", "data": pcm_b64}) @@ -479,17 +761,42 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: yield _sse( { "type": "error", - "message": "Internal server error during generation.", + "message": f"Generation failed: {exc}", } ) return elapsed = round(time.monotonic() - start, 1) + audio_secs = audio_samples / SAMPLE_RATE + realtime_factor = audio_secs / elapsed if elapsed > 0 else None + profile = _generation_profile() + if profile is not None: + logger.info("Generation profile: %s", profile) logger.info("Generation complete in %.1fs", elapsed) - yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker}) + complete_event = { + "type": "complete", + "elapsed": elapsed, + "speaker": speaker, + "audio_secs": round(audio_secs, 2), + "realtime_factor": round(realtime_factor, 3) + if realtime_factor is not None + else None, + "chunks": chunk_count, + "first_chunk_secs": round(first_chunk_at - start, 2) + if first_chunk_at is not None + else None, + "max_chunk_gap_secs": round(max_chunk_gap, 2), + } + if profile is not None: + complete_event["profile"] = profile + yield _sse(complete_event) return StreamingResponse( event_stream(), media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + headers={ + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "X-Content-Type-Options": "nosniff", + }, ) diff --git a/web/app/api/generate/route.ts b/web/app/api/generate/route.ts index 8bd1b94..180c659 100644 --- a/web/app/api/generate/route.ts +++ b/web/app/api/generate/route.ts @@ -1,5 +1,8 @@ import { NextRequest, NextResponse } from "next/server"; +export const dynamic = "force-dynamic"; +export const runtime = "nodejs"; + export async function POST(request: NextRequest) { const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; @@ -24,6 +27,7 @@ export async function POST(request: NextRequest) { cfg_scale: body.cfg_scale ?? 1.5, inference_steps: body.inference_steps ?? 10, }), + signal: request.signal, }); if (!upstream.ok) { @@ -36,8 +40,9 @@ export async function POST(request: NextRequest) { status: 200, headers: { "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", + "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", + "X-Content-Type-Options": "nosniff", "X-Accel-Buffering": "no", }, }); diff --git a/web/app/api/health/route.ts b/web/app/api/health/route.ts index dffb3f8..ba17edb 100644 --- a/web/app/api/health/route.ts +++ b/web/app/api/health/route.ts @@ -27,6 +27,7 @@ export async function GET() { message: data.message, progress: data.progress ?? null, voices: data.voices ?? [], + config: data.config ?? null, }, COMMON_OPTIONS ); diff --git a/web/app/page.tsx b/web/app/page.tsx index 275d658..a9f6317 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -130,9 +130,9 @@ const initialState: AppState = { speaker: "carter", cfgScale: 1.5, inferenceSteps: 10, - prebufferSecs: 2.0, - rebufferThresholdSecs: 0.4, - resumeThresholdSecs: 1.5, + prebufferSecs: 5.0, + rebufferThresholdSecs: 1.0, + resumeThresholdSecs: 3.0, isGenerating: false, genElapsed: 0, genPct: null, diff --git a/web/hooks/useStreamingGeneration.ts b/web/hooks/useStreamingGeneration.ts index 192cb27..257f0c8 100644 --- a/web/hooks/useStreamingGeneration.ts +++ b/web/hooks/useStreamingGeneration.ts @@ -3,9 +3,10 @@ import { useCallback, useEffect, useRef, useState } from "react"; const SAMPLE_RATE = 24_000; -const DEFAULT_PREBUFFER_SECS = 2.0; -const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4; -const DEFAULT_RESUME_THRESHOLD_SECS = 1.5; +const DEFAULT_PREBUFFER_SECS = 5.0; +const DEFAULT_REBUFFER_THRESHOLD_SECS = 1.0; +const DEFAULT_RESUME_THRESHOLD_SECS = 3.0; +const MAX_ADAPTIVE_RESUME_SECS = 18.0; interface GenerateOptions { text: string; @@ -104,6 +105,10 @@ export function useStreamingGeneration({ const isAutoBufferingRef = useRef(false); const isUserPausedRef = useRef(false); const audioUrlRef = useRef(null); + const firstChunkSeenRef = useRef(false); + const underrunCountRef = useRef(0); + const totalAudioSamplesRef = useRef(0); + const adaptiveResumeSecsRef = useRef(DEFAULT_RESUME_THRESHOLD_SECS); const revokeCurrentUrl = useCallback(() => { if (audioUrlRef.current) { @@ -122,8 +127,12 @@ export function useStreamingGeneration({ hasStartedPlaybackRef.current = false; isAutoBufferingRef.current = false; isUserPausedRef.current = false; + firstChunkSeenRef.current = false; + underrunCountRef.current = 0; + totalAudioSamplesRef.current = 0; + adaptiveResumeSecsRef.current = resumeThresholdSecs; setIsStreamPaused(false); - }, []); + }, [resumeThresholdSecs]); useEffect(() => { return () => { @@ -158,10 +167,17 @@ export function useStreamingGeneration({ if (!ctx) return; chunksRef.current.push(chunk); + totalAudioSamplesRef.current += chunk.length; + + if (!firstChunkSeenRef.current) { + firstChunkSeenRef.current = true; + onLog("First audio chunk received"); + } if (!hasStartedPlaybackRef.current) { const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE; if (bufferedSecs >= prebufferSecs) { + onLog(`Playback started after ${bufferedSecs.toFixed(1)}s buffered`); flushBufferedAudio(); } return; @@ -171,18 +187,30 @@ export function useStreamingGeneration({ if (isUserPausedRef.current) return; const ahead = nextStartTimeRef.current - ctx.currentTime; - if (ctx.state === "running" && ahead < rebufferThresholdSecs) { - ctx.suspend().catch(() => {}); - isAutoBufferingRef.current = true; - } else if ( - ctx.state === "suspended" && - isAutoBufferingRef.current && - ahead >= resumeThresholdSecs + if ( + ctx.state === "running" && + !isAutoBufferingRef.current && + ahead < rebufferThresholdSecs + ) { + isAutoBufferingRef.current = true; + underrunCountRef.current += 1; + adaptiveResumeSecsRef.current = Math.min( + MAX_ADAPTIVE_RESUME_SECS, + Math.max(resumeThresholdSecs, prebufferSecs + underrunCountRef.current * 2), + ); + ctx.suspend().catch(() => {}); + onLog( + `Buffer underrun ${underrunCountRef.current}; refilling to ${adaptiveResumeSecsRef.current.toFixed(1)}s`, + ); + } else if ( + isAutoBufferingRef.current && + ahead >= adaptiveResumeSecsRef.current ) { - ctx.resume().catch(() => {}); isAutoBufferingRef.current = false; + ctx.resume().catch(() => {}); + onLog(`Buffer recovered with ${ahead.toFixed(1)}s queued`); } - }, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]); + }, [enqueue, flushBufferedAudio, onLog, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]); const generate = useCallback(async (options: GenerateOptions) => { if (!options.text.trim()) return; @@ -239,6 +267,11 @@ export function useStreamingGeneration({ type: "audio_chunk" | "complete" | "error" | "cancelled"; data?: string; elapsed?: number; + audio_secs?: number; + realtime_factor?: number | null; + chunks?: number; + first_chunk_secs?: number | null; + max_chunk_gap_secs?: number; message?: string; }; @@ -247,12 +280,26 @@ export function useStreamingGeneration({ } else if (event.type === "complete") { if (!hasStartedPlaybackRef.current) { flushBufferedAudio(); + } else if (isAutoBufferingRef.current) { + isAutoBufferingRef.current = false; + audioCtxRef.current?.resume().catch(() => {}); } const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE); const audioUrl = URL.createObjectURL(wavBlob); audioUrlRef.current = audioUrl; const kb = (wavBlob.size / 1024).toFixed(0); - onLog(`Done in ${event.elapsed}s - ${kb} KB`); + const audioSecs = event.audio_secs ?? totalAudioSamplesRef.current / SAMPLE_RATE; + const realtimeFactor = + event.realtime_factor ?? + (event.elapsed && event.elapsed > 0 ? audioSecs / event.elapsed : null); + const speedText = + realtimeFactor === null ? "" : ` - ${realtimeFactor.toFixed(2)}x realtime`; + onLog(`Done in ${event.elapsed}s - ${audioSecs.toFixed(1)}s audio${speedText} - ${kb} KB`); + if (event.chunks && event.first_chunk_secs !== undefined) { + onLog( + `Stream: first chunk ${event.first_chunk_secs}s, ${event.chunks} chunks, max gap ${event.max_chunk_gap_secs}s`, + ); + } onSuccess(audioUrl); } else if (event.type === "cancelled") { throw new DOMException("Generation cancelled", "AbortError"); diff --git a/web/package.json b/web/package.json index 3f88249..f5c4ca1 100644 --- a/web/package.json +++ b/web/package.json @@ -4,7 +4,7 @@ "private": true, "scripts": { "dev": "next dev --turbopack", - "build": "next build --turbopack", + "build": "next build", "start": "next start" }, "dependencies": {