mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
Merge pull request #15 from JezzWTF/improve-docs-and-maintainer-notes-9165053560558121838
Improve codebase documentation and maintainer notes
This commit is contained in:
+161
-50
@@ -1,21 +1,38 @@
|
|||||||
"""
|
"""
|
||||||
VibePod — VibeVoice FastAPI TTS Server
|
VibePod — VibeVoice FastAPI TTS Server
|
||||||
|
|
||||||
Startup sequence (background thread):
|
This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model,
|
||||||
1. Download model weights if not cached -> status: downloading
|
optimized for real-time streaming on both CPU and NVIDIA GPU hardware.
|
||||||
2. Download voice preset .pt files -> status: loading
|
|
||||||
3. Load processor + model into memory -> status: loading
|
|
||||||
4. Pre-load all voice tensors -> status: loading
|
|
||||||
-> Server ready -> status: online
|
|
||||||
|
|
||||||
Generation flow:
|
MAINTAINER GUIDE / FILE MAP:
|
||||||
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
|
- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars.
|
||||||
ends with {type:"complete"}
|
- Global State: Thread-safe storage for the loaded model, processor, and server status.
|
||||||
|
- Background Model Loader: Logic for downloading weights and initializing the model with optimizations.
|
||||||
|
- VibeVoice Patches: Performance-critical overrides for VibeVoice internals (hot-paths).
|
||||||
|
- FastAPI Application: SSE-based generation endpoint and health/status polling.
|
||||||
|
- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network.
|
||||||
|
|
||||||
Device selection:
|
RUNTIME CONFIGURATION (Environment Variables):
|
||||||
Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh).
|
- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset).
|
||||||
Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available).
|
- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU).
|
||||||
If unset, the server auto-detects: CUDA if available, otherwise CPU.
|
- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend).
|
||||||
|
- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill.
|
||||||
|
- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback.
|
||||||
|
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA).
|
||||||
|
- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging.
|
||||||
|
|
||||||
|
CPU-SPECIFIC OPTIMIZATIONS:
|
||||||
|
- VIBEPOD_CPU_THREADS: Number of intra-op threads (defaults to logical core count / 2).
|
||||||
|
- VIBEPOD_CPU_INTEROP_THREADS: Number of inter-op threads (default: 1).
|
||||||
|
- VIBEPOD_CPU_MKLDNN: Set to '0' to disable MKLDNN (default: 1).
|
||||||
|
- VIBEPOD_CPU_BF16: Set to '1' to force bfloat16, '0' for float32.
|
||||||
|
- VIBEPOD_ASYNC_DECODE: Set to '1' to overlap decoding with inference on a separate thread (default: 1).
|
||||||
|
- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization.
|
||||||
|
- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS).
|
||||||
|
|
||||||
|
CUDA-SPECIFIC OPTIMIZATIONS:
|
||||||
|
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'.
|
||||||
|
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -50,9 +67,7 @@ MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
|||||||
SAMPLE_RATE = 24_000
|
SAMPLE_RATE = 24_000
|
||||||
|
|
||||||
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
||||||
VOICE_BASE_URL = (
|
VOICE_BASE_URL = "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
||||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
|
||||||
)
|
|
||||||
|
|
||||||
EN_VOICES: dict[str, str] = {
|
EN_VOICES: dict[str, str] = {
|
||||||
"carter": "en-Carter_man.pt",
|
"carter": "en-Carter_man.pt",
|
||||||
@@ -77,7 +92,10 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
|
|||||||
|
|
||||||
|
|
||||||
def _resolve_device() -> str:
|
def _resolve_device() -> str:
|
||||||
"""Resolve the target device from env var or auto-detect."""
|
"""
|
||||||
|
Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment
|
||||||
|
variable, falling back to CUDA if available, otherwise CPU.
|
||||||
|
"""
|
||||||
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
||||||
if env in ("cpu", "cuda"):
|
if env in ("cpu", "cuda"):
|
||||||
if env == "cuda" and not torch.cuda.is_available():
|
if env == "cuda" and not torch.cuda.is_available():
|
||||||
@@ -94,6 +112,7 @@ def _resolve_device() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _env_int(name: str, default: int) -> int:
|
def _env_int(name: str, default: int) -> int:
|
||||||
|
"""Helper to read an integer environment variable with a fallback default."""
|
||||||
raw = os.environ.get(name, "").strip()
|
raw = os.environ.get(name, "").strip()
|
||||||
if not raw:
|
if not raw:
|
||||||
return default
|
return default
|
||||||
@@ -105,6 +124,7 @@ def _env_int(name: str, default: int) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _env_float(name: str, default: float) -> float:
|
def _env_float(name: str, default: float) -> float:
|
||||||
|
"""Helper to read a float environment variable with a fallback default."""
|
||||||
raw = os.environ.get(name, "").strip()
|
raw = os.environ.get(name, "").strip()
|
||||||
if not raw:
|
if not raw:
|
||||||
return default
|
return default
|
||||||
@@ -125,8 +145,14 @@ def _cpu_supports_bf16() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _configure_cpu_runtime() -> dict[str, object]:
|
def _configure_cpu_runtime() -> dict[str, object]:
|
||||||
|
"""
|
||||||
|
Configure PyTorch's CPU execution engine, including thread counts and
|
||||||
|
MKLDNN acceleration.
|
||||||
|
"""
|
||||||
logical_cpus = os.cpu_count() or 1
|
logical_cpus = os.cpu_count() or 1
|
||||||
default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
default_threads = (
|
||||||
|
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
||||||
|
)
|
||||||
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
|
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
|
||||||
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
||||||
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||||
@@ -202,7 +228,9 @@ def _is_model_cached() -> bool:
|
|||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
snapshot_download(
|
||||||
|
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@@ -211,7 +239,9 @@ def _is_model_cached() -> bool:
|
|||||||
def _download_model() -> None:
|
def _download_model() -> None:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||||
|
"HUGGINGFACE_TOKEN"
|
||||||
|
)
|
||||||
DlTqdm = _make_dl_tqdm()
|
DlTqdm = _make_dl_tqdm()
|
||||||
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
@@ -238,6 +268,9 @@ def _download_voices() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _init_processor():
|
def _init_processor():
|
||||||
|
"""
|
||||||
|
Initialize the VibeVoiceStreamingProcessor from the model repository.
|
||||||
|
"""
|
||||||
logger.info("Loading processor...")
|
logger.info("Loading processor...")
|
||||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||||
VibeVoiceStreamingProcessor,
|
VibeVoiceStreamingProcessor,
|
||||||
@@ -247,6 +280,10 @@ def _init_processor():
|
|||||||
|
|
||||||
|
|
||||||
def _init_model(device: str):
|
def _init_model(device: str):
|
||||||
|
"""
|
||||||
|
Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and
|
||||||
|
apply VibePod-specific performance optimizations.
|
||||||
|
"""
|
||||||
logger.info("Loading model on %s...", device)
|
logger.info("Loading model on %s...", device)
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
@@ -285,7 +322,9 @@ def _init_model(device: str):
|
|||||||
logger.info("AVX512_BF16 detected — loading model in bfloat16")
|
logger.info("AVX512_BF16 detected — loading model in bfloat16")
|
||||||
else:
|
else:
|
||||||
load_dtype = torch.float32
|
load_dtype = torch.float32
|
||||||
logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)")
|
logger.info(
|
||||||
|
"No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)"
|
||||||
|
)
|
||||||
logger.info("Loading model weights with dtype %s", load_dtype)
|
logger.info("Loading model weights with dtype %s", load_dtype)
|
||||||
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
|
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
|
||||||
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
|
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
|
||||||
@@ -294,7 +333,9 @@ def _init_model(device: str):
|
|||||||
elif requested_attn_impl == "flash_attention_2":
|
elif requested_attn_impl == "flash_attention_2":
|
||||||
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
|
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
|
||||||
else:
|
else:
|
||||||
attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
|
attn_impl = (
|
||||||
|
"flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
|
||||||
|
)
|
||||||
logger.info("Using Transformers attention implementation: %s", attn_impl)
|
logger.info("Using Transformers attention implementation: %s", attn_impl)
|
||||||
if device == "cuda" and not has_flash_attn:
|
if device == "cuda" and not has_flash_attn:
|
||||||
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
|
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
|
||||||
@@ -338,7 +379,10 @@ def _init_model(device: str):
|
|||||||
|
|
||||||
|
|
||||||
def _apply_cpu_optimizations(model: object) -> object:
|
def _apply_cpu_optimizations(model: object) -> object:
|
||||||
"""Apply optional post-load CPU optimizations. Returns (possibly new) model object."""
|
"""
|
||||||
|
Apply experimental CPU performance features like dynamic INT8 quantization
|
||||||
|
or torch.compile if enabled via environment variables.
|
||||||
|
"""
|
||||||
|
|
||||||
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
|
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
|
||||||
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
|
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
|
||||||
@@ -387,10 +431,19 @@ def _apply_cpu_optimizations(model: object) -> object:
|
|||||||
if hasattr(model, "model"):
|
if hasattr(model, "model"):
|
||||||
inner = model.model
|
inner = model.model
|
||||||
if hasattr(inner, "prediction_head"):
|
if hasattr(inner, "prediction_head"):
|
||||||
_compile_targets.append(("prediction_head", inner, "prediction_head", False))
|
|
||||||
if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
|
|
||||||
_compile_targets.append(
|
_compile_targets.append(
|
||||||
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
|
("prediction_head", inner, "prediction_head", False)
|
||||||
|
)
|
||||||
|
if hasattr(inner, "acoustic_tokenizer") and hasattr(
|
||||||
|
inner.acoustic_tokenizer, "decode"
|
||||||
|
):
|
||||||
|
_compile_targets.append(
|
||||||
|
(
|
||||||
|
"acoustic_tokenizer.decode",
|
||||||
|
inner.acoustic_tokenizer,
|
||||||
|
"decode",
|
||||||
|
False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for label, obj, attr, dynamic in _compile_targets:
|
for label, obj, attr, dynamic in _compile_targets:
|
||||||
@@ -404,13 +457,20 @@ def _apply_cpu_optimizations(model: object) -> object:
|
|||||||
setattr(obj, attr, compiled)
|
setattr(obj, attr, compiled)
|
||||||
logger.info(" compiled: %s", label)
|
logger.info(" compiled: %s", label)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(" torch.compile failed for %s: %s — skipping", label, exc)
|
logger.warning(
|
||||||
|
" torch.compile failed for %s: %s — skipping", label, exc
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _install_generation_optimizations(model: object) -> None:
|
def _install_generation_optimizations(model: object) -> None:
|
||||||
"""Patch VibeVoice hot paths without changing model quality settings."""
|
"""
|
||||||
|
VibePod Optimization Patch:
|
||||||
|
Replaces performance-critical VibeVoice methods with optimized versions.
|
||||||
|
Includes caching for the noise scheduler and noise tensors to avoid
|
||||||
|
re-allocation overhead during the diffusion loop.
|
||||||
|
"""
|
||||||
|
|
||||||
def profile_enabled() -> bool:
|
def profile_enabled() -> bool:
|
||||||
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
|
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
|
||||||
@@ -483,7 +543,9 @@ def _install_generation_optimizations(model: object) -> None:
|
|||||||
t_batches = t_batch_cache.get(t_batch_cache_key)
|
t_batches = t_batch_cache.get(t_batch_cache_key)
|
||||||
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
||||||
t_batches = [
|
t_batches = [
|
||||||
t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
|
t.repeat(condition.shape[0]).to(
|
||||||
|
device=condition.device, dtype=condition.dtype
|
||||||
|
)
|
||||||
for t in scheduler.timesteps
|
for t in scheduler.timesteps
|
||||||
]
|
]
|
||||||
t_batch_cache[t_batch_cache_key] = t_batches
|
t_batch_cache[t_batch_cache_key] = t_batches
|
||||||
@@ -500,14 +562,18 @@ def _install_generation_optimizations(model: object) -> None:
|
|||||||
eps = self.model.prediction_head(combined, t_batch, condition=condition)
|
eps = self.model.prediction_head(combined, t_batch, condition=condition)
|
||||||
if profile_enabled():
|
if profile_enabled():
|
||||||
profile_sync()
|
profile_sync()
|
||||||
profile_record(self, "diffusion_prediction_head", time.perf_counter() - started)
|
profile_record(
|
||||||
|
self, "diffusion_prediction_head", time.perf_counter() - started
|
||||||
|
)
|
||||||
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
|
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
|
||||||
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||||
if profile_enabled():
|
if profile_enabled():
|
||||||
started = time.perf_counter()
|
started = time.perf_counter()
|
||||||
speech = scheduler.step(guided_eps, t, speech).prev_sample
|
speech = scheduler.step(guided_eps, t, speech).prev_sample
|
||||||
if profile_enabled():
|
if profile_enabled():
|
||||||
profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started)
|
profile_record(
|
||||||
|
self, "diffusion_scheduler_step", time.perf_counter() - started
|
||||||
|
)
|
||||||
|
|
||||||
return speech
|
return speech
|
||||||
|
|
||||||
@@ -532,7 +598,13 @@ def _install_generation_optimizations(model: object) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _install_cpu_pipeline_optimizations(model: object) -> None:
|
def _install_cpu_pipeline_optimizations(model: object) -> None:
|
||||||
"""Attach the decode executor to the model for the optimised generate() loop.
|
"""
|
||||||
|
VibePod Optimization Patch:
|
||||||
|
Enables asynchronous audio decoding on a background thread.
|
||||||
|
|
||||||
|
This allows the acoustic_decode (Vocoder) step to run in parallel with
|
||||||
|
the next chunk's forward_tts_lm (Inference) step, significantly reducing
|
||||||
|
the real-time factor on CPU.
|
||||||
|
|
||||||
The JezzWTF/VibeVoice fork's generate() checks for two optional attributes:
|
The JezzWTF/VibeVoice fork's generate() checks for two optional attributes:
|
||||||
|
|
||||||
@@ -587,12 +659,19 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec
|
|||||||
if isinstance(value, tuple):
|
if isinstance(value, tuple):
|
||||||
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
||||||
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
||||||
value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
|
value.key_cache = [
|
||||||
value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
|
_move_cached_prompt(t, device, dtype) for t in value.key_cache
|
||||||
|
]
|
||||||
|
value.value_cache = [
|
||||||
|
_move_cached_prompt(t, device, dtype) for t in value.value_cache
|
||||||
|
]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _load_voice_presets(device: str) -> dict[str, object]:
|
def _load_voice_presets(device: str) -> dict[str, object]:
|
||||||
|
"""
|
||||||
|
Load all pre-downloaded voice tensor files (.pt) from the voices directory.
|
||||||
|
"""
|
||||||
presets = {}
|
presets = {}
|
||||||
for name, filename in EN_VOICES.items():
|
for name, filename in EN_VOICES.items():
|
||||||
path = VOICES_DIR / filename
|
path = VOICES_DIR / filename
|
||||||
@@ -602,6 +681,11 @@ def _load_voice_presets(device: str) -> dict[str, object]:
|
|||||||
|
|
||||||
|
|
||||||
def _load_model_sync() -> None:
|
def _load_model_sync() -> None:
|
||||||
|
"""
|
||||||
|
Main synchronous initialization routine. Handles model/voice downloads,
|
||||||
|
device configuration, and model loading. Updates global status for the
|
||||||
|
health endpoint.
|
||||||
|
"""
|
||||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
|
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
|
||||||
|
|
||||||
with _load_lock:
|
with _load_lock:
|
||||||
@@ -640,17 +724,27 @@ def _load_model_sync() -> None:
|
|||||||
logical_cpus = os.cpu_count() or 1
|
logical_cpus = os.cpu_count() or 1
|
||||||
_config["cpu_threads"] = _env_int(
|
_config["cpu_threads"] = _env_int(
|
||||||
"VIBEPOD_CPU_THREADS",
|
"VIBEPOD_CPU_THREADS",
|
||||||
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
|
(
|
||||||
|
max(1, logical_cpus // 2)
|
||||||
|
if platform.system() == "Windows"
|
||||||
|
else logical_cpus
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_config["cpu_interop_threads"] = _env_int(
|
||||||
|
"VIBEPOD_CPU_INTEROP_THREADS", 1
|
||||||
|
)
|
||||||
|
_config["cpu_mkldnn"] = (
|
||||||
|
os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||||
)
|
)
|
||||||
_config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
|
||||||
_config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
|
||||||
|
|
||||||
_processor = _init_processor()
|
_processor = _init_processor()
|
||||||
_model = _init_model(_device)
|
_model = _init_model(_device)
|
||||||
_voice_presets = _load_voice_presets(_device)
|
_voice_presets = _load_voice_presets(_device)
|
||||||
|
|
||||||
_model_status = "online"
|
_model_status = "online"
|
||||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
logger.info(
|
||||||
|
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
|
||||||
|
)
|
||||||
logger.info("Configuration: %s", _config)
|
logger.info("Configuration: %s", _config)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -723,16 +817,21 @@ def _sync_generate(
|
|||||||
streamer: object | None = None,
|
streamer: object | None = None,
|
||||||
cancel_event: threading.Event | None = None,
|
cancel_event: threading.Event | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Blocking inference. Returns the speaker used.
|
"""
|
||||||
Runs in a thread-pool executor — do not call from the event loop directly.
|
Performs blocking model inference for TTS generation.
|
||||||
Pass an AsyncAudioStreamer to receive audio chunks in real time.
|
|
||||||
|
This function should always be run in a thread-pool executor to avoid
|
||||||
|
blocking the FastAPI event loop. It streams audio chunks back to the
|
||||||
|
caller via the provided streamer object.
|
||||||
"""
|
"""
|
||||||
if cancel_event and cancel_event.is_set():
|
if cancel_event and cancel_event.is_set():
|
||||||
raise RuntimeError("Generation cancelled.")
|
raise RuntimeError("Generation cancelled.")
|
||||||
|
|
||||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||||
model_dtype = _model_float_dtype()
|
model_dtype = _model_float_dtype()
|
||||||
voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
|
voice_preset = _move_cached_prompt(
|
||||||
|
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
|
||||||
|
)
|
||||||
|
|
||||||
steps = (
|
steps = (
|
||||||
req.inference_steps
|
req.inference_steps
|
||||||
@@ -786,7 +885,11 @@ def _generation_profile() -> dict[str, dict[str, float]] | None:
|
|||||||
key: {
|
key: {
|
||||||
"count": value["count"],
|
"count": value["count"],
|
||||||
"seconds": round(value["seconds"], 3),
|
"seconds": round(value["seconds"], 3),
|
||||||
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
|
"avg_ms": (
|
||||||
|
round(value["seconds"] * 1000 / value["count"], 3)
|
||||||
|
if value["count"]
|
||||||
|
else 0.0
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for key, value in sorted(stats.items())
|
for key, value in sorted(stats.items())
|
||||||
}
|
}
|
||||||
@@ -818,7 +921,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
self.finished_flags = [False for _ in range(batch_size)]
|
self.finished_flags = [False for _ in range(batch_size)]
|
||||||
self.loop = asyncio.get_running_loop()
|
self.loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None:
|
def put(
|
||||||
|
self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor
|
||||||
|
) -> None:
|
||||||
for i, sample_idx in enumerate(sample_indices):
|
for i, sample_idx in enumerate(sample_indices):
|
||||||
idx = sample_idx.item()
|
idx = sample_idx.item()
|
||||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||||
@@ -831,7 +936,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
if sample_indices is None:
|
if sample_indices is None:
|
||||||
indices_to_end = range(self.batch_size)
|
indices_to_end = range(self.batch_size)
|
||||||
else:
|
else:
|
||||||
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
indices_to_end = [
|
||||||
|
s.item() if torch.is_tensor(s) else s for s in sample_indices
|
||||||
|
]
|
||||||
for idx in indices_to_end:
|
for idx in indices_to_end:
|
||||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
@@ -863,7 +970,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
# stop_signal=None is the default sentinel that ends the queue.
|
# stop_signal=None is the default sentinel that ends the queue.
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
|
chunk = await asyncio.wait_for(
|
||||||
|
streamer.audio_queues[0].get(), timeout=120.0
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
cancel_event.set()
|
cancel_event.set()
|
||||||
future.cancel()
|
future.cancel()
|
||||||
@@ -949,11 +1058,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
"elapsed": elapsed,
|
"elapsed": elapsed,
|
||||||
"speaker": speaker,
|
"speaker": speaker,
|
||||||
"audio_secs": round(audio_secs, 2),
|
"audio_secs": round(audio_secs, 2),
|
||||||
"realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
|
"realtime_factor": (
|
||||||
|
round(realtime_factor, 3) if realtime_factor is not None else None
|
||||||
|
),
|
||||||
"chunks": chunk_count,
|
"chunks": chunk_count,
|
||||||
"first_chunk_secs": round(first_chunk_at - start, 2)
|
"first_chunk_secs": (
|
||||||
if first_chunk_at is not None
|
round(first_chunk_at - start, 2) if first_chunk_at is not None else None
|
||||||
else None,
|
),
|
||||||
"max_chunk_gap_secs": round(max_chunk_gap, 2),
|
"max_chunk_gap_secs": round(max_chunk_gap, 2),
|
||||||
}
|
}
|
||||||
if profile is not None:
|
if profile is not None:
|
||||||
|
|||||||
@@ -1,3 +1,14 @@
|
|||||||
|
/**
|
||||||
|
* API Proxy Route: POST /api/generate
|
||||||
|
*
|
||||||
|
* This route proxies requests from the frontend to the FastAPI backend's /generate endpoint.
|
||||||
|
*
|
||||||
|
* Security Architecture:
|
||||||
|
* The FastAPI backend is configured to bind only to localhost (127.0.0.1). This prevents
|
||||||
|
* unauthenticated public access to the model inference engine. Next.js acts as a secure
|
||||||
|
* proxy, allowing the frontend to interact with the backend while maintaining a
|
||||||
|
* single public-facing origin.
|
||||||
|
*/
|
||||||
import { NextRequest, NextResponse } from "next/server";
|
import { NextRequest, NextResponse } from "next/server";
|
||||||
|
|
||||||
export const dynamic = "force-dynamic";
|
export const dynamic = "force-dynamic";
|
||||||
|
|||||||
@@ -1,3 +1,14 @@
|
|||||||
|
/**
|
||||||
|
* API Proxy Route: GET /api/health
|
||||||
|
*
|
||||||
|
* This route proxies health check requests from the frontend to the FastAPI backend's /health endpoint.
|
||||||
|
*
|
||||||
|
* Security Architecture:
|
||||||
|
* The FastAPI backend is configured to bind only to localhost (127.0.0.1). This prevents
|
||||||
|
* unauthenticated public access to the server status and configuration. Next.js acts as a secure
|
||||||
|
* proxy, allowing the frontend to poll for server readiness and adaptive configuration
|
||||||
|
* while maintaining a single public-facing origin.
|
||||||
|
*/
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
const OFFLINE_RESPONSE = { status: "offline" };
|
const OFFLINE_RESPONSE = { status: "offline" };
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ export interface ServerConfig {
|
|||||||
default_inference_steps: number;
|
default_inference_steps: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- State Management ---
|
||||||
|
|
||||||
interface AppState {
|
interface AppState {
|
||||||
script: string;
|
script: string;
|
||||||
speaker: string;
|
speaker: string;
|
||||||
@@ -199,6 +201,8 @@ export default function HomePage() {
|
|||||||
resumeThresholdSecs: state.resumeThresholdSecs,
|
resumeThresholdSecs: state.resumeThresholdSecs,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// --- Server Health & Status Polling ---
|
||||||
|
|
||||||
// Server health polling — fast while not ready, slow when online
|
// Server health polling — fast while not ready, slow when online
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let timeoutId: ReturnType<typeof setTimeout>;
|
let timeoutId: ReturnType<typeof setTimeout>;
|
||||||
@@ -246,6 +250,8 @@ export default function HomePage() {
|
|||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// --- Generation Handling ---
|
||||||
|
|
||||||
const handleGenerate = useCallback(async () => {
|
const handleGenerate = useCallback(async () => {
|
||||||
if (!state.script.trim() || state.isGenerating) return;
|
if (!state.script.trim() || state.isGenerating) return;
|
||||||
addLog(`${wordCount} words queued`);
|
addLog(`${wordCount} words queued`);
|
||||||
|
|||||||
@@ -1,5 +1,16 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hook for managing real-time streaming audio generation from the VibeVoice server.
|
||||||
|
*
|
||||||
|
* Streaming Lifecycle:
|
||||||
|
* 1. fetch /api/generate: Initiates a POST request to the generation endpoint.
|
||||||
|
* 2. parse SSE chunks: Listens for Server-Sent Events (SSE) containing audio data or status updates.
|
||||||
|
* 3. decode base64 float32 PCM: Converts incoming base64-encoded strings into raw Float32 audio samples.
|
||||||
|
* 4. schedule Web Audio playback: Enqueues audio chunks into an AudioContext for low-latency playback.
|
||||||
|
* 5. handle adaptive buffering: Monitors playback progress and pauses to refill the buffer if an underrun is detected.
|
||||||
|
* 6. assemble final WAV Blob: Combines all received chunks into a single WAV file once generation is complete.
|
||||||
|
*/
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
const SAMPLE_RATE = 24_000;
|
const SAMPLE_RATE = 24_000;
|
||||||
@@ -30,6 +41,9 @@ interface UseStreamingGenerationOptions {
|
|||||||
resumeThresholdSecs?: number;
|
resumeThresholdSecs?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Concatenates multiple Float32Array chunks into a single Float32Array.
|
||||||
|
*/
|
||||||
function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<ArrayBuffer> {
|
function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<ArrayBuffer> {
|
||||||
const total = chunks.reduce((sum, chunk) => sum + chunk.length, 0);
|
const total = chunks.reduce((sum, chunk) => sum + chunk.length, 0);
|
||||||
const out = new Float32Array(total);
|
const out = new Float32Array(total);
|
||||||
@@ -41,6 +55,9 @@ function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<A
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps Float32 PCM samples into a WAV file Blob with a standard header.
|
||||||
|
*/
|
||||||
function buildWav(samples: Float32Array<ArrayBuffer>, sampleRate: number): Blob {
|
function buildWav(samples: Float32Array<ArrayBuffer>, sampleRate: number): Blob {
|
||||||
const dataSize = samples.length * 4;
|
const dataSize = samples.length * 4;
|
||||||
const buffer = new ArrayBuffer(44 + dataSize);
|
const buffer = new ArrayBuffer(44 + dataSize);
|
||||||
@@ -68,6 +85,9 @@ function buildWav(samples: Float32Array<ArrayBuffer>, sampleRate: number): Blob
|
|||||||
return new Blob([buffer], { type: "audio/wav" });
|
return new Blob([buffer], { type: "audio/wav" });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decodes a base64-encoded string into a Float32Array of PCM samples.
|
||||||
|
*/
|
||||||
function decodeFloat32Chunk(data: string): Float32Array<ArrayBuffer> {
|
function decodeFloat32Chunk(data: string): Float32Array<ArrayBuffer> {
|
||||||
const raw = atob(data);
|
const raw = atob(data);
|
||||||
const bytes = new Uint8Array(raw.length);
|
const bytes = new Uint8Array(raw.length);
|
||||||
@@ -141,6 +161,9 @@ export function useStreamingGeneration({
|
|||||||
};
|
};
|
||||||
}, [resetPlayback, revokeCurrentUrl]);
|
}, [resetPlayback, revokeCurrentUrl]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an AudioBuffer from a chunk and schedules it for playback in the AudioContext.
|
||||||
|
*/
|
||||||
const enqueue = useCallback((ctx: AudioContext, chunk: Float32Array<ArrayBuffer>) => {
|
const enqueue = useCallback((ctx: AudioContext, chunk: Float32Array<ArrayBuffer>) => {
|
||||||
const audioBuffer = ctx.createBuffer(1, chunk.length, SAMPLE_RATE);
|
const audioBuffer = ctx.createBuffer(1, chunk.length, SAMPLE_RATE);
|
||||||
audioBuffer.copyToChannel(chunk, 0);
|
audioBuffer.copyToChannel(chunk, 0);
|
||||||
@@ -152,6 +175,9 @@ export function useStreamingGeneration({
|
|||||||
nextStartTimeRef.current = startAt + audioBuffer.duration;
|
nextStartTimeRef.current = startAt + audioBuffer.duration;
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resets the playback timing and enqueues all currently buffered chunks for immediate playback.
|
||||||
|
*/
|
||||||
const flushBufferedAudio = useCallback(() => {
|
const flushBufferedAudio = useCallback(() => {
|
||||||
const ctx = audioCtxRef.current;
|
const ctx = audioCtxRef.current;
|
||||||
if (!ctx || chunksRef.current.length === 0) return;
|
if (!ctx || chunksRef.current.length === 0) return;
|
||||||
@@ -162,6 +188,10 @@ export function useStreamingGeneration({
|
|||||||
hasStartedPlaybackRef.current = true;
|
hasStartedPlaybackRef.current = true;
|
||||||
}, [enqueue]);
|
}, [enqueue]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Processes a new audio chunk, either buffering it for initial playback or enqueuing it for
|
||||||
|
* immediate playback with adaptive buffering logic.
|
||||||
|
*/
|
||||||
const handleAudioChunk = useCallback(
|
const handleAudioChunk = useCallback(
|
||||||
(chunk: Float32Array<ArrayBuffer>) => {
|
(chunk: Float32Array<ArrayBuffer>) => {
|
||||||
const ctx = audioCtxRef.current;
|
const ctx = audioCtxRef.current;
|
||||||
|
|||||||
Reference in New Issue
Block a user