mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
perf: improve streaming generation pipeline
Add CUDA inference hot-path optimizations, safer attention fallback handling, and generation profiling hooks. Improve SSE streaming, browser buffering telemetry, and playback recovery while preserving default audio quality settings.
This commit is contained in:
@@ -0,0 +1,2 @@
|
|||||||
|
*.sh text eol=lf
|
||||||
|
*.py text eol=lf
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
@@ -7,6 +7,11 @@
|
|||||||
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
|
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
|
||||||
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
|
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
|
||||||
#
|
#
|
||||||
|
# Optional CUDA acceleration:
|
||||||
|
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
||||||
|
# Installs a matching third-party Windows flash-attn wheel when the CUDA venv
|
||||||
|
# uses Python 3.12, torch 2.6.0, and CUDA 12.4.
|
||||||
|
#
|
||||||
# The two modes maintain completely separate virtual environments so their torch
|
# The two modes maintain completely separate virtual environments so their torch
|
||||||
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
||||||
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
|
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
|
||||||
@@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
validate_flash_attn() {
|
||||||
|
uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
remove_broken_flash_attn() {
|
||||||
|
if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then
|
||||||
|
if ! validate_flash_attn; then
|
||||||
|
echo " Installed flash-attn is not usable in this environment; removing it."
|
||||||
|
uv pip uninstall flash-attn
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 2. Sync Python environment
|
# 2. Sync Python environment
|
||||||
# CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the
|
# CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the
|
||||||
@@ -65,6 +83,36 @@ if $CPU_MODE; then
|
|||||||
else
|
else
|
||||||
echo "--> Syncing CUDA Python environment (.venv)..."
|
echo "--> Syncing CUDA Python environment (.venv)..."
|
||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
|
remove_broken_flash_attn
|
||||||
|
|
||||||
|
if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "--> Checking optional FlashAttention wheel..."
|
||||||
|
|
||||||
|
if validate_flash_attn; then
|
||||||
|
echo " flash-attn already installed and importable."
|
||||||
|
else
|
||||||
|
PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")"
|
||||||
|
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
|
||||||
|
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
|
||||||
|
|
||||||
|
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
|
||||||
|
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
|
||||||
|
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..."
|
||||||
|
uv pip install "$FLASH_ATTN_WHEEL_URL"
|
||||||
|
if validate_flash_attn; then
|
||||||
|
echo " flash-attn import check passed."
|
||||||
|
else
|
||||||
|
echo " flash-attn import check failed; removing it and continuing with SDPA."
|
||||||
|
uv pip uninstall flash-attn
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
|
||||||
|
echo " Continuing with PyTorch SDPA attention."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
+336
-29
@@ -22,11 +22,13 @@ import asyncio
|
|||||||
import base64
|
import base64
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import types
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -211,8 +213,39 @@ def _init_processor():
|
|||||||
|
|
||||||
def _init_model(device: str):
|
def _init_model(device: str):
|
||||||
logger.info("Loading model on %s...", device)
|
logger.info("Loading model on %s...", device)
|
||||||
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
if device == "cuda":
|
||||||
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
|
torch.set_float32_matmul_precision("high")
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
|
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||||
|
logger.info(
|
||||||
|
"PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s",
|
||||||
|
torch.backends.cuda.flash_sdp_enabled(),
|
||||||
|
torch.backends.cuda.mem_efficient_sdp_enabled(),
|
||||||
|
torch.backends.cuda.math_sdp_enabled(),
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
|
||||||
|
if device == "cuda" and cuda_dtype == "fp16":
|
||||||
|
load_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||||
|
logger.info("Loading model weights with dtype %s", load_dtype)
|
||||||
|
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
|
||||||
|
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
|
||||||
|
if requested_attn_impl in {"eager", "sdpa"}:
|
||||||
|
attn_impl = requested_attn_impl
|
||||||
|
elif requested_attn_impl == "flash_attention_2":
|
||||||
|
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
|
||||||
|
else:
|
||||||
|
attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
|
||||||
|
logger.info("Using Transformers attention implementation: %s", attn_impl)
|
||||||
|
if device == "cuda" and not has_flash_attn:
|
||||||
|
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
|
||||||
|
|
||||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||||
VibeVoiceStreamingForConditionalGenerationInference,
|
VibeVoiceStreamingForConditionalGenerationInference,
|
||||||
@@ -225,9 +258,13 @@ def _init_model(device: str):
|
|||||||
device_map=device,
|
device_map=device,
|
||||||
attn_implementation=attn_impl,
|
attn_implementation=attn_impl,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
|
if attn_impl == "sdpa":
|
||||||
|
raise
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True
|
"Model load with %s failed (%s); falling back to sdpa",
|
||||||
|
attn_impl,
|
||||||
|
exc,
|
||||||
)
|
)
|
||||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
@@ -238,9 +275,164 @@ def _init_model(device: str):
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
|
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
|
||||||
|
_install_generation_optimizations(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _install_generation_optimizations(model: object) -> None:
|
||||||
|
"""Patch VibeVoice hot paths without changing model quality settings."""
|
||||||
|
|
||||||
|
def profile_enabled() -> bool:
|
||||||
|
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
|
||||||
|
|
||||||
|
def profile_sync() -> None:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def profile_record(self, key: str, elapsed: float) -> None:
|
||||||
|
stats = getattr(self, "_vibepod_profile", None)
|
||||||
|
if stats is None:
|
||||||
|
stats = {}
|
||||||
|
self._vibepod_profile = stats
|
||||||
|
bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0})
|
||||||
|
bucket["count"] += 1
|
||||||
|
bucket["seconds"] += elapsed
|
||||||
|
|
||||||
|
def timed_method(self, key: str, fn, *args, **kwargs):
|
||||||
|
if not profile_enabled():
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
profile_sync()
|
||||||
|
started = time.perf_counter()
|
||||||
|
result = fn(*args, **kwargs)
|
||||||
|
profile_sync()
|
||||||
|
profile_record(self, key, time.perf_counter() - started)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def prepare_noise_scheduler(self):
|
||||||
|
scheduler = self.model.noise_scheduler
|
||||||
|
cache_key = self.ddpm_inference_steps
|
||||||
|
cache = getattr(self, "_vibepod_scheduler_cache", {})
|
||||||
|
cached = cache.get(cache_key)
|
||||||
|
|
||||||
|
if cached is None:
|
||||||
|
scheduler.set_timesteps(self.ddpm_inference_steps)
|
||||||
|
cached = {
|
||||||
|
"num_inference_steps": scheduler.num_inference_steps,
|
||||||
|
"timesteps": scheduler.timesteps,
|
||||||
|
"sigmas": scheduler.sigmas,
|
||||||
|
}
|
||||||
|
cache[cache_key] = cached
|
||||||
|
self._vibepod_scheduler_cache = cache
|
||||||
|
else:
|
||||||
|
scheduler.num_inference_steps = cached["num_inference_steps"]
|
||||||
|
scheduler.timesteps = cached["timesteps"]
|
||||||
|
scheduler.sigmas = cached["sigmas"]
|
||||||
|
scheduler.model_outputs = [None] * scheduler.config.solver_order
|
||||||
|
scheduler.lower_order_nums = 0
|
||||||
|
scheduler._step_index = None
|
||||||
|
scheduler._begin_index = None
|
||||||
|
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0):
|
||||||
|
scheduler = prepare_noise_scheduler(self)
|
||||||
|
|
||||||
|
condition = torch.cat([condition, neg_condition], dim=0).to(
|
||||||
|
self.model.prediction_head.device
|
||||||
|
)
|
||||||
|
batch_size = condition.shape[0] // 2
|
||||||
|
speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition)
|
||||||
|
t_batch_cache_key = (
|
||||||
|
self.ddpm_inference_steps,
|
||||||
|
condition.device.type,
|
||||||
|
condition.device.index,
|
||||||
|
condition.dtype,
|
||||||
|
batch_size,
|
||||||
|
)
|
||||||
|
t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {})
|
||||||
|
t_batches = t_batch_cache.get(t_batch_cache_key)
|
||||||
|
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
||||||
|
t_batches = [
|
||||||
|
t.repeat(condition.shape[0]).to(
|
||||||
|
device=condition.device, dtype=condition.dtype
|
||||||
|
)
|
||||||
|
for t in scheduler.timesteps
|
||||||
|
]
|
||||||
|
t_batch_cache[t_batch_cache_key] = t_batches
|
||||||
|
self._vibepod_t_batch_cache = t_batch_cache
|
||||||
|
|
||||||
|
for t, t_batch in zip(scheduler.timesteps, t_batches):
|
||||||
|
if batch_size == 1:
|
||||||
|
combined = speech.expand(condition.shape[0], -1)
|
||||||
|
else:
|
||||||
|
combined = torch.cat([speech, speech], dim=0)
|
||||||
|
if profile_enabled():
|
||||||
|
profile_sync()
|
||||||
|
started = time.perf_counter()
|
||||||
|
eps = self.model.prediction_head(combined, t_batch, condition=condition)
|
||||||
|
if profile_enabled():
|
||||||
|
profile_sync()
|
||||||
|
profile_record(self, "diffusion_prediction_head", time.perf_counter() - started)
|
||||||
|
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
|
||||||
|
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||||
|
if profile_enabled():
|
||||||
|
started = time.perf_counter()
|
||||||
|
speech = scheduler.step(guided_eps, t, speech).prev_sample
|
||||||
|
if profile_enabled():
|
||||||
|
profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started)
|
||||||
|
|
||||||
|
return speech
|
||||||
|
|
||||||
|
forward_lm = model.forward_lm
|
||||||
|
forward_tts_lm = model.forward_tts_lm
|
||||||
|
acoustic_decode = model.model.acoustic_tokenizer.decode
|
||||||
|
|
||||||
|
def forward_lm_profiled(*args, **kwargs):
|
||||||
|
return timed_method(model, "forward_lm", forward_lm, *args, **kwargs)
|
||||||
|
|
||||||
|
def forward_tts_lm_profiled(*args, **kwargs):
|
||||||
|
return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs)
|
||||||
|
|
||||||
|
def acoustic_decode_profiled(*args, **kwargs):
|
||||||
|
return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs)
|
||||||
|
|
||||||
|
model.forward_lm = forward_lm_profiled
|
||||||
|
model.forward_tts_lm = forward_tts_lm_profiled
|
||||||
|
model.model.acoustic_tokenizer.decode = acoustic_decode_profiled
|
||||||
|
model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model)
|
||||||
|
logger.info("Installed VibeVoice generation hot-path optimizations.")
|
||||||
|
|
||||||
|
|
||||||
|
def _model_float_dtype() -> torch.dtype:
|
||||||
|
try:
|
||||||
|
return next(_model.parameters()).dtype
|
||||||
|
except StopIteration:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object:
|
||||||
|
if torch.is_tensor(value):
|
||||||
|
if torch.is_floating_point(value):
|
||||||
|
return value.to(device=device, dtype=dtype)
|
||||||
|
return value.to(device=device)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for k in list(value.keys()):
|
||||||
|
value[k] = _move_cached_prompt(value[k], device, dtype)
|
||||||
|
return value
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_move_cached_prompt(v, device, dtype) for v in value]
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
||||||
|
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
||||||
|
value.key_cache = [
|
||||||
|
_move_cached_prompt(t, device, dtype) for t in value.key_cache
|
||||||
|
]
|
||||||
|
value.value_cache = [
|
||||||
|
_move_cached_prompt(t, device, dtype) for t in value.value_cache
|
||||||
|
]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _load_voice_presets(device: str) -> dict[str, object]:
|
def _load_voice_presets(device: str) -> dict[str, object]:
|
||||||
presets = {}
|
presets = {}
|
||||||
for name, filename in EN_VOICES.items():
|
for name, filename in EN_VOICES.items():
|
||||||
@@ -273,9 +465,9 @@ def _load_model_sync() -> None:
|
|||||||
is_cpu = _device == "cpu"
|
is_cpu = _device == "cpu"
|
||||||
_config["device"] = _device
|
_config["device"] = _device
|
||||||
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
||||||
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0)
|
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 6.0 if is_cpu else 5.0)
|
||||||
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4)
|
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0)
|
||||||
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5)
|
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0)
|
||||||
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
|
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
|
||||||
|
|
||||||
_processor = _init_processor()
|
_processor = _init_processor()
|
||||||
@@ -364,10 +556,15 @@ def _sync_generate(
|
|||||||
raise RuntimeError("Generation cancelled.")
|
raise RuntimeError("Generation cancelled.")
|
||||||
|
|
||||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||||
voice_preset = copy.deepcopy(_voice_presets[speaker])
|
model_dtype = _model_float_dtype()
|
||||||
|
voice_preset = _move_cached_prompt(
|
||||||
|
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
|
||||||
|
)
|
||||||
|
|
||||||
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
|
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
|
||||||
_model.set_ddpm_inference_steps(num_steps=steps)
|
_model.set_ddpm_inference_steps(num_steps=steps)
|
||||||
|
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
|
||||||
|
_model._vibepod_profile = {}
|
||||||
|
|
||||||
inputs = _processor.process_input_with_cached_prompt(
|
inputs = _processor.process_input_with_cached_prompt(
|
||||||
text=req.text,
|
text=req.text,
|
||||||
@@ -380,19 +577,20 @@ def _sync_generate(
|
|||||||
if torch.is_tensor(v):
|
if torch.is_tensor(v):
|
||||||
inputs[k] = v.to(_device)
|
inputs[k] = v.to(_device)
|
||||||
|
|
||||||
outputs = _model.generate(
|
with torch.inference_mode():
|
||||||
**inputs,
|
_model.generate(
|
||||||
max_new_tokens=None,
|
**inputs,
|
||||||
cfg_scale=req.cfg_scale,
|
max_new_tokens=None,
|
||||||
tokenizer=_processor.tokenizer,
|
cfg_scale=req.cfg_scale,
|
||||||
generation_config={"do_sample": False},
|
tokenizer=_processor.tokenizer,
|
||||||
verbose=True,
|
generation_config={"do_sample": False},
|
||||||
all_prefilled_outputs=copy.deepcopy(voice_preset),
|
verbose=False,
|
||||||
audio_streamer=streamer,
|
show_progress_bar=False,
|
||||||
)
|
return_speech=False,
|
||||||
|
stop_check_fn=cancel_event.is_set if cancel_event else None,
|
||||||
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
|
all_prefilled_outputs=voice_preset,
|
||||||
raise ValueError("Model returned no audio output.")
|
audio_streamer=streamer,
|
||||||
|
)
|
||||||
|
|
||||||
return speaker
|
return speaker
|
||||||
|
|
||||||
@@ -401,6 +599,24 @@ def _sse(event: dict) -> str:
|
|||||||
return f"data: {json.dumps(event)}\n\n"
|
return f"data: {json.dumps(event)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
|
||||||
|
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
|
||||||
|
return None
|
||||||
|
stats = getattr(_model, "_vibepod_profile", None)
|
||||||
|
if not stats:
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
key: {
|
||||||
|
"count": value["count"],
|
||||||
|
"seconds": round(value["seconds"], 3),
|
||||||
|
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3)
|
||||||
|
if value["count"]
|
||||||
|
else 0.0,
|
||||||
|
}
|
||||||
|
for key, value in sorted(stats.items())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||||
if _model_status != "online":
|
if _model_status != "online":
|
||||||
@@ -417,20 +633,58 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_stream() -> AsyncGenerator[str, None]:
|
async def event_stream() -> AsyncGenerator[str, None]:
|
||||||
from vibevoice.modular.streamer import AsyncAudioStreamer
|
class NonBlockingAudioStreamer:
|
||||||
|
"""Async streamer that keeps GPU->CPU copies out of the model thread."""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int, stop_signal: object = None) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.stop_signal = stop_signal
|
||||||
|
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
||||||
|
self.finished_flags = [False for _ in range(batch_size)]
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None:
|
||||||
|
for i, sample_idx in enumerate(sample_indices):
|
||||||
|
idx = sample_idx.item()
|
||||||
|
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||||
|
self.loop.call_soon_threadsafe(
|
||||||
|
self.audio_queues[idx].put_nowait,
|
||||||
|
audio_chunks[i].detach(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def end(self, sample_indices: Optional[torch.Tensor] = None) -> None:
|
||||||
|
if sample_indices is None:
|
||||||
|
indices_to_end = range(self.batch_size)
|
||||||
|
else:
|
||||||
|
indices_to_end = [
|
||||||
|
s.item() if torch.is_tensor(s) else s for s in sample_indices
|
||||||
|
]
|
||||||
|
for idx in indices_to_end:
|
||||||
|
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||||
|
self.loop.call_soon_threadsafe(
|
||||||
|
self.audio_queues[idx].put_nowait, self.stop_signal
|
||||||
|
)
|
||||||
|
self.finished_flags[idx] = True
|
||||||
|
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
streamer = AsyncAudioStreamer(batch_size=1)
|
streamer = NonBlockingAudioStreamer(batch_size=1)
|
||||||
cancel_event = threading.Event()
|
cancel_event = threading.Event()
|
||||||
|
|
||||||
accum_size = max(1, _config["chunk_accum"])
|
accum_size = max(1, _config["chunk_accum"])
|
||||||
accumulated_chunks = []
|
accumulated_chunks = []
|
||||||
|
chunk_count = 0
|
||||||
|
audio_samples = 0
|
||||||
|
first_chunk_at: Optional[float] = None
|
||||||
|
last_chunk_at: Optional[float] = None
|
||||||
|
max_chunk_gap = 0.0
|
||||||
|
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||||
|
|
||||||
async with _generation_lock:
|
async with _generation_lock:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
future = loop.run_in_executor(
|
future = loop.run_in_executor(
|
||||||
None, functools.partial(_sync_generate, req, streamer, cancel_event)
|
None, functools.partial(_sync_generate, req, streamer, cancel_event)
|
||||||
)
|
)
|
||||||
|
future.add_done_callback(lambda _: streamer.end())
|
||||||
|
|
||||||
# Drain audio chunks as they arrive from the diffusion head.
|
# Drain audio chunks as they arrive from the diffusion head.
|
||||||
# stop_signal=None is the default sentinel that ends the queue.
|
# stop_signal=None is the default sentinel that ends the queue.
|
||||||
@@ -454,17 +708,45 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
if chunk is None: # stop signal
|
if chunk is None: # stop signal
|
||||||
break
|
break
|
||||||
|
|
||||||
accumulated_chunks.append(chunk.detach().cpu().float())
|
accumulated_chunks.append(chunk.detach())
|
||||||
|
|
||||||
if len(accumulated_chunks) >= accum_size:
|
if len(accumulated_chunks) >= accum_size:
|
||||||
combined = torch.cat(accumulated_chunks, dim=0)
|
now = time.monotonic()
|
||||||
|
if first_chunk_at is None:
|
||||||
|
first_chunk_at = now
|
||||||
|
if last_chunk_at is not None:
|
||||||
|
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
|
||||||
|
last_chunk_at = now
|
||||||
|
|
||||||
|
combined = (
|
||||||
|
torch.cat(accumulated_chunks, dim=0)
|
||||||
|
.detach()
|
||||||
|
.to("cpu", dtype=torch.float32)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
chunk_count += 1
|
||||||
|
audio_samples += combined.numel()
|
||||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||||
accumulated_chunks = []
|
accumulated_chunks = []
|
||||||
|
|
||||||
# Flush any remaining chunks
|
# Flush any remaining chunks
|
||||||
if accumulated_chunks:
|
if accumulated_chunks:
|
||||||
combined = torch.cat(accumulated_chunks, dim=0)
|
now = time.monotonic()
|
||||||
|
if first_chunk_at is None:
|
||||||
|
first_chunk_at = now
|
||||||
|
if last_chunk_at is not None:
|
||||||
|
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
|
||||||
|
last_chunk_at = now
|
||||||
|
|
||||||
|
combined = (
|
||||||
|
torch.cat(accumulated_chunks, dim=0)
|
||||||
|
.detach()
|
||||||
|
.to("cpu", dtype=torch.float32)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
chunk_count += 1
|
||||||
|
audio_samples += combined.numel()
|
||||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||||
|
|
||||||
@@ -479,17 +761,42 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
yield _sse(
|
yield _sse(
|
||||||
{
|
{
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"message": "Internal server error during generation.",
|
"message": f"Generation failed: {exc}",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
elapsed = round(time.monotonic() - start, 1)
|
elapsed = round(time.monotonic() - start, 1)
|
||||||
|
audio_secs = audio_samples / SAMPLE_RATE
|
||||||
|
realtime_factor = audio_secs / elapsed if elapsed > 0 else None
|
||||||
|
profile = _generation_profile()
|
||||||
|
if profile is not None:
|
||||||
|
logger.info("Generation profile: %s", profile)
|
||||||
logger.info("Generation complete in %.1fs", elapsed)
|
logger.info("Generation complete in %.1fs", elapsed)
|
||||||
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker})
|
complete_event = {
|
||||||
|
"type": "complete",
|
||||||
|
"elapsed": elapsed,
|
||||||
|
"speaker": speaker,
|
||||||
|
"audio_secs": round(audio_secs, 2),
|
||||||
|
"realtime_factor": round(realtime_factor, 3)
|
||||||
|
if realtime_factor is not None
|
||||||
|
else None,
|
||||||
|
"chunks": chunk_count,
|
||||||
|
"first_chunk_secs": round(first_chunk_at - start, 2)
|
||||||
|
if first_chunk_at is not None
|
||||||
|
else None,
|
||||||
|
"max_chunk_gap_secs": round(max_chunk_gap, 2),
|
||||||
|
}
|
||||||
|
if profile is not None:
|
||||||
|
complete_event["profile"] = profile
|
||||||
|
yield _sse(complete_event)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_stream(),
|
event_stream(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
headers={
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
import { NextRequest, NextResponse } from "next/server";
|
import { NextRequest, NextResponse } from "next/server";
|
||||||
|
|
||||||
|
export const dynamic = "force-dynamic";
|
||||||
|
export const runtime = "nodejs";
|
||||||
|
|
||||||
export async function POST(request: NextRequest) {
|
export async function POST(request: NextRequest) {
|
||||||
const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000";
|
const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000";
|
||||||
|
|
||||||
@@ -24,6 +27,7 @@ export async function POST(request: NextRequest) {
|
|||||||
cfg_scale: body.cfg_scale ?? 1.5,
|
cfg_scale: body.cfg_scale ?? 1.5,
|
||||||
inference_steps: body.inference_steps ?? 10,
|
inference_steps: body.inference_steps ?? 10,
|
||||||
}),
|
}),
|
||||||
|
signal: request.signal,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!upstream.ok) {
|
if (!upstream.ok) {
|
||||||
@@ -36,8 +40,9 @@ export async function POST(request: NextRequest) {
|
|||||||
status: 200,
|
status: 200,
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "text/event-stream",
|
"Content-Type": "text/event-stream",
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache, no-transform",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ export async function GET() {
|
|||||||
message: data.message,
|
message: data.message,
|
||||||
progress: data.progress ?? null,
|
progress: data.progress ?? null,
|
||||||
voices: data.voices ?? [],
|
voices: data.voices ?? [],
|
||||||
|
config: data.config ?? null,
|
||||||
},
|
},
|
||||||
COMMON_OPTIONS
|
COMMON_OPTIONS
|
||||||
);
|
);
|
||||||
|
|||||||
+3
-3
@@ -130,9 +130,9 @@ const initialState: AppState = {
|
|||||||
speaker: "carter",
|
speaker: "carter",
|
||||||
cfgScale: 1.5,
|
cfgScale: 1.5,
|
||||||
inferenceSteps: 10,
|
inferenceSteps: 10,
|
||||||
prebufferSecs: 2.0,
|
prebufferSecs: 5.0,
|
||||||
rebufferThresholdSecs: 0.4,
|
rebufferThresholdSecs: 1.0,
|
||||||
resumeThresholdSecs: 1.5,
|
resumeThresholdSecs: 3.0,
|
||||||
isGenerating: false,
|
isGenerating: false,
|
||||||
genElapsed: 0,
|
genElapsed: 0,
|
||||||
genPct: null,
|
genPct: null,
|
||||||
|
|||||||
@@ -3,9 +3,10 @@
|
|||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
const SAMPLE_RATE = 24_000;
|
const SAMPLE_RATE = 24_000;
|
||||||
const DEFAULT_PREBUFFER_SECS = 2.0;
|
const DEFAULT_PREBUFFER_SECS = 5.0;
|
||||||
const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4;
|
const DEFAULT_REBUFFER_THRESHOLD_SECS = 1.0;
|
||||||
const DEFAULT_RESUME_THRESHOLD_SECS = 1.5;
|
const DEFAULT_RESUME_THRESHOLD_SECS = 3.0;
|
||||||
|
const MAX_ADAPTIVE_RESUME_SECS = 18.0;
|
||||||
|
|
||||||
interface GenerateOptions {
|
interface GenerateOptions {
|
||||||
text: string;
|
text: string;
|
||||||
@@ -104,6 +105,10 @@ export function useStreamingGeneration({
|
|||||||
const isAutoBufferingRef = useRef(false);
|
const isAutoBufferingRef = useRef(false);
|
||||||
const isUserPausedRef = useRef(false);
|
const isUserPausedRef = useRef(false);
|
||||||
const audioUrlRef = useRef<string | null>(null);
|
const audioUrlRef = useRef<string | null>(null);
|
||||||
|
const firstChunkSeenRef = useRef(false);
|
||||||
|
const underrunCountRef = useRef(0);
|
||||||
|
const totalAudioSamplesRef = useRef(0);
|
||||||
|
const adaptiveResumeSecsRef = useRef(DEFAULT_RESUME_THRESHOLD_SECS);
|
||||||
|
|
||||||
const revokeCurrentUrl = useCallback(() => {
|
const revokeCurrentUrl = useCallback(() => {
|
||||||
if (audioUrlRef.current) {
|
if (audioUrlRef.current) {
|
||||||
@@ -122,8 +127,12 @@ export function useStreamingGeneration({
|
|||||||
hasStartedPlaybackRef.current = false;
|
hasStartedPlaybackRef.current = false;
|
||||||
isAutoBufferingRef.current = false;
|
isAutoBufferingRef.current = false;
|
||||||
isUserPausedRef.current = false;
|
isUserPausedRef.current = false;
|
||||||
|
firstChunkSeenRef.current = false;
|
||||||
|
underrunCountRef.current = 0;
|
||||||
|
totalAudioSamplesRef.current = 0;
|
||||||
|
adaptiveResumeSecsRef.current = resumeThresholdSecs;
|
||||||
setIsStreamPaused(false);
|
setIsStreamPaused(false);
|
||||||
}, []);
|
}, [resumeThresholdSecs]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return () => {
|
return () => {
|
||||||
@@ -158,10 +167,17 @@ export function useStreamingGeneration({
|
|||||||
if (!ctx) return;
|
if (!ctx) return;
|
||||||
|
|
||||||
chunksRef.current.push(chunk);
|
chunksRef.current.push(chunk);
|
||||||
|
totalAudioSamplesRef.current += chunk.length;
|
||||||
|
|
||||||
|
if (!firstChunkSeenRef.current) {
|
||||||
|
firstChunkSeenRef.current = true;
|
||||||
|
onLog("First audio chunk received");
|
||||||
|
}
|
||||||
|
|
||||||
if (!hasStartedPlaybackRef.current) {
|
if (!hasStartedPlaybackRef.current) {
|
||||||
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
|
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
|
||||||
if (bufferedSecs >= prebufferSecs) {
|
if (bufferedSecs >= prebufferSecs) {
|
||||||
|
onLog(`Playback started after ${bufferedSecs.toFixed(1)}s buffered`);
|
||||||
flushBufferedAudio();
|
flushBufferedAudio();
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -171,18 +187,30 @@ export function useStreamingGeneration({
|
|||||||
if (isUserPausedRef.current) return;
|
if (isUserPausedRef.current) return;
|
||||||
|
|
||||||
const ahead = nextStartTimeRef.current - ctx.currentTime;
|
const ahead = nextStartTimeRef.current - ctx.currentTime;
|
||||||
if (ctx.state === "running" && ahead < rebufferThresholdSecs) {
|
if (
|
||||||
ctx.suspend().catch(() => {});
|
ctx.state === "running" &&
|
||||||
isAutoBufferingRef.current = true;
|
!isAutoBufferingRef.current &&
|
||||||
} else if (
|
ahead < rebufferThresholdSecs
|
||||||
ctx.state === "suspended" &&
|
) {
|
||||||
isAutoBufferingRef.current &&
|
isAutoBufferingRef.current = true;
|
||||||
ahead >= resumeThresholdSecs
|
underrunCountRef.current += 1;
|
||||||
|
adaptiveResumeSecsRef.current = Math.min(
|
||||||
|
MAX_ADAPTIVE_RESUME_SECS,
|
||||||
|
Math.max(resumeThresholdSecs, prebufferSecs + underrunCountRef.current * 2),
|
||||||
|
);
|
||||||
|
ctx.suspend().catch(() => {});
|
||||||
|
onLog(
|
||||||
|
`Buffer underrun ${underrunCountRef.current}; refilling to ${adaptiveResumeSecsRef.current.toFixed(1)}s`,
|
||||||
|
);
|
||||||
|
} else if (
|
||||||
|
isAutoBufferingRef.current &&
|
||||||
|
ahead >= adaptiveResumeSecsRef.current
|
||||||
) {
|
) {
|
||||||
ctx.resume().catch(() => {});
|
|
||||||
isAutoBufferingRef.current = false;
|
isAutoBufferingRef.current = false;
|
||||||
|
ctx.resume().catch(() => {});
|
||||||
|
onLog(`Buffer recovered with ${ahead.toFixed(1)}s queued`);
|
||||||
}
|
}
|
||||||
}, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
|
}, [enqueue, flushBufferedAudio, onLog, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
|
||||||
|
|
||||||
const generate = useCallback(async (options: GenerateOptions) => {
|
const generate = useCallback(async (options: GenerateOptions) => {
|
||||||
if (!options.text.trim()) return;
|
if (!options.text.trim()) return;
|
||||||
@@ -239,6 +267,11 @@ export function useStreamingGeneration({
|
|||||||
type: "audio_chunk" | "complete" | "error" | "cancelled";
|
type: "audio_chunk" | "complete" | "error" | "cancelled";
|
||||||
data?: string;
|
data?: string;
|
||||||
elapsed?: number;
|
elapsed?: number;
|
||||||
|
audio_secs?: number;
|
||||||
|
realtime_factor?: number | null;
|
||||||
|
chunks?: number;
|
||||||
|
first_chunk_secs?: number | null;
|
||||||
|
max_chunk_gap_secs?: number;
|
||||||
message?: string;
|
message?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -247,12 +280,26 @@ export function useStreamingGeneration({
|
|||||||
} else if (event.type === "complete") {
|
} else if (event.type === "complete") {
|
||||||
if (!hasStartedPlaybackRef.current) {
|
if (!hasStartedPlaybackRef.current) {
|
||||||
flushBufferedAudio();
|
flushBufferedAudio();
|
||||||
|
} else if (isAutoBufferingRef.current) {
|
||||||
|
isAutoBufferingRef.current = false;
|
||||||
|
audioCtxRef.current?.resume().catch(() => {});
|
||||||
}
|
}
|
||||||
const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE);
|
const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE);
|
||||||
const audioUrl = URL.createObjectURL(wavBlob);
|
const audioUrl = URL.createObjectURL(wavBlob);
|
||||||
audioUrlRef.current = audioUrl;
|
audioUrlRef.current = audioUrl;
|
||||||
const kb = (wavBlob.size / 1024).toFixed(0);
|
const kb = (wavBlob.size / 1024).toFixed(0);
|
||||||
onLog(`Done in ${event.elapsed}s - ${kb} KB`);
|
const audioSecs = event.audio_secs ?? totalAudioSamplesRef.current / SAMPLE_RATE;
|
||||||
|
const realtimeFactor =
|
||||||
|
event.realtime_factor ??
|
||||||
|
(event.elapsed && event.elapsed > 0 ? audioSecs / event.elapsed : null);
|
||||||
|
const speedText =
|
||||||
|
realtimeFactor === null ? "" : ` - ${realtimeFactor.toFixed(2)}x realtime`;
|
||||||
|
onLog(`Done in ${event.elapsed}s - ${audioSecs.toFixed(1)}s audio${speedText} - ${kb} KB`);
|
||||||
|
if (event.chunks && event.first_chunk_secs !== undefined) {
|
||||||
|
onLog(
|
||||||
|
`Stream: first chunk ${event.first_chunk_secs}s, ${event.chunks} chunks, max gap ${event.max_chunk_gap_secs}s`,
|
||||||
|
);
|
||||||
|
}
|
||||||
onSuccess(audioUrl);
|
onSuccess(audioUrl);
|
||||||
} else if (event.type === "cancelled") {
|
} else if (event.type === "cancelled") {
|
||||||
throw new DOMException("Generation cancelled", "AbortError");
|
throw new DOMException("Generation cancelled", "AbortError");
|
||||||
|
|||||||
+1
-1
@@ -4,7 +4,7 @@
|
|||||||
"private": true,
|
"private": true,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "next dev --turbopack",
|
"dev": "next dev --turbopack",
|
||||||
"build": "next build --turbopack",
|
"build": "next build",
|
||||||
"start": "next start"
|
"start": "next start"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
|||||||
Reference in New Issue
Block a user