Improve CPU Inference Stability: Adaptive Buffering & Chunk Accumulation (#11)

* Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation

This change addresses audio stuttering issues when running on CPU-only hardware by:
- Implementing server-side audio chunk accumulation to reduce SSE overhead.
- Introducing device-aware default configurations for buffering and inference steps.
- Exposing key performance parameters as environment variables.
- Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration.

Changes:
- Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`.
- Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters.
- Updated `web/app/page.tsx` to fetch and apply server-side configuration.

Verified on CPU mode in the development environment.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>

* Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation

This change addresses audio stuttering issues when running on CPU-only hardware by:
- Implementing server-side audio chunk accumulation to reduce SSE overhead.
- Introducing device-aware default configurations for buffering and inference steps.
- Exposing key performance parameters as environment variables.
- Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration.

Changes:
- Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`.
- Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters.
- Updated `web/app/page.tsx` to fetch and apply server-side configuration.

Verified on CPU mode in the development environment.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>

* Improve CPU Inference Stability: Adaptive Buffering UI & Logic

This change enhances the initial CPU stability fix by:
- Exposing adaptive buffering settings (Pre-buffer, Re-buffer Threshold, Resume Threshold) in a new "Advanced Buffering" UI section.
- Managing buffering settings in the application state to allow for manual overrides.
- Implementing robust re-initialization of buffering and inference defaults whenever the server's device (CPU/CUDA) changes.
- Including the active device in the server's config object for reliable client-side detection.

Verified with frontend screenshots and full build. Responds to PR feedback regarding actioning the adaptive logic.

Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com>

* Refine adaptive buffering: env helpers, threshold validation, a11y fixes

- Extract _env_int/_env_float helpers in server to validate env-var config
  with graceful fallback instead of bare int/float casts
- Fix inference_steps falsy-check (0 is valid) to use explicit None guard
- Enforce rebufferThresholdSecs < resumeThresholdSecs in both the hook
  (with console.warn + clamp) and the GenerationControls UI (sliders block
  invalid states by auto-bumping or ignoring the drag)
- Add type="button", aria-expanded, aria-controls, htmlFor, and input id
  attributes to GenerationControls for accessibility
- Add .vscode/settings.json to .gitignore; sort package.json scripts

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
This commit is contained in:
2026-04-30 16:03:35 +01:00
committed by GitHub
parent 87185e6289
commit a39ec536fd
7 changed files with 339 additions and 49 deletions
+1
View File
@@ -23,3 +23,4 @@ web/node_modules/
# OS
.DS_Store
Thumbs.db
.vscode/settings.json
+2 -2
View File
@@ -3,12 +3,12 @@
"version": "0.1.0",
"private": true,
"scripts": {
"build": "pnpm --filter vibepod-web build",
"dev": "bash dev.sh",
"dev:cpu": "bash dev.sh --cpu",
"dev:server": "bash server/start.sh",
"dev:server:cpu": "bash server/start.sh --cpu",
"dev:web": "pnpm --filter vibepod-web dev",
"build": "pnpm --filter vibepod-web build"
"dev:web": "pnpm --filter vibepod-web dev"
},
"packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8"
}
+3 -1
View File
@@ -36,7 +36,9 @@ def download() -> str:
)
sys.exit(1)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
print(f"Checking / downloading model: {MODEL_ID}")
print("(This may take several minutes on first run — the model is ~1 GB)")
+111 -25
View File
@@ -52,11 +52,11 @@ VOICE_BASE_URL = (
EN_VOICES: dict[str, str] = {
"carter": "en-Carter_man.pt",
"davis": "en-Davis_man.pt",
"emma": "en-Emma_woman.pt",
"frank": "en-Frank_man.pt",
"grace": "en-Grace_woman.pt",
"mike": "en-Mike_man.pt",
"davis": "en-Davis_man.pt",
"emma": "en-Emma_woman.pt",
"frank": "en-Frank_man.pt",
"grace": "en-Grace_woman.pt",
"mike": "en-Mike_man.pt",
}
DEFAULT_SPEAKER = "carter"
@@ -66,6 +66,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
# Falls back to auto-detection if not set.
def _resolve_device() -> str:
"""Resolve the target device from env var or auto-detect."""
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
@@ -80,6 +81,31 @@ def _resolve_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
# ── Env-var helpers ─────────────────────────────────────────────────────────────
def _env_int(name: str, default: int) -> int:
raw = os.environ.get(name, "").strip()
if not raw:
return default
try:
return int(raw)
except ValueError:
logger.warning("Invalid value for %s=%r — using default %d", name, raw, default)
return default
def _env_float(name: str, default: float) -> float:
raw = os.environ.get(name, "").strip()
if not raw:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid value for %s=%r — using default %g", name, raw, default)
return default
# ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"]
@@ -93,13 +119,24 @@ _voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
# Config defaults (can be overridden by env vars)
# These are populated in _load_model_sync once the device is known.
_config = {
"device": "cpu",
"chunk_accum": 1,
"prebuffer_secs": 2.0,
"rebuffer_threshold_secs": 0.4,
"resume_threshold_secs": 1.5,
"default_inference_steps": 10,
}
# Download progress (files downloaded so far)
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
def _make_dl_tqdm() -> type:
class _DlTqdm(_BaseTqdm):
def __init__(self, *args: object, **kwargs: object) -> None:
@@ -119,10 +156,14 @@ def _make_dl_tqdm() -> type:
# ── Model / voice helpers ───────────────────────────────────────────────────────
def _is_model_cached() -> bool:
try:
from huggingface_hub import snapshot_download
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
snapshot_download(
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
)
return True
except Exception:
return False
@@ -130,7 +171,10 @@ def _is_model_cached() -> bool:
def _download_model() -> None:
from huggingface_hub import snapshot_download
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download(
@@ -155,11 +199,13 @@ def _download_voices() -> None:
# ── Background model loader ─────────────────────────────────────────────────────
def _init_processor():
logger.info("Loading processor...")
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
@@ -171,6 +217,7 @@ def _init_model(device: str):
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
try:
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
@@ -179,7 +226,9 @@ def _init_model(device: str):
attn_implementation=attn_impl,
)
except Exception:
logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True)
logger.warning(
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True
)
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
@@ -188,7 +237,7 @@ def _init_model(device: str):
)
model.eval()
model.set_ddpm_inference_steps(num_steps=10)
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
return model
@@ -197,14 +246,12 @@ def _load_voice_presets(device: str) -> dict[str, object]:
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
presets[name] = torch.load(
path, map_location=device, weights_only=False
)
presets[name] = torch.load(path, map_location=device, weights_only=False)
return presets
def _load_model_sync() -> None:
global _processor, _model, _device, _model_status, _model_error, _voice_presets
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
with _load_lock:
if _model is not None:
@@ -222,12 +269,24 @@ def _load_model_sync() -> None:
_device = _resolve_device()
logger.info("Using device: %s", _device)
# Populate config based on device
is_cpu = _device == "cpu"
_config["device"] = _device
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0)
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4)
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5)
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
_processor = _init_processor()
_model = _init_model(_device)
_voice_presets = _load_voice_presets(_device)
_model_status = "online"
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
logger.info(
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
)
logger.info("Configuration: %s", _config)
except Exception as exc:
_model_status = "error"
@@ -237,6 +296,7 @@ def _load_model_sync() -> None:
# ── FastAPI app ─────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
@@ -249,11 +309,12 @@ app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
# ── Schemas ─────────────────────────────────────────────────────────────────────
class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000)
speaker: str = Field(default=DEFAULT_SPEAKER)
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
inference_steps: int = Field(default=10, ge=5, le=20)
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
@field_validator("text")
@classmethod
@@ -270,6 +331,7 @@ class GenerateRequest(BaseModel):
# ── Endpoints ───────────────────────────────────────────────────────────────────
@app.get("/health")
async def health() -> dict:
body: dict = {
@@ -277,9 +339,13 @@ async def health() -> dict:
"model": MODEL_ID,
"device": _device,
"voices": list(_voice_presets.keys()),
"config": _config,
}
if _model_status == "downloading":
body["progress"] = {"done": _dl_progress["done"], "total": _dl_progress["total"]}
body["progress"] = {
"done": _dl_progress["done"],
"total": _dl_progress["total"],
}
if _model_error:
body["message"] = _model_error
return body
@@ -300,7 +366,8 @@ def _sync_generate(
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
voice_preset = copy.deepcopy(_voice_presets[speaker])
_model.set_ddpm_inference_steps(num_steps=req.inference_steps)
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
_model.set_ddpm_inference_steps(num_steps=steps)
inputs = _processor.process_input_with_cached_prompt(
text=req.text,
@@ -339,13 +406,15 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if _model_status != "online":
detail = {
"downloading": "Model is downloading — please wait.",
"loading": "Model is loading into memory — please wait.",
"error": f"Model failed to load: {_model_error or 'unknown error'}",
"loading": "Model is loading into memory — please wait.",
"error": f"Model failed to load: {_model_error or 'unknown error'}",
}.get(_model_status, "Server not ready.")
raise HTTPException(status_code=503, detail=detail)
if _generation_lock.locked():
raise HTTPException(status_code=503, detail="Server is already generating audio. Please wait.")
raise HTTPException(
status_code=503, detail="Server is already generating audio. Please wait."
)
async def event_stream() -> AsyncGenerator[str, None]:
from vibevoice.modular.streamer import AsyncAudioStreamer
@@ -354,6 +423,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
streamer = AsyncAudioStreamer(batch_size=1)
cancel_event = threading.Event()
accum_size = max(1, _config["chunk_accum"])
accumulated_chunks = []
async with _generation_lock:
loop = asyncio.get_event_loop()
future = loop.run_in_executor(
@@ -382,9 +454,18 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if chunk is None: # stop signal
break
pcm_b64 = base64.b64encode(
chunk.detach().cpu().float().numpy().tobytes()
).decode()
accumulated_chunks.append(chunk.detach().cpu().float())
if len(accumulated_chunks) >= accum_size:
combined = torch.cat(accumulated_chunks, dim=0)
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
accumulated_chunks = []
# Flush any remaining chunks
if accumulated_chunks:
combined = torch.cat(accumulated_chunks, dim=0)
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
try:
@@ -395,7 +476,12 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
return
except Exception as exc:
logger.exception("Generation failed: %s", exc)
yield _sse({"type": "error", "message": "Internal server error during generation."})
yield _sse(
{
"type": "error",
"message": "Internal server error during generation.",
}
)
return
elapsed = round(time.monotonic() - start, 1)
+81 -13
View File
@@ -15,11 +15,23 @@ export interface DownloadProgress {
total: number;
}
export interface ServerConfig {
device: string;
chunk_accum: number;
prebuffer_secs: number;
rebuffer_threshold_secs: number;
resume_threshold_secs: number;
default_inference_steps: number;
}
interface AppState {
script: string;
speaker: string;
cfgScale: number;
inferenceSteps: number;
prebufferSecs: number;
rebufferThresholdSecs: number;
resumeThresholdSecs: number;
isGenerating: boolean;
genElapsed: number;
genPct: number | null;
@@ -28,6 +40,7 @@ interface AppState {
serverStatus: ServerStatus;
downloadProgress: DownloadProgress | null;
availableVoices: string[];
serverConfig: ServerConfig | null;
}
type AppAction =
@@ -35,6 +48,9 @@ type AppAction =
| { type: "SET_SPEAKER"; payload: string }
| { type: "SET_CFG_SCALE"; payload: number }
| { type: "SET_INFERENCE_STEPS"; payload: number }
| { type: "SET_PREBUFFER_SECS"; payload: number }
| { type: "SET_REBUFFER_THRESHOLD"; payload: number }
| { type: "SET_RESUME_THRESHOLD"; payload: number }
| { type: "START_GENERATION" }
| { type: "GEN_PROGRESS"; elapsed: number; pct: number | null }
| { type: "GENERATION_SUCCESS"; payload: string }
@@ -43,7 +59,12 @@ type AppAction =
| { type: "ADD_LOG"; payload: string }
| {
type: "SET_SERVER_STATUS";
payload: { status: ServerStatus; progress?: DownloadProgress | null; voices?: string[] };
payload: {
status: ServerStatus;
progress?: DownloadProgress | null;
voices?: string[];
config?: ServerConfig | null;
};
};
function reducer(state: AppState, action: AppAction): AppState {
@@ -52,6 +73,9 @@ function reducer(state: AppState, action: AppAction): AppState {
case "SET_SPEAKER": return { ...state, speaker: action.payload };
case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload };
case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload };
case "SET_PREBUFFER_SECS": return { ...state, prebufferSecs: action.payload };
case "SET_REBUFFER_THRESHOLD": return { ...state, rebufferThresholdSecs: action.payload };
case "SET_RESUME_THRESHOLD": return { ...state, resumeThresholdSecs: action.payload };
case "START_GENERATION":
return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null };
case "GEN_PROGRESS":
@@ -63,14 +87,40 @@ function reducer(state: AppState, action: AppAction): AppState {
return { ...state, isGenerating: false, genElapsed: 0, genPct: null };
case "ADD_LOG":
return { ...state, logs: [...state.logs, action.payload] };
case "SET_SERVER_STATUS":
case "SET_SERVER_STATUS": {
const isNewConfig = !state.serverConfig && action.payload.config;
const deviceChanged = !!(state.serverConfig && action.payload.config && state.serverConfig.device !== action.payload.config.device);
const nextSteps = (isNewConfig || deviceChanged)
? action.payload.config!.default_inference_steps
: state.inferenceSteps;
const nextPrebuffer = (isNewConfig || deviceChanged)
? action.payload.config!.prebuffer_secs
: state.prebufferSecs;
const nextRebuffer = (isNewConfig || deviceChanged)
? action.payload.config!.rebuffer_threshold_secs
: state.rebufferThresholdSecs;
const nextResume = (isNewConfig || deviceChanged)
? action.payload.config!.resume_threshold_secs
: state.resumeThresholdSecs;
return {
...state,
serverStatus: action.payload.status,
downloadProgress: action.payload.progress ?? null,
availableVoices:
action.payload.voices?.length ? action.payload.voices : state.availableVoices,
availableVoices: action.payload.voices?.length
? action.payload.voices
: state.availableVoices,
serverConfig: action.payload.config ?? state.serverConfig,
inferenceSteps: nextSteps,
prebufferSecs: nextPrebuffer,
rebufferThresholdSecs: nextRebuffer,
resumeThresholdSecs: nextResume,
};
}
default: return state;
}
}
@@ -80,6 +130,9 @@ const initialState: AppState = {
speaker: "carter",
cfgScale: 1.5,
inferenceSteps: 10,
prebufferSecs: 2.0,
rebufferThresholdSecs: 0.4,
resumeThresholdSecs: 1.5,
isGenerating: false,
genElapsed: 0,
genPct: null,
@@ -88,6 +141,7 @@ const initialState: AppState = {
serverStatus: "offline",
downloadProgress: null,
availableVoices: [],
serverConfig: null,
};
export default function HomePage() {
@@ -106,19 +160,16 @@ export default function HomePage() {
const handleGenerationCancel = useCallback(() => dispatch({ type: "GENERATION_CANCELLED" }), []);
const handleGenerationError = useCallback(() => dispatch({ type: "GENERATION_ERROR" }), []);
const {
generate,
pauseStream,
resumeStream,
stop,
isStreamPaused,
} = useStreamingGeneration({
const { generate, pauseStream, resumeStream, stop, isStreamPaused } = useStreamingGeneration({
onLog: addLog,
onStart: handleGenerationStart,
onProgress: handleGenerationProgress,
onSuccess: handleGenerationSuccess,
onCancel: handleGenerationCancel,
onError: handleGenerationError,
prebufferSecs: state.prebufferSecs,
rebufferThresholdSecs: state.rebufferThresholdSecs,
resumeThresholdSecs: state.resumeThresholdSecs,
});
// Server health polling — fast while not ready, slow when online
@@ -131,21 +182,32 @@ export default function HomePage() {
let nextStatus: ServerStatus = "offline";
let nextProgress: DownloadProgress | null = null;
let nextVoices: string[] = [];
let nextConfig: ServerConfig | null = null;
try {
const res = await fetch("/api/health", { cache: "no-store" });
const data = await res.json() as {
const data = (await res.json()) as {
status: ServerStatus;
progress?: DownloadProgress | null;
voices?: string[];
config?: ServerConfig;
};
nextStatus = data.status ?? "offline";
nextProgress = data.progress ?? null;
nextVoices = data.voices ?? [];
nextConfig = data.config ?? null;
} catch {
nextStatus = "offline";
}
if (!cancelled) {
dispatch({ type: "SET_SERVER_STATUS", payload: { status: nextStatus, progress: nextProgress, voices: nextVoices } });
dispatch({
type: "SET_SERVER_STATUS",
payload: {
status: nextStatus,
progress: nextProgress,
voices: nextVoices,
config: nextConfig,
},
});
timeoutId = setTimeout(poll, nextStatus === "online" ? 15_000 : 2_000);
}
}
@@ -199,6 +261,12 @@ export default function HomePage() {
onCfgScaleChange={(v) => dispatch({ type: "SET_CFG_SCALE", payload: v })}
inferenceSteps={state.inferenceSteps}
onInferenceStepsChange={(v) => dispatch({ type: "SET_INFERENCE_STEPS", payload: v })}
prebufferSecs={state.prebufferSecs}
onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })}
rebufferThresholdSecs={state.rebufferThresholdSecs}
onRebufferThresholdChange={(v) => dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })}
resumeThresholdSecs={state.resumeThresholdSecs}
onResumeThresholdChange={(v) => dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })}
onGenerate={handleGenerate}
onStop={stop}
onPauseStream={pauseStream}
+117 -1
View File
@@ -1,5 +1,6 @@
"use client";
import { useState } from "react";
import type { ServerStatus, DownloadProgress } from "@/app/page";
const FALLBACK_VOICES = ["carter", "davis", "emma", "frank", "grace", "mike"];
@@ -12,6 +13,12 @@ interface GenerationControlsProps {
onCfgScaleChange: (v: number) => void;
inferenceSteps: number;
onInferenceStepsChange: (v: number) => void;
prebufferSecs: number;
onPrebufferSecsChange: (v: number) => void;
rebufferThresholdSecs: number;
onRebufferThresholdChange: (v: number) => void;
resumeThresholdSecs: number;
onResumeThresholdChange: (v: number) => void;
onGenerate: () => void;
onStop: () => void;
onPauseStream: () => void;
@@ -53,6 +60,12 @@ export default function GenerationControls({
onCfgScaleChange,
inferenceSteps,
onInferenceStepsChange,
prebufferSecs,
onPrebufferSecsChange,
rebufferThresholdSecs,
onRebufferThresholdChange,
resumeThresholdSecs,
onResumeThresholdChange,
onGenerate,
onStop,
onPauseStream,
@@ -65,6 +78,7 @@ export default function GenerationControls({
serverStatus,
downloadProgress,
}: GenerationControlsProps) {
const [showAdvanced, setShowAdvanced] = useState(false);
const voices = availableVoices.length > 0 ? availableVoices : FALLBACK_VOICES;
const serverReady = serverStatus === "online";
const buttonDisabled = isGenerating || wordCount === 0 || !serverReady;
@@ -169,6 +183,108 @@ export default function GenerationControls({
</div>
</div>
{/* Advanced Buffering toggle */}
<div className="pt-2">
<button
type="button"
onClick={() => setShowAdvanced(!showAdvanced)}
aria-expanded={showAdvanced}
aria-controls="advanced-buffering-panel"
className="flex items-center gap-2 text-xs font-semibold uppercase tracking-wider cursor-pointer transition-colors"
style={{ color: showAdvanced ? "var(--accent-teal)" : "var(--muted)" }}
>
<svg
className={`w-3 h-3 transition-transform ${showAdvanced ? "rotate-90" : ""}`}
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="3"
>
<polyline points="9 18 15 12 9 6" />
</svg>
Advanced Buffering
</button>
</div>
{showAdvanced && (
<div id="advanced-buffering-panel" className="flex flex-col gap-4 pl-2 border-l" style={{ borderColor: "var(--border)" }}>
{/* Pre-buffer */}
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between">
<label className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
Initial Pre-buffer
</label>
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
{prebufferSecs.toFixed(1)}s
</span>
</div>
<input
type="range"
min={0.5}
max={10.0}
step={0.5}
value={prebufferSecs}
onChange={(e) => onPrebufferSecsChange(parseFloat(e.target.value))}
className="w-full h-1"
/>
</div>
{/* Re-buffer threshold */}
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between">
<label htmlFor="rebuffer-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
Re-buffer Threshold
</label>
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
{rebufferThresholdSecs.toFixed(1)}s
</span>
</div>
<input
id="rebuffer-threshold"
type="range"
min={0.1}
max={3.0}
step={0.1}
value={rebufferThresholdSecs}
onChange={(e) => {
const next = parseFloat(e.target.value);
onRebufferThresholdChange(next);
if (resumeThresholdSecs <= next) {
onResumeThresholdChange(parseFloat((next + 0.5).toFixed(1)));
}
}}
className="w-full h-1"
/>
</div>
{/* Resume threshold */}
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between">
<label htmlFor="resume-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
Resume Threshold
</label>
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
{resumeThresholdSecs.toFixed(1)}s
</span>
</div>
<input
id="resume-threshold"
type="range"
min={0.5}
max={5.0}
step={0.1}
value={resumeThresholdSecs}
onChange={(e) => {
const next = parseFloat(e.target.value);
if (next <= rebufferThresholdSecs) return;
onResumeThresholdChange(next);
}}
className="w-full h-1"
/>
</div>
</div>
)}
{/* Server status banner */}
{!serverReady && (
<div
@@ -177,7 +293,7 @@ export default function GenerationControls({
>
<div className="flex items-center gap-2">
<span
className={`w-2 h-2 rounded-full flex-shrink-0 ${serverStatus === "offline" || serverStatus === "error" ? "" : "animate-pulse"}`}
className={`w-2 h-2 rounded-full shrink-0 ${serverStatus === "offline" || serverStatus === "error" ? "" : "animate-pulse"}`}
style={{ background: STATUS_CONFIG[serverStatus].color }}
/>
<span style={{ color: STATUS_CONFIG[serverStatus].color }}>
+24 -7
View File
@@ -3,9 +3,9 @@
import { useCallback, useEffect, useRef, useState } from "react";
const SAMPLE_RATE = 24_000;
const PREBUFFER_SECS = 2.0;
const REBUFFER_THRESHOLD_SECS = 0.4;
const RESUME_THRESHOLD_SECS = 1.5;
const DEFAULT_PREBUFFER_SECS = 2.0;
const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4;
const DEFAULT_RESUME_THRESHOLD_SECS = 1.5;
interface GenerateOptions {
text: string;
@@ -21,6 +21,12 @@ interface UseStreamingGenerationOptions {
onSuccess: (audioUrl: string) => void;
onCancel: () => void;
onError: () => void;
/** Seconds of audio to buffer before playback starts. */
prebufferSecs?: number;
/** Buffer lookahead (seconds) below which playback suspends to refill. */
rebufferThresholdSecs?: number;
/** Buffer lookahead (seconds) at or above which suspended playback resumes. Must be > rebufferThresholdSecs. */
resumeThresholdSecs?: number;
}
function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<ArrayBuffer> {
@@ -77,7 +83,18 @@ export function useStreamingGeneration({
onSuccess,
onCancel,
onError,
prebufferSecs = DEFAULT_PREBUFFER_SECS,
rebufferThresholdSecs: rawRebufferThresholdSecs = DEFAULT_REBUFFER_THRESHOLD_SECS,
resumeThresholdSecs: rawResumeThresholdSecs = DEFAULT_RESUME_THRESHOLD_SECS,
}: UseStreamingGenerationOptions) {
let rebufferThresholdSecs = rawRebufferThresholdSecs;
let resumeThresholdSecs = rawResumeThresholdSecs;
if (resumeThresholdSecs <= rebufferThresholdSecs) {
console.warn(
`[useStreamingGeneration] resumeThresholdSecs (${resumeThresholdSecs}) must be greater than rebufferThresholdSecs (${rebufferThresholdSecs}). Clamping resumeThresholdSecs to ${rebufferThresholdSecs + 0.5}.`,
);
resumeThresholdSecs = rebufferThresholdSecs + 0.5;
}
const [isStreamPaused, setIsStreamPaused] = useState(false);
const abortRef = useRef<AbortController | null>(null);
const audioCtxRef = useRef<AudioContext | null>(null);
@@ -144,7 +161,7 @@ export function useStreamingGeneration({
if (!hasStartedPlaybackRef.current) {
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
if (bufferedSecs >= PREBUFFER_SECS) {
if (bufferedSecs >= prebufferSecs) {
flushBufferedAudio();
}
return;
@@ -154,18 +171,18 @@ export function useStreamingGeneration({
if (isUserPausedRef.current) return;
const ahead = nextStartTimeRef.current - ctx.currentTime;
if (ctx.state === "running" && ahead < REBUFFER_THRESHOLD_SECS) {
if (ctx.state === "running" && ahead < rebufferThresholdSecs) {
ctx.suspend().catch(() => {});
isAutoBufferingRef.current = true;
} else if (
ctx.state === "suspended" &&
isAutoBufferingRef.current &&
ahead >= RESUME_THRESHOLD_SECS
ahead >= resumeThresholdSecs
) {
ctx.resume().catch(() => {});
isAutoBufferingRef.current = false;
}
}, [enqueue, flushBufferedAudio]);
}, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
const generate = useCallback(async (options: GenerateOptions) => {
if (!options.text.trim()) return;