mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
Merge pull request #15 from JezzWTF/improve-docs-and-maintainer-notes-9165053560558121838
Improve codebase documentation and maintainer notes
This commit is contained in:
+161
-50
@@ -1,21 +1,38 @@
|
||||
"""
|
||||
VibePod — VibeVoice FastAPI TTS Server
|
||||
|
||||
Startup sequence (background thread):
|
||||
1. Download model weights if not cached -> status: downloading
|
||||
2. Download voice preset .pt files -> status: loading
|
||||
3. Load processor + model into memory -> status: loading
|
||||
4. Pre-load all voice tensors -> status: loading
|
||||
-> Server ready -> status: online
|
||||
This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model,
|
||||
optimized for real-time streaming on both CPU and NVIDIA GPU hardware.
|
||||
|
||||
Generation flow:
|
||||
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
|
||||
ends with {type:"complete"}
|
||||
MAINTAINER GUIDE / FILE MAP:
|
||||
- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars.
|
||||
- Global State: Thread-safe storage for the loaded model, processor, and server status.
|
||||
- Background Model Loader: Logic for downloading weights and initializing the model with optimizations.
|
||||
- VibeVoice Patches: Performance-critical overrides for VibeVoice internals (hot-paths).
|
||||
- FastAPI Application: SSE-based generation endpoint and health/status polling.
|
||||
- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network.
|
||||
|
||||
Device selection:
|
||||
Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh).
|
||||
Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available).
|
||||
If unset, the server auto-detects: CUDA if available, otherwise CPU.
|
||||
RUNTIME CONFIGURATION (Environment Variables):
|
||||
- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset).
|
||||
- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU).
|
||||
- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend).
|
||||
- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill.
|
||||
- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback.
|
||||
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA).
|
||||
- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging.
|
||||
|
||||
CPU-SPECIFIC OPTIMIZATIONS:
|
||||
- VIBEPOD_CPU_THREADS: Number of intra-op threads (defaults to logical core count / 2).
|
||||
- VIBEPOD_CPU_INTEROP_THREADS: Number of inter-op threads (default: 1).
|
||||
- VIBEPOD_CPU_MKLDNN: Set to '0' to disable MKLDNN (default: 1).
|
||||
- VIBEPOD_CPU_BF16: Set to '1' to force bfloat16, '0' for float32.
|
||||
- VIBEPOD_ASYNC_DECODE: Set to '1' to overlap decoding with inference on a separate thread (default: 1).
|
||||
- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization.
|
||||
- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS).
|
||||
|
||||
CUDA-SPECIFIC OPTIMIZATIONS:
|
||||
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'.
|
||||
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -50,9 +67,7 @@ MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
||||
SAMPLE_RATE = 24_000
|
||||
|
||||
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
||||
VOICE_BASE_URL = (
|
||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
||||
)
|
||||
VOICE_BASE_URL = "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
||||
|
||||
EN_VOICES: dict[str, str] = {
|
||||
"carter": "en-Carter_man.pt",
|
||||
@@ -77,7 +92,10 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
|
||||
|
||||
|
||||
def _resolve_device() -> str:
|
||||
"""Resolve the target device from env var or auto-detect."""
|
||||
"""
|
||||
Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment
|
||||
variable, falling back to CUDA if available, otherwise CPU.
|
||||
"""
|
||||
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
||||
if env in ("cpu", "cuda"):
|
||||
if env == "cuda" and not torch.cuda.is_available():
|
||||
@@ -94,6 +112,7 @@ def _resolve_device() -> str:
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
"""Helper to read an integer environment variable with a fallback default."""
|
||||
raw = os.environ.get(name, "").strip()
|
||||
if not raw:
|
||||
return default
|
||||
@@ -105,6 +124,7 @@ def _env_int(name: str, default: int) -> int:
|
||||
|
||||
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
"""Helper to read a float environment variable with a fallback default."""
|
||||
raw = os.environ.get(name, "").strip()
|
||||
if not raw:
|
||||
return default
|
||||
@@ -125,8 +145,14 @@ def _cpu_supports_bf16() -> bool:
|
||||
|
||||
|
||||
def _configure_cpu_runtime() -> dict[str, object]:
|
||||
"""
|
||||
Configure PyTorch's CPU execution engine, including thread counts and
|
||||
MKLDNN acceleration.
|
||||
"""
|
||||
logical_cpus = os.cpu_count() or 1
|
||||
default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
||||
default_threads = (
|
||||
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
||||
)
|
||||
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
|
||||
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
||||
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||
@@ -202,7 +228,9 @@ def _is_model_cached() -> bool:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
||||
snapshot_download(
|
||||
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -211,7 +239,9 @@ def _is_model_cached() -> bool:
|
||||
def _download_model() -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||
"HUGGINGFACE_TOKEN"
|
||||
)
|
||||
DlTqdm = _make_dl_tqdm()
|
||||
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
||||
snapshot_download(
|
||||
@@ -238,6 +268,9 @@ def _download_voices() -> None:
|
||||
|
||||
|
||||
def _init_processor():
|
||||
"""
|
||||
Initialize the VibeVoiceStreamingProcessor from the model repository.
|
||||
"""
|
||||
logger.info("Loading processor...")
|
||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||
VibeVoiceStreamingProcessor,
|
||||
@@ -247,6 +280,10 @@ def _init_processor():
|
||||
|
||||
|
||||
def _init_model(device: str):
|
||||
"""
|
||||
Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and
|
||||
apply VibePod-specific performance optimizations.
|
||||
"""
|
||||
logger.info("Loading model on %s...", device)
|
||||
if device == "cuda":
|
||||
torch.set_float32_matmul_precision("high")
|
||||
@@ -285,7 +322,9 @@ def _init_model(device: str):
|
||||
logger.info("AVX512_BF16 detected — loading model in bfloat16")
|
||||
else:
|
||||
load_dtype = torch.float32
|
||||
logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)")
|
||||
logger.info(
|
||||
"No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)"
|
||||
)
|
||||
logger.info("Loading model weights with dtype %s", load_dtype)
|
||||
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
|
||||
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
|
||||
@@ -294,7 +333,9 @@ def _init_model(device: str):
|
||||
elif requested_attn_impl == "flash_attention_2":
|
||||
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
|
||||
else:
|
||||
attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
|
||||
attn_impl = (
|
||||
"flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
|
||||
)
|
||||
logger.info("Using Transformers attention implementation: %s", attn_impl)
|
||||
if device == "cuda" and not has_flash_attn:
|
||||
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
|
||||
@@ -338,7 +379,10 @@ def _init_model(device: str):
|
||||
|
||||
|
||||
def _apply_cpu_optimizations(model: object) -> object:
|
||||
"""Apply optional post-load CPU optimizations. Returns (possibly new) model object."""
|
||||
"""
|
||||
Apply experimental CPU performance features like dynamic INT8 quantization
|
||||
or torch.compile if enabled via environment variables.
|
||||
"""
|
||||
|
||||
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
|
||||
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
|
||||
@@ -387,10 +431,19 @@ def _apply_cpu_optimizations(model: object) -> object:
|
||||
if hasattr(model, "model"):
|
||||
inner = model.model
|
||||
if hasattr(inner, "prediction_head"):
|
||||
_compile_targets.append(("prediction_head", inner, "prediction_head", False))
|
||||
if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
|
||||
_compile_targets.append(
|
||||
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
|
||||
("prediction_head", inner, "prediction_head", False)
|
||||
)
|
||||
if hasattr(inner, "acoustic_tokenizer") and hasattr(
|
||||
inner.acoustic_tokenizer, "decode"
|
||||
):
|
||||
_compile_targets.append(
|
||||
(
|
||||
"acoustic_tokenizer.decode",
|
||||
inner.acoustic_tokenizer,
|
||||
"decode",
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
for label, obj, attr, dynamic in _compile_targets:
|
||||
@@ -404,13 +457,20 @@ def _apply_cpu_optimizations(model: object) -> object:
|
||||
setattr(obj, attr, compiled)
|
||||
logger.info(" compiled: %s", label)
|
||||
except Exception as exc:
|
||||
logger.warning(" torch.compile failed for %s: %s — skipping", label, exc)
|
||||
logger.warning(
|
||||
" torch.compile failed for %s: %s — skipping", label, exc
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _install_generation_optimizations(model: object) -> None:
|
||||
"""Patch VibeVoice hot paths without changing model quality settings."""
|
||||
"""
|
||||
VibePod Optimization Patch:
|
||||
Replaces performance-critical VibeVoice methods with optimized versions.
|
||||
Includes caching for the noise scheduler and noise tensors to avoid
|
||||
re-allocation overhead during the diffusion loop.
|
||||
"""
|
||||
|
||||
def profile_enabled() -> bool:
|
||||
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
|
||||
@@ -483,7 +543,9 @@ def _install_generation_optimizations(model: object) -> None:
|
||||
t_batches = t_batch_cache.get(t_batch_cache_key)
|
||||
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
||||
t_batches = [
|
||||
t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
|
||||
t.repeat(condition.shape[0]).to(
|
||||
device=condition.device, dtype=condition.dtype
|
||||
)
|
||||
for t in scheduler.timesteps
|
||||
]
|
||||
t_batch_cache[t_batch_cache_key] = t_batches
|
||||
@@ -500,14 +562,18 @@ def _install_generation_optimizations(model: object) -> None:
|
||||
eps = self.model.prediction_head(combined, t_batch, condition=condition)
|
||||
if profile_enabled():
|
||||
profile_sync()
|
||||
profile_record(self, "diffusion_prediction_head", time.perf_counter() - started)
|
||||
profile_record(
|
||||
self, "diffusion_prediction_head", time.perf_counter() - started
|
||||
)
|
||||
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
|
||||
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
if profile_enabled():
|
||||
started = time.perf_counter()
|
||||
speech = scheduler.step(guided_eps, t, speech).prev_sample
|
||||
if profile_enabled():
|
||||
profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started)
|
||||
profile_record(
|
||||
self, "diffusion_scheduler_step", time.perf_counter() - started
|
||||
)
|
||||
|
||||
return speech
|
||||
|
||||
@@ -532,7 +598,13 @@ def _install_generation_optimizations(model: object) -> None:
|
||||
|
||||
|
||||
def _install_cpu_pipeline_optimizations(model: object) -> None:
|
||||
"""Attach the decode executor to the model for the optimised generate() loop.
|
||||
"""
|
||||
VibePod Optimization Patch:
|
||||
Enables asynchronous audio decoding on a background thread.
|
||||
|
||||
This allows the acoustic_decode (Vocoder) step to run in parallel with
|
||||
the next chunk's forward_tts_lm (Inference) step, significantly reducing
|
||||
the real-time factor on CPU.
|
||||
|
||||
The JezzWTF/VibeVoice fork's generate() checks for two optional attributes:
|
||||
|
||||
@@ -587,12 +659,19 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
||||
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
||||
value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
|
||||
value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
|
||||
value.key_cache = [
|
||||
_move_cached_prompt(t, device, dtype) for t in value.key_cache
|
||||
]
|
||||
value.value_cache = [
|
||||
_move_cached_prompt(t, device, dtype) for t in value.value_cache
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
def _load_voice_presets(device: str) -> dict[str, object]:
|
||||
"""
|
||||
Load all pre-downloaded voice tensor files (.pt) from the voices directory.
|
||||
"""
|
||||
presets = {}
|
||||
for name, filename in EN_VOICES.items():
|
||||
path = VOICES_DIR / filename
|
||||
@@ -602,6 +681,11 @@ def _load_voice_presets(device: str) -> dict[str, object]:
|
||||
|
||||
|
||||
def _load_model_sync() -> None:
|
||||
"""
|
||||
Main synchronous initialization routine. Handles model/voice downloads,
|
||||
device configuration, and model loading. Updates global status for the
|
||||
health endpoint.
|
||||
"""
|
||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
|
||||
|
||||
with _load_lock:
|
||||
@@ -640,17 +724,27 @@ def _load_model_sync() -> None:
|
||||
logical_cpus = os.cpu_count() or 1
|
||||
_config["cpu_threads"] = _env_int(
|
||||
"VIBEPOD_CPU_THREADS",
|
||||
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
|
||||
(
|
||||
max(1, logical_cpus // 2)
|
||||
if platform.system() == "Windows"
|
||||
else logical_cpus
|
||||
),
|
||||
)
|
||||
_config["cpu_interop_threads"] = _env_int(
|
||||
"VIBEPOD_CPU_INTEROP_THREADS", 1
|
||||
)
|
||||
_config["cpu_mkldnn"] = (
|
||||
os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||
)
|
||||
_config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
||||
_config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||
|
||||
_processor = _init_processor()
|
||||
_model = _init_model(_device)
|
||||
_voice_presets = _load_voice_presets(_device)
|
||||
|
||||
_model_status = "online"
|
||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||
logger.info(
|
||||
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
|
||||
)
|
||||
logger.info("Configuration: %s", _config)
|
||||
|
||||
except Exception as exc:
|
||||
@@ -723,16 +817,21 @@ def _sync_generate(
|
||||
streamer: object | None = None,
|
||||
cancel_event: threading.Event | None = None,
|
||||
) -> str:
|
||||
"""Blocking inference. Returns the speaker used.
|
||||
Runs in a thread-pool executor — do not call from the event loop directly.
|
||||
Pass an AsyncAudioStreamer to receive audio chunks in real time.
|
||||
"""
|
||||
Performs blocking model inference for TTS generation.
|
||||
|
||||
This function should always be run in a thread-pool executor to avoid
|
||||
blocking the FastAPI event loop. It streams audio chunks back to the
|
||||
caller via the provided streamer object.
|
||||
"""
|
||||
if cancel_event and cancel_event.is_set():
|
||||
raise RuntimeError("Generation cancelled.")
|
||||
|
||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||
model_dtype = _model_float_dtype()
|
||||
voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
|
||||
voice_preset = _move_cached_prompt(
|
||||
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
|
||||
)
|
||||
|
||||
steps = (
|
||||
req.inference_steps
|
||||
@@ -786,7 +885,11 @@ def _generation_profile() -> dict[str, dict[str, float]] | None:
|
||||
key: {
|
||||
"count": value["count"],
|
||||
"seconds": round(value["seconds"], 3),
|
||||
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
|
||||
"avg_ms": (
|
||||
round(value["seconds"] * 1000 / value["count"], 3)
|
||||
if value["count"]
|
||||
else 0.0
|
||||
),
|
||||
}
|
||||
for key, value in sorted(stats.items())
|
||||
}
|
||||
@@ -818,7 +921,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
self.finished_flags = [False for _ in range(batch_size)]
|
||||
self.loop = asyncio.get_running_loop()
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None:
|
||||
def put(
|
||||
self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor
|
||||
) -> None:
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
@@ -831,7 +936,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if sample_indices is None:
|
||||
indices_to_end = range(self.batch_size)
|
||||
else:
|
||||
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
||||
indices_to_end = [
|
||||
s.item() if torch.is_tensor(s) else s for s in sample_indices
|
||||
]
|
||||
for idx in indices_to_end:
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.loop.call_soon_threadsafe(
|
||||
@@ -863,7 +970,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
# stop_signal=None is the default sentinel that ends the queue.
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
|
||||
chunk = await asyncio.wait_for(
|
||||
streamer.audio_queues[0].get(), timeout=120.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
@@ -949,11 +1058,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
"elapsed": elapsed,
|
||||
"speaker": speaker,
|
||||
"audio_secs": round(audio_secs, 2),
|
||||
"realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
|
||||
"realtime_factor": (
|
||||
round(realtime_factor, 3) if realtime_factor is not None else None
|
||||
),
|
||||
"chunks": chunk_count,
|
||||
"first_chunk_secs": round(first_chunk_at - start, 2)
|
||||
if first_chunk_at is not None
|
||||
else None,
|
||||
"first_chunk_secs": (
|
||||
round(first_chunk_at - start, 2) if first_chunk_at is not None else None
|
||||
),
|
||||
"max_chunk_gap_secs": round(max_chunk_gap, 2),
|
||||
}
|
||||
if profile is not None:
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
/**
|
||||
* API Proxy Route: POST /api/generate
|
||||
*
|
||||
* This route proxies requests from the frontend to the FastAPI backend's /generate endpoint.
|
||||
*
|
||||
* Security Architecture:
|
||||
* The FastAPI backend is configured to bind only to localhost (127.0.0.1). This prevents
|
||||
* unauthenticated public access to the model inference engine. Next.js acts as a secure
|
||||
* proxy, allowing the frontend to interact with the backend while maintaining a
|
||||
* single public-facing origin.
|
||||
*/
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
export const dynamic = "force-dynamic";
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
/**
|
||||
* API Proxy Route: GET /api/health
|
||||
*
|
||||
* This route proxies health check requests from the frontend to the FastAPI backend's /health endpoint.
|
||||
*
|
||||
* Security Architecture:
|
||||
* The FastAPI backend is configured to bind only to localhost (127.0.0.1). This prevents
|
||||
* unauthenticated public access to the server status and configuration. Next.js acts as a secure
|
||||
* proxy, allowing the frontend to poll for server readiness and adaptive configuration
|
||||
* while maintaining a single public-facing origin.
|
||||
*/
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
const OFFLINE_RESPONSE = { status: "offline" };
|
||||
|
||||
@@ -24,6 +24,8 @@ export interface ServerConfig {
|
||||
default_inference_steps: number;
|
||||
}
|
||||
|
||||
// --- State Management ---
|
||||
|
||||
interface AppState {
|
||||
script: string;
|
||||
speaker: string;
|
||||
@@ -199,6 +201,8 @@ export default function HomePage() {
|
||||
resumeThresholdSecs: state.resumeThresholdSecs,
|
||||
});
|
||||
|
||||
// --- Server Health & Status Polling ---
|
||||
|
||||
// Server health polling — fast while not ready, slow when online
|
||||
useEffect(() => {
|
||||
let timeoutId: ReturnType<typeof setTimeout>;
|
||||
@@ -246,6 +250,8 @@ export default function HomePage() {
|
||||
};
|
||||
}, []);
|
||||
|
||||
// --- Generation Handling ---
|
||||
|
||||
const handleGenerate = useCallback(async () => {
|
||||
if (!state.script.trim() || state.isGenerating) return;
|
||||
addLog(`${wordCount} words queued`);
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
"use client";
|
||||
|
||||
/**
|
||||
* Hook for managing real-time streaming audio generation from the VibeVoice server.
|
||||
*
|
||||
* Streaming Lifecycle:
|
||||
* 1. fetch /api/generate: Initiates a POST request to the generation endpoint.
|
||||
* 2. parse SSE chunks: Listens for Server-Sent Events (SSE) containing audio data or status updates.
|
||||
* 3. decode base64 float32 PCM: Converts incoming base64-encoded strings into raw Float32 audio samples.
|
||||
* 4. schedule Web Audio playback: Enqueues audio chunks into an AudioContext for low-latency playback.
|
||||
* 5. handle adaptive buffering: Monitors playback progress and pauses to refill the buffer if an underrun is detected.
|
||||
* 6. assemble final WAV Blob: Combines all received chunks into a single WAV file once generation is complete.
|
||||
*/
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
const SAMPLE_RATE = 24_000;
|
||||
@@ -30,6 +41,9 @@ interface UseStreamingGenerationOptions {
|
||||
resumeThresholdSecs?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenates multiple Float32Array chunks into a single Float32Array.
|
||||
*/
|
||||
function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<ArrayBuffer> {
|
||||
const total = chunks.reduce((sum, chunk) => sum + chunk.length, 0);
|
||||
const out = new Float32Array(total);
|
||||
@@ -41,6 +55,9 @@ function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<A
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps Float32 PCM samples into a WAV file Blob with a standard header.
|
||||
*/
|
||||
function buildWav(samples: Float32Array<ArrayBuffer>, sampleRate: number): Blob {
|
||||
const dataSize = samples.length * 4;
|
||||
const buffer = new ArrayBuffer(44 + dataSize);
|
||||
@@ -68,6 +85,9 @@ function buildWav(samples: Float32Array<ArrayBuffer>, sampleRate: number): Blob
|
||||
return new Blob([buffer], { type: "audio/wav" });
|
||||
}
|
||||
|
||||
/**
|
||||
* Decodes a base64-encoded string into a Float32Array of PCM samples.
|
||||
*/
|
||||
function decodeFloat32Chunk(data: string): Float32Array<ArrayBuffer> {
|
||||
const raw = atob(data);
|
||||
const bytes = new Uint8Array(raw.length);
|
||||
@@ -141,6 +161,9 @@ export function useStreamingGeneration({
|
||||
};
|
||||
}, [resetPlayback, revokeCurrentUrl]);
|
||||
|
||||
/**
|
||||
* Creates an AudioBuffer from a chunk and schedules it for playback in the AudioContext.
|
||||
*/
|
||||
const enqueue = useCallback((ctx: AudioContext, chunk: Float32Array<ArrayBuffer>) => {
|
||||
const audioBuffer = ctx.createBuffer(1, chunk.length, SAMPLE_RATE);
|
||||
audioBuffer.copyToChannel(chunk, 0);
|
||||
@@ -152,6 +175,9 @@ export function useStreamingGeneration({
|
||||
nextStartTimeRef.current = startAt + audioBuffer.duration;
|
||||
}, []);
|
||||
|
||||
/**
|
||||
* Resets the playback timing and enqueues all currently buffered chunks for immediate playback.
|
||||
*/
|
||||
const flushBufferedAudio = useCallback(() => {
|
||||
const ctx = audioCtxRef.current;
|
||||
if (!ctx || chunksRef.current.length === 0) return;
|
||||
@@ -162,6 +188,10 @@ export function useStreamingGeneration({
|
||||
hasStartedPlaybackRef.current = true;
|
||||
}, [enqueue]);
|
||||
|
||||
/**
|
||||
* Processes a new audio chunk, either buffering it for initial playback or enqueuing it for
|
||||
* immediate playback with adaptive buffering logic.
|
||||
*/
|
||||
const handleAudioChunk = useCallback(
|
||||
(chunk: Float32Array<ArrayBuffer>) => {
|
||||
const ctx = audioCtxRef.current;
|
||||
|
||||
Reference in New Issue
Block a user