perf: improve streaming generation pipeline

Add CUDA inference hot-path optimizations, safer attention fallback handling, and generation profiling hooks. Improve SSE streaming, browser buffering telemetry, and playback recovery while preserving default audio quality settings.
This commit is contained in:
2026-04-30 18:54:14 +01:00
parent a39ec536fd
commit 75b84b211b
9 changed files with 459 additions and 48 deletions
+2
View File
@@ -0,0 +1,2 @@
*.sh text eol=lf
*.py text eol=lf
+1
View File
@@ -0,0 +1 @@
3.12
+48
View File
@@ -7,6 +7,11 @@
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv) # ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu) # ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
# #
# Optional CUDA acceleration:
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
# Installs a matching third-party Windows flash-attn wheel when the CUDA venv
# uses Python 3.12, torch 2.6.0, and CUDA 12.4.
#
# The two modes maintain completely separate virtual environments so their torch # The two modes maintain completely separate virtual environments so their torch
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel. # --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
@@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then
exit 1 exit 1
fi fi
validate_flash_attn() {
uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null
}
remove_broken_flash_attn() {
if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then
if ! validate_flash_attn; then
echo " Installed flash-attn is not usable in this environment; removing it."
uv pip uninstall flash-attn
fi
fi
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# 2. Sync Python environment # 2. Sync Python environment
# CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the # CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the
@@ -65,6 +83,36 @@ if $CPU_MODE; then
else else
echo "--> Syncing CUDA Python environment (.venv)..." echo "--> Syncing CUDA Python environment (.venv)..."
uv sync uv sync
remove_broken_flash_attn
if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then
echo ""
echo "--> Checking optional FlashAttention wheel..."
if validate_flash_attn; then
echo " flash-attn already installed and importable."
else
PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")"
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..."
uv pip install "$FLASH_ATTN_WHEEL_URL"
if validate_flash_attn; then
echo " flash-attn import check passed."
else
echo " flash-attn import check failed; removing it and continuing with SDPA."
uv pip uninstall flash-attn
fi
else
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
echo " Continuing with PyTorch SDPA attention."
fi
fi
fi
fi fi
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+336 -29
View File
@@ -22,11 +22,13 @@ import asyncio
import base64 import base64
import copy import copy
import functools import functools
import importlib.util
import json import json
import logging import logging
import os import os
import threading import threading
import time import time
import types
import urllib.request import urllib.request
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
@@ -211,8 +213,39 @@ def _init_processor():
def _init_model(device: str): def _init_model(device: str):
logger.info("Loading model on %s...", device) logger.info("Loading model on %s...", device)
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 if device == "cuda":
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
logger.info(
"PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s",
torch.backends.cuda.flash_sdp_enabled(),
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled(),
)
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
if device == "cuda" and cuda_dtype == "fp16":
load_dtype = torch.float16
else:
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
if requested_attn_impl in {"eager", "sdpa"}:
attn_impl = requested_attn_impl
elif requested_attn_impl == "flash_attention_2":
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
else:
attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
logger.info("Using Transformers attention implementation: %s", attn_impl)
if device == "cuda" and not has_flash_attn:
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
from vibevoice.modular.modeling_vibevoice_streaming_inference import ( from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference, VibeVoiceStreamingForConditionalGenerationInference,
@@ -225,9 +258,13 @@ def _init_model(device: str):
device_map=device, device_map=device,
attn_implementation=attn_impl, attn_implementation=attn_impl,
) )
except Exception: except Exception as exc:
if attn_impl == "sdpa":
raise
logger.warning( logger.warning(
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True "Model load with %s failed (%s); falling back to sdpa",
attn_impl,
exc,
) )
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID, MODEL_ID,
@@ -238,9 +275,164 @@ def _init_model(device: str):
model.eval() model.eval()
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"]) model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
_install_generation_optimizations(model)
return model return model
def _install_generation_optimizations(model: object) -> None:
"""Patch VibeVoice hot paths without changing model quality settings."""
def profile_enabled() -> bool:
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
def profile_sync() -> None:
if torch.cuda.is_available():
torch.cuda.synchronize()
def profile_record(self, key: str, elapsed: float) -> None:
stats = getattr(self, "_vibepod_profile", None)
if stats is None:
stats = {}
self._vibepod_profile = stats
bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0})
bucket["count"] += 1
bucket["seconds"] += elapsed
def timed_method(self, key: str, fn, *args, **kwargs):
if not profile_enabled():
return fn(*args, **kwargs)
profile_sync()
started = time.perf_counter()
result = fn(*args, **kwargs)
profile_sync()
profile_record(self, key, time.perf_counter() - started)
return result
def prepare_noise_scheduler(self):
scheduler = self.model.noise_scheduler
cache_key = self.ddpm_inference_steps
cache = getattr(self, "_vibepod_scheduler_cache", {})
cached = cache.get(cache_key)
if cached is None:
scheduler.set_timesteps(self.ddpm_inference_steps)
cached = {
"num_inference_steps": scheduler.num_inference_steps,
"timesteps": scheduler.timesteps,
"sigmas": scheduler.sigmas,
}
cache[cache_key] = cached
self._vibepod_scheduler_cache = cache
else:
scheduler.num_inference_steps = cached["num_inference_steps"]
scheduler.timesteps = cached["timesteps"]
scheduler.sigmas = cached["sigmas"]
scheduler.model_outputs = [None] * scheduler.config.solver_order
scheduler.lower_order_nums = 0
scheduler._step_index = None
scheduler._begin_index = None
return scheduler
def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0):
scheduler = prepare_noise_scheduler(self)
condition = torch.cat([condition, neg_condition], dim=0).to(
self.model.prediction_head.device
)
batch_size = condition.shape[0] // 2
speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition)
t_batch_cache_key = (
self.ddpm_inference_steps,
condition.device.type,
condition.device.index,
condition.dtype,
batch_size,
)
t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {})
t_batches = t_batch_cache.get(t_batch_cache_key)
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
t_batches = [
t.repeat(condition.shape[0]).to(
device=condition.device, dtype=condition.dtype
)
for t in scheduler.timesteps
]
t_batch_cache[t_batch_cache_key] = t_batches
self._vibepod_t_batch_cache = t_batch_cache
for t, t_batch in zip(scheduler.timesteps, t_batches):
if batch_size == 1:
combined = speech.expand(condition.shape[0], -1)
else:
combined = torch.cat([speech, speech], dim=0)
if profile_enabled():
profile_sync()
started = time.perf_counter()
eps = self.model.prediction_head(combined, t_batch, condition=condition)
if profile_enabled():
profile_sync()
profile_record(self, "diffusion_prediction_head", time.perf_counter() - started)
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
if profile_enabled():
started = time.perf_counter()
speech = scheduler.step(guided_eps, t, speech).prev_sample
if profile_enabled():
profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started)
return speech
forward_lm = model.forward_lm
forward_tts_lm = model.forward_tts_lm
acoustic_decode = model.model.acoustic_tokenizer.decode
def forward_lm_profiled(*args, **kwargs):
return timed_method(model, "forward_lm", forward_lm, *args, **kwargs)
def forward_tts_lm_profiled(*args, **kwargs):
return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs)
def acoustic_decode_profiled(*args, **kwargs):
return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs)
model.forward_lm = forward_lm_profiled
model.forward_tts_lm = forward_tts_lm_profiled
model.model.acoustic_tokenizer.decode = acoustic_decode_profiled
model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model)
logger.info("Installed VibeVoice generation hot-path optimizations.")
def _model_float_dtype() -> torch.dtype:
try:
return next(_model.parameters()).dtype
except StopIteration:
return torch.float32
def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object:
if torch.is_tensor(value):
if torch.is_floating_point(value):
return value.to(device=device, dtype=dtype)
return value.to(device=device)
if isinstance(value, dict):
for k in list(value.keys()):
value[k] = _move_cached_prompt(value[k], device, dtype)
return value
if isinstance(value, list):
return [_move_cached_prompt(v, device, dtype) for v in value]
if isinstance(value, tuple):
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
value.key_cache = [
_move_cached_prompt(t, device, dtype) for t in value.key_cache
]
value.value_cache = [
_move_cached_prompt(t, device, dtype) for t in value.value_cache
]
return value
def _load_voice_presets(device: str) -> dict[str, object]: def _load_voice_presets(device: str) -> dict[str, object]:
presets = {} presets = {}
for name, filename in EN_VOICES.items(): for name, filename in EN_VOICES.items():
@@ -273,9 +465,9 @@ def _load_model_sync() -> None:
is_cpu = _device == "cpu" is_cpu = _device == "cpu"
_config["device"] = _device _config["device"] = _device
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0) _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 6.0 if is_cpu else 5.0)
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4) _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0)
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5) _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0)
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
_processor = _init_processor() _processor = _init_processor()
@@ -364,10 +556,15 @@ def _sync_generate(
raise RuntimeError("Generation cancelled.") raise RuntimeError("Generation cancelled.")
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
voice_preset = copy.deepcopy(_voice_presets[speaker]) model_dtype = _model_float_dtype()
voice_preset = _move_cached_prompt(
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
)
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"] steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
_model.set_ddpm_inference_steps(num_steps=steps) _model.set_ddpm_inference_steps(num_steps=steps)
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
_model._vibepod_profile = {}
inputs = _processor.process_input_with_cached_prompt( inputs = _processor.process_input_with_cached_prompt(
text=req.text, text=req.text,
@@ -380,19 +577,20 @@ def _sync_generate(
if torch.is_tensor(v): if torch.is_tensor(v):
inputs[k] = v.to(_device) inputs[k] = v.to(_device)
outputs = _model.generate( with torch.inference_mode():
**inputs, _model.generate(
max_new_tokens=None, **inputs,
cfg_scale=req.cfg_scale, max_new_tokens=None,
tokenizer=_processor.tokenizer, cfg_scale=req.cfg_scale,
generation_config={"do_sample": False}, tokenizer=_processor.tokenizer,
verbose=True, generation_config={"do_sample": False},
all_prefilled_outputs=copy.deepcopy(voice_preset), verbose=False,
audio_streamer=streamer, show_progress_bar=False,
) return_speech=False,
stop_check_fn=cancel_event.is_set if cancel_event else None,
if not outputs.speech_outputs or outputs.speech_outputs[0] is None: all_prefilled_outputs=voice_preset,
raise ValueError("Model returned no audio output.") audio_streamer=streamer,
)
return speaker return speaker
@@ -401,6 +599,24 @@ def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n" return f"data: {json.dumps(event)}\n\n"
def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
return None
stats = getattr(_model, "_vibepod_profile", None)
if not stats:
return {}
return {
key: {
"count": value["count"],
"seconds": round(value["seconds"], 3),
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3)
if value["count"]
else 0.0,
}
for key, value in sorted(stats.items())
}
@app.post("/generate") @app.post("/generate")
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if _model_status != "online": if _model_status != "online":
@@ -417,20 +633,58 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
) )
async def event_stream() -> AsyncGenerator[str, None]: async def event_stream() -> AsyncGenerator[str, None]:
from vibevoice.modular.streamer import AsyncAudioStreamer class NonBlockingAudioStreamer:
"""Async streamer that keeps GPU->CPU copies out of the model thread."""
def __init__(self, batch_size: int, stop_signal: object = None) -> None:
self.batch_size = batch_size
self.stop_signal = stop_signal
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None:
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait,
audio_chunks[i].detach(),
)
def end(self, sample_indices: Optional[torch.Tensor] = None) -> None:
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [
s.item() if torch.is_tensor(s) else s for s in sample_indices
]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
start = time.monotonic() start = time.monotonic()
streamer = AsyncAudioStreamer(batch_size=1) streamer = NonBlockingAudioStreamer(batch_size=1)
cancel_event = threading.Event() cancel_event = threading.Event()
accum_size = max(1, _config["chunk_accum"]) accum_size = max(1, _config["chunk_accum"])
accumulated_chunks = [] accumulated_chunks = []
chunk_count = 0
audio_samples = 0
first_chunk_at: Optional[float] = None
last_chunk_at: Optional[float] = None
max_chunk_gap = 0.0
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
async with _generation_lock: async with _generation_lock:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
future = loop.run_in_executor( future = loop.run_in_executor(
None, functools.partial(_sync_generate, req, streamer, cancel_event) None, functools.partial(_sync_generate, req, streamer, cancel_event)
) )
future.add_done_callback(lambda _: streamer.end())
# Drain audio chunks as they arrive from the diffusion head. # Drain audio chunks as they arrive from the diffusion head.
# stop_signal=None is the default sentinel that ends the queue. # stop_signal=None is the default sentinel that ends the queue.
@@ -454,17 +708,45 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if chunk is None: # stop signal if chunk is None: # stop signal
break break
accumulated_chunks.append(chunk.detach().cpu().float()) accumulated_chunks.append(chunk.detach())
if len(accumulated_chunks) >= accum_size: if len(accumulated_chunks) >= accum_size:
combined = torch.cat(accumulated_chunks, dim=0) now = time.monotonic()
if first_chunk_at is None:
first_chunk_at = now
if last_chunk_at is not None:
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
last_chunk_at = now
combined = (
torch.cat(accumulated_chunks, dim=0)
.detach()
.to("cpu", dtype=torch.float32)
.contiguous()
)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64}) yield _sse({"type": "audio_chunk", "data": pcm_b64})
accumulated_chunks = [] accumulated_chunks = []
# Flush any remaining chunks # Flush any remaining chunks
if accumulated_chunks: if accumulated_chunks:
combined = torch.cat(accumulated_chunks, dim=0) now = time.monotonic()
if first_chunk_at is None:
first_chunk_at = now
if last_chunk_at is not None:
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
last_chunk_at = now
combined = (
torch.cat(accumulated_chunks, dim=0)
.detach()
.to("cpu", dtype=torch.float32)
.contiguous()
)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64}) yield _sse({"type": "audio_chunk", "data": pcm_b64})
@@ -479,17 +761,42 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
yield _sse( yield _sse(
{ {
"type": "error", "type": "error",
"message": "Internal server error during generation.", "message": f"Generation failed: {exc}",
} }
) )
return return
elapsed = round(time.monotonic() - start, 1) elapsed = round(time.monotonic() - start, 1)
audio_secs = audio_samples / SAMPLE_RATE
realtime_factor = audio_secs / elapsed if elapsed > 0 else None
profile = _generation_profile()
if profile is not None:
logger.info("Generation profile: %s", profile)
logger.info("Generation complete in %.1fs", elapsed) logger.info("Generation complete in %.1fs", elapsed)
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker}) complete_event = {
"type": "complete",
"elapsed": elapsed,
"speaker": speaker,
"audio_secs": round(audio_secs, 2),
"realtime_factor": round(realtime_factor, 3)
if realtime_factor is not None
else None,
"chunks": chunk_count,
"first_chunk_secs": round(first_chunk_at - start, 2)
if first_chunk_at is not None
else None,
"max_chunk_gap_secs": round(max_chunk_gap, 2),
}
if profile is not None:
complete_event["profile"] = profile
yield _sse(complete_event)
return StreamingResponse( return StreamingResponse(
event_stream(), event_stream(),
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"X-Content-Type-Options": "nosniff",
},
) )
+6 -1
View File
@@ -1,5 +1,8 @@
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
export const dynamic = "force-dynamic";
export const runtime = "nodejs";
export async function POST(request: NextRequest) { export async function POST(request: NextRequest) {
const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000";
@@ -24,6 +27,7 @@ export async function POST(request: NextRequest) {
cfg_scale: body.cfg_scale ?? 1.5, cfg_scale: body.cfg_scale ?? 1.5,
inference_steps: body.inference_steps ?? 10, inference_steps: body.inference_steps ?? 10,
}), }),
signal: request.signal,
}); });
if (!upstream.ok) { if (!upstream.ok) {
@@ -36,8 +40,9 @@ export async function POST(request: NextRequest) {
status: 200, status: 200,
headers: { headers: {
"Content-Type": "text/event-stream", "Content-Type": "text/event-stream",
"Cache-Control": "no-cache", "Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive", "Connection": "keep-alive",
"X-Content-Type-Options": "nosniff",
"X-Accel-Buffering": "no", "X-Accel-Buffering": "no",
}, },
}); });
+1
View File
@@ -27,6 +27,7 @@ export async function GET() {
message: data.message, message: data.message,
progress: data.progress ?? null, progress: data.progress ?? null,
voices: data.voices ?? [], voices: data.voices ?? [],
config: data.config ?? null,
}, },
COMMON_OPTIONS COMMON_OPTIONS
); );
+3 -3
View File
@@ -130,9 +130,9 @@ const initialState: AppState = {
speaker: "carter", speaker: "carter",
cfgScale: 1.5, cfgScale: 1.5,
inferenceSteps: 10, inferenceSteps: 10,
prebufferSecs: 2.0, prebufferSecs: 5.0,
rebufferThresholdSecs: 0.4, rebufferThresholdSecs: 1.0,
resumeThresholdSecs: 1.5, resumeThresholdSecs: 3.0,
isGenerating: false, isGenerating: false,
genElapsed: 0, genElapsed: 0,
genPct: null, genPct: null,
+61 -14
View File
@@ -3,9 +3,10 @@
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
const SAMPLE_RATE = 24_000; const SAMPLE_RATE = 24_000;
const DEFAULT_PREBUFFER_SECS = 2.0; const DEFAULT_PREBUFFER_SECS = 5.0;
const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4; const DEFAULT_REBUFFER_THRESHOLD_SECS = 1.0;
const DEFAULT_RESUME_THRESHOLD_SECS = 1.5; const DEFAULT_RESUME_THRESHOLD_SECS = 3.0;
const MAX_ADAPTIVE_RESUME_SECS = 18.0;
interface GenerateOptions { interface GenerateOptions {
text: string; text: string;
@@ -104,6 +105,10 @@ export function useStreamingGeneration({
const isAutoBufferingRef = useRef(false); const isAutoBufferingRef = useRef(false);
const isUserPausedRef = useRef(false); const isUserPausedRef = useRef(false);
const audioUrlRef = useRef<string | null>(null); const audioUrlRef = useRef<string | null>(null);
const firstChunkSeenRef = useRef(false);
const underrunCountRef = useRef(0);
const totalAudioSamplesRef = useRef(0);
const adaptiveResumeSecsRef = useRef(DEFAULT_RESUME_THRESHOLD_SECS);
const revokeCurrentUrl = useCallback(() => { const revokeCurrentUrl = useCallback(() => {
if (audioUrlRef.current) { if (audioUrlRef.current) {
@@ -122,8 +127,12 @@ export function useStreamingGeneration({
hasStartedPlaybackRef.current = false; hasStartedPlaybackRef.current = false;
isAutoBufferingRef.current = false; isAutoBufferingRef.current = false;
isUserPausedRef.current = false; isUserPausedRef.current = false;
firstChunkSeenRef.current = false;
underrunCountRef.current = 0;
totalAudioSamplesRef.current = 0;
adaptiveResumeSecsRef.current = resumeThresholdSecs;
setIsStreamPaused(false); setIsStreamPaused(false);
}, []); }, [resumeThresholdSecs]);
useEffect(() => { useEffect(() => {
return () => { return () => {
@@ -158,10 +167,17 @@ export function useStreamingGeneration({
if (!ctx) return; if (!ctx) return;
chunksRef.current.push(chunk); chunksRef.current.push(chunk);
totalAudioSamplesRef.current += chunk.length;
if (!firstChunkSeenRef.current) {
firstChunkSeenRef.current = true;
onLog("First audio chunk received");
}
if (!hasStartedPlaybackRef.current) { if (!hasStartedPlaybackRef.current) {
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE; const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
if (bufferedSecs >= prebufferSecs) { if (bufferedSecs >= prebufferSecs) {
onLog(`Playback started after ${bufferedSecs.toFixed(1)}s buffered`);
flushBufferedAudio(); flushBufferedAudio();
} }
return; return;
@@ -171,18 +187,30 @@ export function useStreamingGeneration({
if (isUserPausedRef.current) return; if (isUserPausedRef.current) return;
const ahead = nextStartTimeRef.current - ctx.currentTime; const ahead = nextStartTimeRef.current - ctx.currentTime;
if (ctx.state === "running" && ahead < rebufferThresholdSecs) { if (
ctx.suspend().catch(() => {}); ctx.state === "running" &&
isAutoBufferingRef.current = true; !isAutoBufferingRef.current &&
} else if ( ahead < rebufferThresholdSecs
ctx.state === "suspended" && ) {
isAutoBufferingRef.current && isAutoBufferingRef.current = true;
ahead >= resumeThresholdSecs underrunCountRef.current += 1;
adaptiveResumeSecsRef.current = Math.min(
MAX_ADAPTIVE_RESUME_SECS,
Math.max(resumeThresholdSecs, prebufferSecs + underrunCountRef.current * 2),
);
ctx.suspend().catch(() => {});
onLog(
`Buffer underrun ${underrunCountRef.current}; refilling to ${adaptiveResumeSecsRef.current.toFixed(1)}s`,
);
} else if (
isAutoBufferingRef.current &&
ahead >= adaptiveResumeSecsRef.current
) { ) {
ctx.resume().catch(() => {});
isAutoBufferingRef.current = false; isAutoBufferingRef.current = false;
ctx.resume().catch(() => {});
onLog(`Buffer recovered with ${ahead.toFixed(1)}s queued`);
} }
}, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]); }, [enqueue, flushBufferedAudio, onLog, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
const generate = useCallback(async (options: GenerateOptions) => { const generate = useCallback(async (options: GenerateOptions) => {
if (!options.text.trim()) return; if (!options.text.trim()) return;
@@ -239,6 +267,11 @@ export function useStreamingGeneration({
type: "audio_chunk" | "complete" | "error" | "cancelled"; type: "audio_chunk" | "complete" | "error" | "cancelled";
data?: string; data?: string;
elapsed?: number; elapsed?: number;
audio_secs?: number;
realtime_factor?: number | null;
chunks?: number;
first_chunk_secs?: number | null;
max_chunk_gap_secs?: number;
message?: string; message?: string;
}; };
@@ -247,12 +280,26 @@ export function useStreamingGeneration({
} else if (event.type === "complete") { } else if (event.type === "complete") {
if (!hasStartedPlaybackRef.current) { if (!hasStartedPlaybackRef.current) {
flushBufferedAudio(); flushBufferedAudio();
} else if (isAutoBufferingRef.current) {
isAutoBufferingRef.current = false;
audioCtxRef.current?.resume().catch(() => {});
} }
const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE); const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE);
const audioUrl = URL.createObjectURL(wavBlob); const audioUrl = URL.createObjectURL(wavBlob);
audioUrlRef.current = audioUrl; audioUrlRef.current = audioUrl;
const kb = (wavBlob.size / 1024).toFixed(0); const kb = (wavBlob.size / 1024).toFixed(0);
onLog(`Done in ${event.elapsed}s - ${kb} KB`); const audioSecs = event.audio_secs ?? totalAudioSamplesRef.current / SAMPLE_RATE;
const realtimeFactor =
event.realtime_factor ??
(event.elapsed && event.elapsed > 0 ? audioSecs / event.elapsed : null);
const speedText =
realtimeFactor === null ? "" : ` - ${realtimeFactor.toFixed(2)}x realtime`;
onLog(`Done in ${event.elapsed}s - ${audioSecs.toFixed(1)}s audio${speedText} - ${kb} KB`);
if (event.chunks && event.first_chunk_secs !== undefined) {
onLog(
`Stream: first chunk ${event.first_chunk_secs}s, ${event.chunks} chunks, max gap ${event.max_chunk_gap_secs}s`,
);
}
onSuccess(audioUrl); onSuccess(audioUrl);
} else if (event.type === "cancelled") { } else if (event.type === "cancelled") {
throw new DOMException("Generation cancelled", "AbortError"); throw new DOMException("Generation cancelled", "AbortError");
+1 -1
View File
@@ -4,7 +4,7 @@
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "next dev --turbopack", "dev": "next dev --turbopack",
"build": "next build --turbopack", "build": "next build",
"start": "next start" "start": "next start"
}, },
"dependencies": { "dependencies": {