mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
perf: improve streaming generation pipeline
Add CUDA inference hot-path optimizations, safer attention fallback handling, and generation profiling hooks. Improve SSE streaming, browser buffering telemetry, and playback recovery while preserving default audio quality settings.
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
*.sh text eol=lf
|
||||
*.py text eol=lf
|
||||
@@ -0,0 +1 @@
|
||||
3.12
|
||||
@@ -7,6 +7,11 @@
|
||||
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
|
||||
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
|
||||
#
|
||||
# Optional CUDA acceleration:
|
||||
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
||||
# Installs a matching third-party Windows flash-attn wheel when the CUDA venv
|
||||
# uses Python 3.12, torch 2.6.0, and CUDA 12.4.
|
||||
#
|
||||
# The two modes maintain completely separate virtual environments so their torch
|
||||
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
||||
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
|
||||
@@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
validate_flash_attn() {
|
||||
uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null
|
||||
}
|
||||
|
||||
remove_broken_flash_attn() {
|
||||
if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then
|
||||
if ! validate_flash_attn; then
|
||||
echo " Installed flash-attn is not usable in this environment; removing it."
|
||||
uv pip uninstall flash-attn
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Sync Python environment
|
||||
# CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the
|
||||
@@ -65,6 +83,36 @@ if $CPU_MODE; then
|
||||
else
|
||||
echo "--> Syncing CUDA Python environment (.venv)..."
|
||||
uv sync
|
||||
|
||||
remove_broken_flash_attn
|
||||
|
||||
if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then
|
||||
echo ""
|
||||
echo "--> Checking optional FlashAttention wheel..."
|
||||
|
||||
if validate_flash_attn; then
|
||||
echo " flash-attn already installed and importable."
|
||||
else
|
||||
PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")"
|
||||
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
|
||||
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
|
||||
|
||||
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
|
||||
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
|
||||
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..."
|
||||
uv pip install "$FLASH_ATTN_WHEEL_URL"
|
||||
if validate_flash_attn; then
|
||||
echo " flash-attn import check passed."
|
||||
else
|
||||
echo " flash-attn import check failed; removing it and continuing with SDPA."
|
||||
uv pip uninstall flash-attn
|
||||
fi
|
||||
else
|
||||
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
|
||||
echo " Continuing with PyTorch SDPA attention."
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
+328
-21
@@ -22,11 +22,13 @@ import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
import urllib.request
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
@@ -211,8 +213,39 @@ def _init_processor():
|
||||
|
||||
def _init_model(device: str):
|
||||
logger.info("Loading model on %s...", device)
|
||||
if device == "cuda":
|
||||
torch.set_float32_matmul_precision("high")
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
logger.info(
|
||||
"PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s",
|
||||
torch.backends.cuda.flash_sdp_enabled(),
|
||||
torch.backends.cuda.mem_efficient_sdp_enabled(),
|
||||
torch.backends.cuda.math_sdp_enabled(),
|
||||
)
|
||||
|
||||
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
|
||||
if device == "cuda" and cuda_dtype == "fp16":
|
||||
load_dtype = torch.float16
|
||||
else:
|
||||
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
|
||||
logger.info("Loading model weights with dtype %s", load_dtype)
|
||||
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
|
||||
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
|
||||
if requested_attn_impl in {"eager", "sdpa"}:
|
||||
attn_impl = requested_attn_impl
|
||||
elif requested_attn_impl == "flash_attention_2":
|
||||
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
|
||||
else:
|
||||
attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
|
||||
logger.info("Using Transformers attention implementation: %s", attn_impl)
|
||||
if device == "cuda" and not has_flash_attn:
|
||||
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
|
||||
|
||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||
VibeVoiceStreamingForConditionalGenerationInference,
|
||||
@@ -225,9 +258,13 @@ def _init_model(device: str):
|
||||
device_map=device,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
if attn_impl == "sdpa":
|
||||
raise
|
||||
logger.warning(
|
||||
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True
|
||||
"Model load with %s failed (%s); falling back to sdpa",
|
||||
attn_impl,
|
||||
exc,
|
||||
)
|
||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
@@ -238,9 +275,164 @@ def _init_model(device: str):
|
||||
|
||||
model.eval()
|
||||
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
|
||||
_install_generation_optimizations(model)
|
||||
return model
|
||||
|
||||
|
||||
def _install_generation_optimizations(model: object) -> None:
|
||||
"""Patch VibeVoice hot paths without changing model quality settings."""
|
||||
|
||||
def profile_enabled() -> bool:
|
||||
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
|
||||
|
||||
def profile_sync() -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def profile_record(self, key: str, elapsed: float) -> None:
|
||||
stats = getattr(self, "_vibepod_profile", None)
|
||||
if stats is None:
|
||||
stats = {}
|
||||
self._vibepod_profile = stats
|
||||
bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0})
|
||||
bucket["count"] += 1
|
||||
bucket["seconds"] += elapsed
|
||||
|
||||
def timed_method(self, key: str, fn, *args, **kwargs):
|
||||
if not profile_enabled():
|
||||
return fn(*args, **kwargs)
|
||||
profile_sync()
|
||||
started = time.perf_counter()
|
||||
result = fn(*args, **kwargs)
|
||||
profile_sync()
|
||||
profile_record(self, key, time.perf_counter() - started)
|
||||
return result
|
||||
|
||||
def prepare_noise_scheduler(self):
|
||||
scheduler = self.model.noise_scheduler
|
||||
cache_key = self.ddpm_inference_steps
|
||||
cache = getattr(self, "_vibepod_scheduler_cache", {})
|
||||
cached = cache.get(cache_key)
|
||||
|
||||
if cached is None:
|
||||
scheduler.set_timesteps(self.ddpm_inference_steps)
|
||||
cached = {
|
||||
"num_inference_steps": scheduler.num_inference_steps,
|
||||
"timesteps": scheduler.timesteps,
|
||||
"sigmas": scheduler.sigmas,
|
||||
}
|
||||
cache[cache_key] = cached
|
||||
self._vibepod_scheduler_cache = cache
|
||||
else:
|
||||
scheduler.num_inference_steps = cached["num_inference_steps"]
|
||||
scheduler.timesteps = cached["timesteps"]
|
||||
scheduler.sigmas = cached["sigmas"]
|
||||
scheduler.model_outputs = [None] * scheduler.config.solver_order
|
||||
scheduler.lower_order_nums = 0
|
||||
scheduler._step_index = None
|
||||
scheduler._begin_index = None
|
||||
|
||||
return scheduler
|
||||
|
||||
def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0):
|
||||
scheduler = prepare_noise_scheduler(self)
|
||||
|
||||
condition = torch.cat([condition, neg_condition], dim=0).to(
|
||||
self.model.prediction_head.device
|
||||
)
|
||||
batch_size = condition.shape[0] // 2
|
||||
speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition)
|
||||
t_batch_cache_key = (
|
||||
self.ddpm_inference_steps,
|
||||
condition.device.type,
|
||||
condition.device.index,
|
||||
condition.dtype,
|
||||
batch_size,
|
||||
)
|
||||
t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {})
|
||||
t_batches = t_batch_cache.get(t_batch_cache_key)
|
||||
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
||||
t_batches = [
|
||||
t.repeat(condition.shape[0]).to(
|
||||
device=condition.device, dtype=condition.dtype
|
||||
)
|
||||
for t in scheduler.timesteps
|
||||
]
|
||||
t_batch_cache[t_batch_cache_key] = t_batches
|
||||
self._vibepod_t_batch_cache = t_batch_cache
|
||||
|
||||
for t, t_batch in zip(scheduler.timesteps, t_batches):
|
||||
if batch_size == 1:
|
||||
combined = speech.expand(condition.shape[0], -1)
|
||||
else:
|
||||
combined = torch.cat([speech, speech], dim=0)
|
||||
if profile_enabled():
|
||||
profile_sync()
|
||||
started = time.perf_counter()
|
||||
eps = self.model.prediction_head(combined, t_batch, condition=condition)
|
||||
if profile_enabled():
|
||||
profile_sync()
|
||||
profile_record(self, "diffusion_prediction_head", time.perf_counter() - started)
|
||||
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
|
||||
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
if profile_enabled():
|
||||
started = time.perf_counter()
|
||||
speech = scheduler.step(guided_eps, t, speech).prev_sample
|
||||
if profile_enabled():
|
||||
profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started)
|
||||
|
||||
return speech
|
||||
|
||||
forward_lm = model.forward_lm
|
||||
forward_tts_lm = model.forward_tts_lm
|
||||
acoustic_decode = model.model.acoustic_tokenizer.decode
|
||||
|
||||
def forward_lm_profiled(*args, **kwargs):
|
||||
return timed_method(model, "forward_lm", forward_lm, *args, **kwargs)
|
||||
|
||||
def forward_tts_lm_profiled(*args, **kwargs):
|
||||
return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs)
|
||||
|
||||
def acoustic_decode_profiled(*args, **kwargs):
|
||||
return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs)
|
||||
|
||||
model.forward_lm = forward_lm_profiled
|
||||
model.forward_tts_lm = forward_tts_lm_profiled
|
||||
model.model.acoustic_tokenizer.decode = acoustic_decode_profiled
|
||||
model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model)
|
||||
logger.info("Installed VibeVoice generation hot-path optimizations.")
|
||||
|
||||
|
||||
def _model_float_dtype() -> torch.dtype:
|
||||
try:
|
||||
return next(_model.parameters()).dtype
|
||||
except StopIteration:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object:
|
||||
if torch.is_tensor(value):
|
||||
if torch.is_floating_point(value):
|
||||
return value.to(device=device, dtype=dtype)
|
||||
return value.to(device=device)
|
||||
if isinstance(value, dict):
|
||||
for k in list(value.keys()):
|
||||
value[k] = _move_cached_prompt(value[k], device, dtype)
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return [_move_cached_prompt(v, device, dtype) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
||||
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
||||
value.key_cache = [
|
||||
_move_cached_prompt(t, device, dtype) for t in value.key_cache
|
||||
]
|
||||
value.value_cache = [
|
||||
_move_cached_prompt(t, device, dtype) for t in value.value_cache
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
def _load_voice_presets(device: str) -> dict[str, object]:
|
||||
presets = {}
|
||||
for name, filename in EN_VOICES.items():
|
||||
@@ -273,9 +465,9 @@ def _load_model_sync() -> None:
|
||||
is_cpu = _device == "cpu"
|
||||
_config["device"] = _device
|
||||
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
||||
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0)
|
||||
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4)
|
||||
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5)
|
||||
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 6.0 if is_cpu else 5.0)
|
||||
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0)
|
||||
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0)
|
||||
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
|
||||
|
||||
_processor = _init_processor()
|
||||
@@ -364,10 +556,15 @@ def _sync_generate(
|
||||
raise RuntimeError("Generation cancelled.")
|
||||
|
||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||
voice_preset = copy.deepcopy(_voice_presets[speaker])
|
||||
model_dtype = _model_float_dtype()
|
||||
voice_preset = _move_cached_prompt(
|
||||
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
|
||||
)
|
||||
|
||||
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
|
||||
_model.set_ddpm_inference_steps(num_steps=steps)
|
||||
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
|
||||
_model._vibepod_profile = {}
|
||||
|
||||
inputs = _processor.process_input_with_cached_prompt(
|
||||
text=req.text,
|
||||
@@ -380,20 +577,21 @@ def _sync_generate(
|
||||
if torch.is_tensor(v):
|
||||
inputs[k] = v.to(_device)
|
||||
|
||||
outputs = _model.generate(
|
||||
with torch.inference_mode():
|
||||
_model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=None,
|
||||
cfg_scale=req.cfg_scale,
|
||||
tokenizer=_processor.tokenizer,
|
||||
generation_config={"do_sample": False},
|
||||
verbose=True,
|
||||
all_prefilled_outputs=copy.deepcopy(voice_preset),
|
||||
verbose=False,
|
||||
show_progress_bar=False,
|
||||
return_speech=False,
|
||||
stop_check_fn=cancel_event.is_set if cancel_event else None,
|
||||
all_prefilled_outputs=voice_preset,
|
||||
audio_streamer=streamer,
|
||||
)
|
||||
|
||||
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
|
||||
raise ValueError("Model returned no audio output.")
|
||||
|
||||
return speaker
|
||||
|
||||
|
||||
@@ -401,6 +599,24 @@ def _sse(event: dict) -> str:
|
||||
return f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
|
||||
def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
|
||||
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
|
||||
return None
|
||||
stats = getattr(_model, "_vibepod_profile", None)
|
||||
if not stats:
|
||||
return {}
|
||||
return {
|
||||
key: {
|
||||
"count": value["count"],
|
||||
"seconds": round(value["seconds"], 3),
|
||||
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3)
|
||||
if value["count"]
|
||||
else 0.0,
|
||||
}
|
||||
for key, value in sorted(stats.items())
|
||||
}
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if _model_status != "online":
|
||||
@@ -417,20 +633,58 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
)
|
||||
|
||||
async def event_stream() -> AsyncGenerator[str, None]:
|
||||
from vibevoice.modular.streamer import AsyncAudioStreamer
|
||||
class NonBlockingAudioStreamer:
|
||||
"""Async streamer that keeps GPU->CPU copies out of the model thread."""
|
||||
|
||||
def __init__(self, batch_size: int, stop_signal: object = None) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.stop_signal = stop_signal
|
||||
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
||||
self.finished_flags = [False for _ in range(batch_size)]
|
||||
self.loop = asyncio.get_running_loop()
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None:
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait,
|
||||
audio_chunks[i].detach(),
|
||||
)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None) -> None:
|
||||
if sample_indices is None:
|
||||
indices_to_end = range(self.batch_size)
|
||||
else:
|
||||
indices_to_end = [
|
||||
s.item() if torch.is_tensor(s) else s for s in sample_indices
|
||||
]
|
||||
for idx in indices_to_end:
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait, self.stop_signal
|
||||
)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
start = time.monotonic()
|
||||
streamer = AsyncAudioStreamer(batch_size=1)
|
||||
streamer = NonBlockingAudioStreamer(batch_size=1)
|
||||
cancel_event = threading.Event()
|
||||
|
||||
accum_size = max(1, _config["chunk_accum"])
|
||||
accumulated_chunks = []
|
||||
chunk_count = 0
|
||||
audio_samples = 0
|
||||
first_chunk_at: Optional[float] = None
|
||||
last_chunk_at: Optional[float] = None
|
||||
max_chunk_gap = 0.0
|
||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||
|
||||
async with _generation_lock:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
None, functools.partial(_sync_generate, req, streamer, cancel_event)
|
||||
)
|
||||
future.add_done_callback(lambda _: streamer.end())
|
||||
|
||||
# Drain audio chunks as they arrive from the diffusion head.
|
||||
# stop_signal=None is the default sentinel that ends the queue.
|
||||
@@ -454,17 +708,45 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if chunk is None: # stop signal
|
||||
break
|
||||
|
||||
accumulated_chunks.append(chunk.detach().cpu().float())
|
||||
accumulated_chunks.append(chunk.detach())
|
||||
|
||||
if len(accumulated_chunks) >= accum_size:
|
||||
combined = torch.cat(accumulated_chunks, dim=0)
|
||||
now = time.monotonic()
|
||||
if first_chunk_at is None:
|
||||
first_chunk_at = now
|
||||
if last_chunk_at is not None:
|
||||
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
|
||||
last_chunk_at = now
|
||||
|
||||
combined = (
|
||||
torch.cat(accumulated_chunks, dim=0)
|
||||
.detach()
|
||||
.to("cpu", dtype=torch.float32)
|
||||
.contiguous()
|
||||
)
|
||||
chunk_count += 1
|
||||
audio_samples += combined.numel()
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||
accumulated_chunks = []
|
||||
|
||||
# Flush any remaining chunks
|
||||
if accumulated_chunks:
|
||||
combined = torch.cat(accumulated_chunks, dim=0)
|
||||
now = time.monotonic()
|
||||
if first_chunk_at is None:
|
||||
first_chunk_at = now
|
||||
if last_chunk_at is not None:
|
||||
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
|
||||
last_chunk_at = now
|
||||
|
||||
combined = (
|
||||
torch.cat(accumulated_chunks, dim=0)
|
||||
.detach()
|
||||
.to("cpu", dtype=torch.float32)
|
||||
.contiguous()
|
||||
)
|
||||
chunk_count += 1
|
||||
audio_samples += combined.numel()
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||
|
||||
@@ -479,17 +761,42 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
yield _sse(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Internal server error during generation.",
|
||||
"message": f"Generation failed: {exc}",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
elapsed = round(time.monotonic() - start, 1)
|
||||
audio_secs = audio_samples / SAMPLE_RATE
|
||||
realtime_factor = audio_secs / elapsed if elapsed > 0 else None
|
||||
profile = _generation_profile()
|
||||
if profile is not None:
|
||||
logger.info("Generation profile: %s", profile)
|
||||
logger.info("Generation complete in %.1fs", elapsed)
|
||||
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker})
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"elapsed": elapsed,
|
||||
"speaker": speaker,
|
||||
"audio_secs": round(audio_secs, 2),
|
||||
"realtime_factor": round(realtime_factor, 3)
|
||||
if realtime_factor is not None
|
||||
else None,
|
||||
"chunks": chunk_count,
|
||||
"first_chunk_secs": round(first_chunk_at - start, 2)
|
||||
if first_chunk_at is not None
|
||||
else None,
|
||||
"max_chunk_gap_secs": round(max_chunk_gap, 2),
|
||||
}
|
||||
if profile is not None:
|
||||
complete_event["profile"] = profile
|
||||
yield _sse(complete_event)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
export const dynamic = "force-dynamic";
|
||||
export const runtime = "nodejs";
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000";
|
||||
|
||||
@@ -24,6 +27,7 @@ export async function POST(request: NextRequest) {
|
||||
cfg_scale: body.cfg_scale ?? 1.5,
|
||||
inference_steps: body.inference_steps ?? 10,
|
||||
}),
|
||||
signal: request.signal,
|
||||
});
|
||||
|
||||
if (!upstream.ok) {
|
||||
@@ -36,8 +40,9 @@ export async function POST(request: NextRequest) {
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -27,6 +27,7 @@ export async function GET() {
|
||||
message: data.message,
|
||||
progress: data.progress ?? null,
|
||||
voices: data.voices ?? [],
|
||||
config: data.config ?? null,
|
||||
},
|
||||
COMMON_OPTIONS
|
||||
);
|
||||
|
||||
+3
-3
@@ -130,9 +130,9 @@ const initialState: AppState = {
|
||||
speaker: "carter",
|
||||
cfgScale: 1.5,
|
||||
inferenceSteps: 10,
|
||||
prebufferSecs: 2.0,
|
||||
rebufferThresholdSecs: 0.4,
|
||||
resumeThresholdSecs: 1.5,
|
||||
prebufferSecs: 5.0,
|
||||
rebufferThresholdSecs: 1.0,
|
||||
resumeThresholdSecs: 3.0,
|
||||
isGenerating: false,
|
||||
genElapsed: 0,
|
||||
genPct: null,
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
const SAMPLE_RATE = 24_000;
|
||||
const DEFAULT_PREBUFFER_SECS = 2.0;
|
||||
const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4;
|
||||
const DEFAULT_RESUME_THRESHOLD_SECS = 1.5;
|
||||
const DEFAULT_PREBUFFER_SECS = 5.0;
|
||||
const DEFAULT_REBUFFER_THRESHOLD_SECS = 1.0;
|
||||
const DEFAULT_RESUME_THRESHOLD_SECS = 3.0;
|
||||
const MAX_ADAPTIVE_RESUME_SECS = 18.0;
|
||||
|
||||
interface GenerateOptions {
|
||||
text: string;
|
||||
@@ -104,6 +105,10 @@ export function useStreamingGeneration({
|
||||
const isAutoBufferingRef = useRef(false);
|
||||
const isUserPausedRef = useRef(false);
|
||||
const audioUrlRef = useRef<string | null>(null);
|
||||
const firstChunkSeenRef = useRef(false);
|
||||
const underrunCountRef = useRef(0);
|
||||
const totalAudioSamplesRef = useRef(0);
|
||||
const adaptiveResumeSecsRef = useRef(DEFAULT_RESUME_THRESHOLD_SECS);
|
||||
|
||||
const revokeCurrentUrl = useCallback(() => {
|
||||
if (audioUrlRef.current) {
|
||||
@@ -122,8 +127,12 @@ export function useStreamingGeneration({
|
||||
hasStartedPlaybackRef.current = false;
|
||||
isAutoBufferingRef.current = false;
|
||||
isUserPausedRef.current = false;
|
||||
firstChunkSeenRef.current = false;
|
||||
underrunCountRef.current = 0;
|
||||
totalAudioSamplesRef.current = 0;
|
||||
adaptiveResumeSecsRef.current = resumeThresholdSecs;
|
||||
setIsStreamPaused(false);
|
||||
}, []);
|
||||
}, [resumeThresholdSecs]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
@@ -158,10 +167,17 @@ export function useStreamingGeneration({
|
||||
if (!ctx) return;
|
||||
|
||||
chunksRef.current.push(chunk);
|
||||
totalAudioSamplesRef.current += chunk.length;
|
||||
|
||||
if (!firstChunkSeenRef.current) {
|
||||
firstChunkSeenRef.current = true;
|
||||
onLog("First audio chunk received");
|
||||
}
|
||||
|
||||
if (!hasStartedPlaybackRef.current) {
|
||||
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
|
||||
if (bufferedSecs >= prebufferSecs) {
|
||||
onLog(`Playback started after ${bufferedSecs.toFixed(1)}s buffered`);
|
||||
flushBufferedAudio();
|
||||
}
|
||||
return;
|
||||
@@ -171,18 +187,30 @@ export function useStreamingGeneration({
|
||||
if (isUserPausedRef.current) return;
|
||||
|
||||
const ahead = nextStartTimeRef.current - ctx.currentTime;
|
||||
if (ctx.state === "running" && ahead < rebufferThresholdSecs) {
|
||||
ctx.suspend().catch(() => {});
|
||||
isAutoBufferingRef.current = true;
|
||||
} else if (
|
||||
ctx.state === "suspended" &&
|
||||
isAutoBufferingRef.current &&
|
||||
ahead >= resumeThresholdSecs
|
||||
if (
|
||||
ctx.state === "running" &&
|
||||
!isAutoBufferingRef.current &&
|
||||
ahead < rebufferThresholdSecs
|
||||
) {
|
||||
isAutoBufferingRef.current = true;
|
||||
underrunCountRef.current += 1;
|
||||
adaptiveResumeSecsRef.current = Math.min(
|
||||
MAX_ADAPTIVE_RESUME_SECS,
|
||||
Math.max(resumeThresholdSecs, prebufferSecs + underrunCountRef.current * 2),
|
||||
);
|
||||
ctx.suspend().catch(() => {});
|
||||
onLog(
|
||||
`Buffer underrun ${underrunCountRef.current}; refilling to ${adaptiveResumeSecsRef.current.toFixed(1)}s`,
|
||||
);
|
||||
} else if (
|
||||
isAutoBufferingRef.current &&
|
||||
ahead >= adaptiveResumeSecsRef.current
|
||||
) {
|
||||
ctx.resume().catch(() => {});
|
||||
isAutoBufferingRef.current = false;
|
||||
ctx.resume().catch(() => {});
|
||||
onLog(`Buffer recovered with ${ahead.toFixed(1)}s queued`);
|
||||
}
|
||||
}, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
|
||||
}, [enqueue, flushBufferedAudio, onLog, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
|
||||
|
||||
const generate = useCallback(async (options: GenerateOptions) => {
|
||||
if (!options.text.trim()) return;
|
||||
@@ -239,6 +267,11 @@ export function useStreamingGeneration({
|
||||
type: "audio_chunk" | "complete" | "error" | "cancelled";
|
||||
data?: string;
|
||||
elapsed?: number;
|
||||
audio_secs?: number;
|
||||
realtime_factor?: number | null;
|
||||
chunks?: number;
|
||||
first_chunk_secs?: number | null;
|
||||
max_chunk_gap_secs?: number;
|
||||
message?: string;
|
||||
};
|
||||
|
||||
@@ -247,12 +280,26 @@ export function useStreamingGeneration({
|
||||
} else if (event.type === "complete") {
|
||||
if (!hasStartedPlaybackRef.current) {
|
||||
flushBufferedAudio();
|
||||
} else if (isAutoBufferingRef.current) {
|
||||
isAutoBufferingRef.current = false;
|
||||
audioCtxRef.current?.resume().catch(() => {});
|
||||
}
|
||||
const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE);
|
||||
const audioUrl = URL.createObjectURL(wavBlob);
|
||||
audioUrlRef.current = audioUrl;
|
||||
const kb = (wavBlob.size / 1024).toFixed(0);
|
||||
onLog(`Done in ${event.elapsed}s - ${kb} KB`);
|
||||
const audioSecs = event.audio_secs ?? totalAudioSamplesRef.current / SAMPLE_RATE;
|
||||
const realtimeFactor =
|
||||
event.realtime_factor ??
|
||||
(event.elapsed && event.elapsed > 0 ? audioSecs / event.elapsed : null);
|
||||
const speedText =
|
||||
realtimeFactor === null ? "" : ` - ${realtimeFactor.toFixed(2)}x realtime`;
|
||||
onLog(`Done in ${event.elapsed}s - ${audioSecs.toFixed(1)}s audio${speedText} - ${kb} KB`);
|
||||
if (event.chunks && event.first_chunk_secs !== undefined) {
|
||||
onLog(
|
||||
`Stream: first chunk ${event.first_chunk_secs}s, ${event.chunks} chunks, max gap ${event.max_chunk_gap_secs}s`,
|
||||
);
|
||||
}
|
||||
onSuccess(audioUrl);
|
||||
} else if (event.type === "cancelled") {
|
||||
throw new DOMException("Generation cancelled", "AbortError");
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start"
|
||||
},
|
||||
"dependencies": {
|
||||
|
||||
Reference in New Issue
Block a user