mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
Improve CPU Inference Stability: Adaptive Buffering & Chunk Accumulation (#11)
* Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation This change addresses audio stuttering issues when running on CPU-only hardware by: - Implementing server-side audio chunk accumulation to reduce SSE overhead. - Introducing device-aware default configurations for buffering and inference steps. - Exposing key performance parameters as environment variables. - Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration. Changes: - Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`. - Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters. - Updated `web/app/page.tsx` to fetch and apply server-side configuration. Verified on CPU mode in the development environment. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> * Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation This change addresses audio stuttering issues when running on CPU-only hardware by: - Implementing server-side audio chunk accumulation to reduce SSE overhead. - Introducing device-aware default configurations for buffering and inference steps. - Exposing key performance parameters as environment variables. - Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration. Changes: - Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`. - Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters. - Updated `web/app/page.tsx` to fetch and apply server-side configuration. Verified on CPU mode in the development environment. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> * Improve CPU Inference Stability: Adaptive Buffering UI & Logic This change enhances the initial CPU stability fix by: - Exposing adaptive buffering settings (Pre-buffer, Re-buffer Threshold, Resume Threshold) in a new "Advanced Buffering" UI section. - Managing buffering settings in the application state to allow for manual overrides. - Implementing robust re-initialization of buffering and inference defaults whenever the server's device (CPU/CUDA) changes. - Including the active device in the server's config object for reliable client-side detection. Verified with frontend screenshots and full build. Responds to PR feedback regarding actioning the adaptive logic. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> * Refine adaptive buffering: env helpers, threshold validation, a11y fixes - Extract _env_int/_env_float helpers in server to validate env-var config with graceful fallback instead of bare int/float casts - Fix inference_steps falsy-check (0 is valid) to use explicit None guard - Enforce rebufferThresholdSecs < resumeThresholdSecs in both the hook (with console.warn + clamp) and the GenerationControls UI (sliders block invalid states by auto-bumping or ignoring the drag) - Add type="button", aria-expanded, aria-controls, htmlFor, and input id attributes to GenerationControls for accessibility - Add .vscode/settings.json to .gitignore; sort package.json scripts --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
This commit is contained in:
@@ -23,3 +23,4 @@ web/node_modules/
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
.vscode/settings.json
|
||||
|
||||
+2
-2
@@ -3,12 +3,12 @@
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"build": "pnpm --filter vibepod-web build",
|
||||
"dev": "bash dev.sh",
|
||||
"dev:cpu": "bash dev.sh --cpu",
|
||||
"dev:server": "bash server/start.sh",
|
||||
"dev:server:cpu": "bash server/start.sh --cpu",
|
||||
"dev:web": "pnpm --filter vibepod-web dev",
|
||||
"build": "pnpm --filter vibepod-web build"
|
||||
"dev:web": "pnpm --filter vibepod-web dev"
|
||||
},
|
||||
"packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8"
|
||||
}
|
||||
|
||||
@@ -36,7 +36,9 @@ def download() -> str:
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||
"HUGGINGFACE_TOKEN"
|
||||
)
|
||||
|
||||
print(f"Checking / downloading model: {MODEL_ID}")
|
||||
print("(This may take several minutes on first run — the model is ~1 GB)")
|
||||
|
||||
+111
-25
@@ -52,11 +52,11 @@ VOICE_BASE_URL = (
|
||||
|
||||
EN_VOICES: dict[str, str] = {
|
||||
"carter": "en-Carter_man.pt",
|
||||
"davis": "en-Davis_man.pt",
|
||||
"emma": "en-Emma_woman.pt",
|
||||
"frank": "en-Frank_man.pt",
|
||||
"grace": "en-Grace_woman.pt",
|
||||
"mike": "en-Mike_man.pt",
|
||||
"davis": "en-Davis_man.pt",
|
||||
"emma": "en-Emma_woman.pt",
|
||||
"frank": "en-Frank_man.pt",
|
||||
"grace": "en-Grace_woman.pt",
|
||||
"mike": "en-Mike_man.pt",
|
||||
}
|
||||
DEFAULT_SPEAKER = "carter"
|
||||
|
||||
@@ -66,6 +66,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o
|
||||
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
||||
# Falls back to auto-detection if not set.
|
||||
|
||||
|
||||
def _resolve_device() -> str:
|
||||
"""Resolve the target device from env var or auto-detect."""
|
||||
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
||||
@@ -80,6 +81,31 @@ def _resolve_device() -> str:
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
# ── Env-var helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
raw = os.environ.get(name, "").strip()
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return int(raw)
|
||||
except ValueError:
|
||||
logger.warning("Invalid value for %s=%r — using default %d", name, raw, default)
|
||||
return default
|
||||
|
||||
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
raw = os.environ.get(name, "").strip()
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return float(raw)
|
||||
except ValueError:
|
||||
logger.warning("Invalid value for %s=%r — using default %g", name, raw, default)
|
||||
return default
|
||||
|
||||
|
||||
# ── Global state ────────────────────────────────────────────────────────────────
|
||||
|
||||
ModelStatus = Literal["downloading", "loading", "online", "error"]
|
||||
@@ -93,13 +119,24 @@ _voice_presets: dict[str, object] = {}
|
||||
_load_lock = threading.Lock()
|
||||
_generation_lock = asyncio.Lock()
|
||||
|
||||
# Config defaults (can be overridden by env vars)
|
||||
# These are populated in _load_model_sync once the device is known.
|
||||
_config = {
|
||||
"device": "cpu",
|
||||
"chunk_accum": 1,
|
||||
"prebuffer_secs": 2.0,
|
||||
"rebuffer_threshold_secs": 0.4,
|
||||
"resume_threshold_secs": 1.5,
|
||||
"default_inference_steps": 10,
|
||||
}
|
||||
|
||||
# Download progress (files downloaded so far)
|
||||
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
|
||||
|
||||
|
||||
|
||||
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
|
||||
|
||||
|
||||
def _make_dl_tqdm() -> type:
|
||||
class _DlTqdm(_BaseTqdm):
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
@@ -119,10 +156,14 @@ def _make_dl_tqdm() -> type:
|
||||
|
||||
# ── Model / voice helpers ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _is_model_cached() -> bool:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
||||
|
||||
snapshot_download(
|
||||
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -130,7 +171,10 @@ def _is_model_cached() -> bool:
|
||||
|
||||
def _download_model() -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||
|
||||
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||
"HUGGINGFACE_TOKEN"
|
||||
)
|
||||
DlTqdm = _make_dl_tqdm()
|
||||
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
||||
snapshot_download(
|
||||
@@ -155,11 +199,13 @@ def _download_voices() -> None:
|
||||
|
||||
# ── Background model loader ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _init_processor():
|
||||
logger.info("Loading processor...")
|
||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||
VibeVoiceStreamingProcessor,
|
||||
)
|
||||
|
||||
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
||||
|
||||
|
||||
@@ -171,6 +217,7 @@ def _init_model(device: str):
|
||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||
VibeVoiceStreamingForConditionalGenerationInference,
|
||||
)
|
||||
|
||||
try:
|
||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
@@ -179,7 +226,9 @@ def _init_model(device: str):
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True)
|
||||
logger.warning(
|
||||
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True
|
||||
)
|
||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
@@ -188,7 +237,7 @@ def _init_model(device: str):
|
||||
)
|
||||
|
||||
model.eval()
|
||||
model.set_ddpm_inference_steps(num_steps=10)
|
||||
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
|
||||
return model
|
||||
|
||||
|
||||
@@ -197,14 +246,12 @@ def _load_voice_presets(device: str) -> dict[str, object]:
|
||||
for name, filename in EN_VOICES.items():
|
||||
path = VOICES_DIR / filename
|
||||
if path.exists():
|
||||
presets[name] = torch.load(
|
||||
path, map_location=device, weights_only=False
|
||||
)
|
||||
presets[name] = torch.load(path, map_location=device, weights_only=False)
|
||||
return presets
|
||||
|
||||
|
||||
def _load_model_sync() -> None:
|
||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
|
||||
|
||||
with _load_lock:
|
||||
if _model is not None:
|
||||
@@ -222,12 +269,24 @@ def _load_model_sync() -> None:
|
||||
_device = _resolve_device()
|
||||
logger.info("Using device: %s", _device)
|
||||
|
||||
# Populate config based on device
|
||||
is_cpu = _device == "cpu"
|
||||
_config["device"] = _device
|
||||
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
||||
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0)
|
||||
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4)
|
||||
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5)
|
||||
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
|
||||
|
||||
_processor = _init_processor()
|
||||
_model = _init_model(_device)
|
||||
_voice_presets = _load_voice_presets(_device)
|
||||
|
||||
_model_status = "online"
|
||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||
logger.info(
|
||||
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
|
||||
)
|
||||
logger.info("Configuration: %s", _config)
|
||||
|
||||
except Exception as exc:
|
||||
_model_status = "error"
|
||||
@@ -237,6 +296,7 @@ def _load_model_sync() -> None:
|
||||
|
||||
# ── FastAPI app ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
||||
@@ -249,11 +309,12 @@ app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
# ── Schemas ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1, max_length=10_000)
|
||||
speaker: str = Field(default=DEFAULT_SPEAKER)
|
||||
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
||||
inference_steps: int = Field(default=10, ge=5, le=20)
|
||||
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
|
||||
|
||||
@field_validator("text")
|
||||
@classmethod
|
||||
@@ -270,6 +331,7 @@ class GenerateRequest(BaseModel):
|
||||
|
||||
# ── Endpoints ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
body: dict = {
|
||||
@@ -277,9 +339,13 @@ async def health() -> dict:
|
||||
"model": MODEL_ID,
|
||||
"device": _device,
|
||||
"voices": list(_voice_presets.keys()),
|
||||
"config": _config,
|
||||
}
|
||||
if _model_status == "downloading":
|
||||
body["progress"] = {"done": _dl_progress["done"], "total": _dl_progress["total"]}
|
||||
body["progress"] = {
|
||||
"done": _dl_progress["done"],
|
||||
"total": _dl_progress["total"],
|
||||
}
|
||||
if _model_error:
|
||||
body["message"] = _model_error
|
||||
return body
|
||||
@@ -300,7 +366,8 @@ def _sync_generate(
|
||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||
voice_preset = copy.deepcopy(_voice_presets[speaker])
|
||||
|
||||
_model.set_ddpm_inference_steps(num_steps=req.inference_steps)
|
||||
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
|
||||
_model.set_ddpm_inference_steps(num_steps=steps)
|
||||
|
||||
inputs = _processor.process_input_with_cached_prompt(
|
||||
text=req.text,
|
||||
@@ -339,13 +406,15 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if _model_status != "online":
|
||||
detail = {
|
||||
"downloading": "Model is downloading — please wait.",
|
||||
"loading": "Model is loading into memory — please wait.",
|
||||
"error": f"Model failed to load: {_model_error or 'unknown error'}",
|
||||
"loading": "Model is loading into memory — please wait.",
|
||||
"error": f"Model failed to load: {_model_error or 'unknown error'}",
|
||||
}.get(_model_status, "Server not ready.")
|
||||
raise HTTPException(status_code=503, detail=detail)
|
||||
|
||||
if _generation_lock.locked():
|
||||
raise HTTPException(status_code=503, detail="Server is already generating audio. Please wait.")
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Server is already generating audio. Please wait."
|
||||
)
|
||||
|
||||
async def event_stream() -> AsyncGenerator[str, None]:
|
||||
from vibevoice.modular.streamer import AsyncAudioStreamer
|
||||
@@ -354,6 +423,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
streamer = AsyncAudioStreamer(batch_size=1)
|
||||
cancel_event = threading.Event()
|
||||
|
||||
accum_size = max(1, _config["chunk_accum"])
|
||||
accumulated_chunks = []
|
||||
|
||||
async with _generation_lock:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
@@ -382,9 +454,18 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if chunk is None: # stop signal
|
||||
break
|
||||
|
||||
pcm_b64 = base64.b64encode(
|
||||
chunk.detach().cpu().float().numpy().tobytes()
|
||||
).decode()
|
||||
accumulated_chunks.append(chunk.detach().cpu().float())
|
||||
|
||||
if len(accumulated_chunks) >= accum_size:
|
||||
combined = torch.cat(accumulated_chunks, dim=0)
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||
accumulated_chunks = []
|
||||
|
||||
# Flush any remaining chunks
|
||||
if accumulated_chunks:
|
||||
combined = torch.cat(accumulated_chunks, dim=0)
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||
|
||||
try:
|
||||
@@ -395,7 +476,12 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.exception("Generation failed: %s", exc)
|
||||
yield _sse({"type": "error", "message": "Internal server error during generation."})
|
||||
yield _sse(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Internal server error during generation.",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
elapsed = round(time.monotonic() - start, 1)
|
||||
|
||||
+81
-13
@@ -15,11 +15,23 @@ export interface DownloadProgress {
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface ServerConfig {
|
||||
device: string;
|
||||
chunk_accum: number;
|
||||
prebuffer_secs: number;
|
||||
rebuffer_threshold_secs: number;
|
||||
resume_threshold_secs: number;
|
||||
default_inference_steps: number;
|
||||
}
|
||||
|
||||
interface AppState {
|
||||
script: string;
|
||||
speaker: string;
|
||||
cfgScale: number;
|
||||
inferenceSteps: number;
|
||||
prebufferSecs: number;
|
||||
rebufferThresholdSecs: number;
|
||||
resumeThresholdSecs: number;
|
||||
isGenerating: boolean;
|
||||
genElapsed: number;
|
||||
genPct: number | null;
|
||||
@@ -28,6 +40,7 @@ interface AppState {
|
||||
serverStatus: ServerStatus;
|
||||
downloadProgress: DownloadProgress | null;
|
||||
availableVoices: string[];
|
||||
serverConfig: ServerConfig | null;
|
||||
}
|
||||
|
||||
type AppAction =
|
||||
@@ -35,6 +48,9 @@ type AppAction =
|
||||
| { type: "SET_SPEAKER"; payload: string }
|
||||
| { type: "SET_CFG_SCALE"; payload: number }
|
||||
| { type: "SET_INFERENCE_STEPS"; payload: number }
|
||||
| { type: "SET_PREBUFFER_SECS"; payload: number }
|
||||
| { type: "SET_REBUFFER_THRESHOLD"; payload: number }
|
||||
| { type: "SET_RESUME_THRESHOLD"; payload: number }
|
||||
| { type: "START_GENERATION" }
|
||||
| { type: "GEN_PROGRESS"; elapsed: number; pct: number | null }
|
||||
| { type: "GENERATION_SUCCESS"; payload: string }
|
||||
@@ -43,7 +59,12 @@ type AppAction =
|
||||
| { type: "ADD_LOG"; payload: string }
|
||||
| {
|
||||
type: "SET_SERVER_STATUS";
|
||||
payload: { status: ServerStatus; progress?: DownloadProgress | null; voices?: string[] };
|
||||
payload: {
|
||||
status: ServerStatus;
|
||||
progress?: DownloadProgress | null;
|
||||
voices?: string[];
|
||||
config?: ServerConfig | null;
|
||||
};
|
||||
};
|
||||
|
||||
function reducer(state: AppState, action: AppAction): AppState {
|
||||
@@ -52,6 +73,9 @@ function reducer(state: AppState, action: AppAction): AppState {
|
||||
case "SET_SPEAKER": return { ...state, speaker: action.payload };
|
||||
case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload };
|
||||
case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload };
|
||||
case "SET_PREBUFFER_SECS": return { ...state, prebufferSecs: action.payload };
|
||||
case "SET_REBUFFER_THRESHOLD": return { ...state, rebufferThresholdSecs: action.payload };
|
||||
case "SET_RESUME_THRESHOLD": return { ...state, resumeThresholdSecs: action.payload };
|
||||
case "START_GENERATION":
|
||||
return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null };
|
||||
case "GEN_PROGRESS":
|
||||
@@ -63,14 +87,40 @@ function reducer(state: AppState, action: AppAction): AppState {
|
||||
return { ...state, isGenerating: false, genElapsed: 0, genPct: null };
|
||||
case "ADD_LOG":
|
||||
return { ...state, logs: [...state.logs, action.payload] };
|
||||
case "SET_SERVER_STATUS":
|
||||
case "SET_SERVER_STATUS": {
|
||||
const isNewConfig = !state.serverConfig && action.payload.config;
|
||||
const deviceChanged = !!(state.serverConfig && action.payload.config && state.serverConfig.device !== action.payload.config.device);
|
||||
|
||||
const nextSteps = (isNewConfig || deviceChanged)
|
||||
? action.payload.config!.default_inference_steps
|
||||
: state.inferenceSteps;
|
||||
|
||||
const nextPrebuffer = (isNewConfig || deviceChanged)
|
||||
? action.payload.config!.prebuffer_secs
|
||||
: state.prebufferSecs;
|
||||
|
||||
const nextRebuffer = (isNewConfig || deviceChanged)
|
||||
? action.payload.config!.rebuffer_threshold_secs
|
||||
: state.rebufferThresholdSecs;
|
||||
|
||||
const nextResume = (isNewConfig || deviceChanged)
|
||||
? action.payload.config!.resume_threshold_secs
|
||||
: state.resumeThresholdSecs;
|
||||
|
||||
return {
|
||||
...state,
|
||||
serverStatus: action.payload.status,
|
||||
downloadProgress: action.payload.progress ?? null,
|
||||
availableVoices:
|
||||
action.payload.voices?.length ? action.payload.voices : state.availableVoices,
|
||||
availableVoices: action.payload.voices?.length
|
||||
? action.payload.voices
|
||||
: state.availableVoices,
|
||||
serverConfig: action.payload.config ?? state.serverConfig,
|
||||
inferenceSteps: nextSteps,
|
||||
prebufferSecs: nextPrebuffer,
|
||||
rebufferThresholdSecs: nextRebuffer,
|
||||
resumeThresholdSecs: nextResume,
|
||||
};
|
||||
}
|
||||
default: return state;
|
||||
}
|
||||
}
|
||||
@@ -80,6 +130,9 @@ const initialState: AppState = {
|
||||
speaker: "carter",
|
||||
cfgScale: 1.5,
|
||||
inferenceSteps: 10,
|
||||
prebufferSecs: 2.0,
|
||||
rebufferThresholdSecs: 0.4,
|
||||
resumeThresholdSecs: 1.5,
|
||||
isGenerating: false,
|
||||
genElapsed: 0,
|
||||
genPct: null,
|
||||
@@ -88,6 +141,7 @@ const initialState: AppState = {
|
||||
serverStatus: "offline",
|
||||
downloadProgress: null,
|
||||
availableVoices: [],
|
||||
serverConfig: null,
|
||||
};
|
||||
|
||||
export default function HomePage() {
|
||||
@@ -106,19 +160,16 @@ export default function HomePage() {
|
||||
const handleGenerationCancel = useCallback(() => dispatch({ type: "GENERATION_CANCELLED" }), []);
|
||||
const handleGenerationError = useCallback(() => dispatch({ type: "GENERATION_ERROR" }), []);
|
||||
|
||||
const {
|
||||
generate,
|
||||
pauseStream,
|
||||
resumeStream,
|
||||
stop,
|
||||
isStreamPaused,
|
||||
} = useStreamingGeneration({
|
||||
const { generate, pauseStream, resumeStream, stop, isStreamPaused } = useStreamingGeneration({
|
||||
onLog: addLog,
|
||||
onStart: handleGenerationStart,
|
||||
onProgress: handleGenerationProgress,
|
||||
onSuccess: handleGenerationSuccess,
|
||||
onCancel: handleGenerationCancel,
|
||||
onError: handleGenerationError,
|
||||
prebufferSecs: state.prebufferSecs,
|
||||
rebufferThresholdSecs: state.rebufferThresholdSecs,
|
||||
resumeThresholdSecs: state.resumeThresholdSecs,
|
||||
});
|
||||
|
||||
// Server health polling — fast while not ready, slow when online
|
||||
@@ -131,21 +182,32 @@ export default function HomePage() {
|
||||
let nextStatus: ServerStatus = "offline";
|
||||
let nextProgress: DownloadProgress | null = null;
|
||||
let nextVoices: string[] = [];
|
||||
let nextConfig: ServerConfig | null = null;
|
||||
try {
|
||||
const res = await fetch("/api/health", { cache: "no-store" });
|
||||
const data = await res.json() as {
|
||||
const data = (await res.json()) as {
|
||||
status: ServerStatus;
|
||||
progress?: DownloadProgress | null;
|
||||
voices?: string[];
|
||||
config?: ServerConfig;
|
||||
};
|
||||
nextStatus = data.status ?? "offline";
|
||||
nextProgress = data.progress ?? null;
|
||||
nextVoices = data.voices ?? [];
|
||||
nextConfig = data.config ?? null;
|
||||
} catch {
|
||||
nextStatus = "offline";
|
||||
}
|
||||
if (!cancelled) {
|
||||
dispatch({ type: "SET_SERVER_STATUS", payload: { status: nextStatus, progress: nextProgress, voices: nextVoices } });
|
||||
dispatch({
|
||||
type: "SET_SERVER_STATUS",
|
||||
payload: {
|
||||
status: nextStatus,
|
||||
progress: nextProgress,
|
||||
voices: nextVoices,
|
||||
config: nextConfig,
|
||||
},
|
||||
});
|
||||
timeoutId = setTimeout(poll, nextStatus === "online" ? 15_000 : 2_000);
|
||||
}
|
||||
}
|
||||
@@ -199,6 +261,12 @@ export default function HomePage() {
|
||||
onCfgScaleChange={(v) => dispatch({ type: "SET_CFG_SCALE", payload: v })}
|
||||
inferenceSteps={state.inferenceSteps}
|
||||
onInferenceStepsChange={(v) => dispatch({ type: "SET_INFERENCE_STEPS", payload: v })}
|
||||
prebufferSecs={state.prebufferSecs}
|
||||
onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })}
|
||||
rebufferThresholdSecs={state.rebufferThresholdSecs}
|
||||
onRebufferThresholdChange={(v) => dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })}
|
||||
resumeThresholdSecs={state.resumeThresholdSecs}
|
||||
onResumeThresholdChange={(v) => dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })}
|
||||
onGenerate={handleGenerate}
|
||||
onStop={stop}
|
||||
onPauseStream={pauseStream}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { ServerStatus, DownloadProgress } from "@/app/page";
|
||||
|
||||
const FALLBACK_VOICES = ["carter", "davis", "emma", "frank", "grace", "mike"];
|
||||
@@ -12,6 +13,12 @@ interface GenerationControlsProps {
|
||||
onCfgScaleChange: (v: number) => void;
|
||||
inferenceSteps: number;
|
||||
onInferenceStepsChange: (v: number) => void;
|
||||
prebufferSecs: number;
|
||||
onPrebufferSecsChange: (v: number) => void;
|
||||
rebufferThresholdSecs: number;
|
||||
onRebufferThresholdChange: (v: number) => void;
|
||||
resumeThresholdSecs: number;
|
||||
onResumeThresholdChange: (v: number) => void;
|
||||
onGenerate: () => void;
|
||||
onStop: () => void;
|
||||
onPauseStream: () => void;
|
||||
@@ -53,6 +60,12 @@ export default function GenerationControls({
|
||||
onCfgScaleChange,
|
||||
inferenceSteps,
|
||||
onInferenceStepsChange,
|
||||
prebufferSecs,
|
||||
onPrebufferSecsChange,
|
||||
rebufferThresholdSecs,
|
||||
onRebufferThresholdChange,
|
||||
resumeThresholdSecs,
|
||||
onResumeThresholdChange,
|
||||
onGenerate,
|
||||
onStop,
|
||||
onPauseStream,
|
||||
@@ -65,6 +78,7 @@ export default function GenerationControls({
|
||||
serverStatus,
|
||||
downloadProgress,
|
||||
}: GenerationControlsProps) {
|
||||
const [showAdvanced, setShowAdvanced] = useState(false);
|
||||
const voices = availableVoices.length > 0 ? availableVoices : FALLBACK_VOICES;
|
||||
const serverReady = serverStatus === "online";
|
||||
const buttonDisabled = isGenerating || wordCount === 0 || !serverReady;
|
||||
@@ -169,6 +183,108 @@ export default function GenerationControls({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Advanced Buffering toggle */}
|
||||
<div className="pt-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowAdvanced(!showAdvanced)}
|
||||
aria-expanded={showAdvanced}
|
||||
aria-controls="advanced-buffering-panel"
|
||||
className="flex items-center gap-2 text-xs font-semibold uppercase tracking-wider cursor-pointer transition-colors"
|
||||
style={{ color: showAdvanced ? "var(--accent-teal)" : "var(--muted)" }}
|
||||
>
|
||||
<svg
|
||||
className={`w-3 h-3 transition-transform ${showAdvanced ? "rotate-90" : ""}`}
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="3"
|
||||
>
|
||||
<polyline points="9 18 15 12 9 6" />
|
||||
</svg>
|
||||
Advanced Buffering
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{showAdvanced && (
|
||||
<div id="advanced-buffering-panel" className="flex flex-col gap-4 pl-2 border-l" style={{ borderColor: "var(--border)" }}>
|
||||
{/* Pre-buffer */}
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
|
||||
Initial Pre-buffer
|
||||
</label>
|
||||
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
|
||||
{prebufferSecs.toFixed(1)}s
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min={0.5}
|
||||
max={10.0}
|
||||
step={0.5}
|
||||
value={prebufferSecs}
|
||||
onChange={(e) => onPrebufferSecsChange(parseFloat(e.target.value))}
|
||||
className="w-full h-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Re-buffer threshold */}
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label htmlFor="rebuffer-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
|
||||
Re-buffer Threshold
|
||||
</label>
|
||||
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
|
||||
{rebufferThresholdSecs.toFixed(1)}s
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
id="rebuffer-threshold"
|
||||
type="range"
|
||||
min={0.1}
|
||||
max={3.0}
|
||||
step={0.1}
|
||||
value={rebufferThresholdSecs}
|
||||
onChange={(e) => {
|
||||
const next = parseFloat(e.target.value);
|
||||
onRebufferThresholdChange(next);
|
||||
if (resumeThresholdSecs <= next) {
|
||||
onResumeThresholdChange(parseFloat((next + 0.5).toFixed(1)));
|
||||
}
|
||||
}}
|
||||
className="w-full h-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Resume threshold */}
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label htmlFor="resume-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
|
||||
Resume Threshold
|
||||
</label>
|
||||
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
|
||||
{resumeThresholdSecs.toFixed(1)}s
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
id="resume-threshold"
|
||||
type="range"
|
||||
min={0.5}
|
||||
max={5.0}
|
||||
step={0.1}
|
||||
value={resumeThresholdSecs}
|
||||
onChange={(e) => {
|
||||
const next = parseFloat(e.target.value);
|
||||
if (next <= rebufferThresholdSecs) return;
|
||||
onResumeThresholdChange(next);
|
||||
}}
|
||||
className="w-full h-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Server status banner */}
|
||||
{!serverReady && (
|
||||
<div
|
||||
@@ -177,7 +293,7 @@ export default function GenerationControls({
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span
|
||||
className={`w-2 h-2 rounded-full flex-shrink-0 ${serverStatus === "offline" || serverStatus === "error" ? "" : "animate-pulse"}`}
|
||||
className={`w-2 h-2 rounded-full shrink-0 ${serverStatus === "offline" || serverStatus === "error" ? "" : "animate-pulse"}`}
|
||||
style={{ background: STATUS_CONFIG[serverStatus].color }}
|
||||
/>
|
||||
<span style={{ color: STATUS_CONFIG[serverStatus].color }}>
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
const SAMPLE_RATE = 24_000;
|
||||
const PREBUFFER_SECS = 2.0;
|
||||
const REBUFFER_THRESHOLD_SECS = 0.4;
|
||||
const RESUME_THRESHOLD_SECS = 1.5;
|
||||
const DEFAULT_PREBUFFER_SECS = 2.0;
|
||||
const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4;
|
||||
const DEFAULT_RESUME_THRESHOLD_SECS = 1.5;
|
||||
|
||||
interface GenerateOptions {
|
||||
text: string;
|
||||
@@ -21,6 +21,12 @@ interface UseStreamingGenerationOptions {
|
||||
onSuccess: (audioUrl: string) => void;
|
||||
onCancel: () => void;
|
||||
onError: () => void;
|
||||
/** Seconds of audio to buffer before playback starts. */
|
||||
prebufferSecs?: number;
|
||||
/** Buffer lookahead (seconds) below which playback suspends to refill. */
|
||||
rebufferThresholdSecs?: number;
|
||||
/** Buffer lookahead (seconds) at or above which suspended playback resumes. Must be > rebufferThresholdSecs. */
|
||||
resumeThresholdSecs?: number;
|
||||
}
|
||||
|
||||
function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<ArrayBuffer> {
|
||||
@@ -77,7 +83,18 @@ export function useStreamingGeneration({
|
||||
onSuccess,
|
||||
onCancel,
|
||||
onError,
|
||||
prebufferSecs = DEFAULT_PREBUFFER_SECS,
|
||||
rebufferThresholdSecs: rawRebufferThresholdSecs = DEFAULT_REBUFFER_THRESHOLD_SECS,
|
||||
resumeThresholdSecs: rawResumeThresholdSecs = DEFAULT_RESUME_THRESHOLD_SECS,
|
||||
}: UseStreamingGenerationOptions) {
|
||||
let rebufferThresholdSecs = rawRebufferThresholdSecs;
|
||||
let resumeThresholdSecs = rawResumeThresholdSecs;
|
||||
if (resumeThresholdSecs <= rebufferThresholdSecs) {
|
||||
console.warn(
|
||||
`[useStreamingGeneration] resumeThresholdSecs (${resumeThresholdSecs}) must be greater than rebufferThresholdSecs (${rebufferThresholdSecs}). Clamping resumeThresholdSecs to ${rebufferThresholdSecs + 0.5}.`,
|
||||
);
|
||||
resumeThresholdSecs = rebufferThresholdSecs + 0.5;
|
||||
}
|
||||
const [isStreamPaused, setIsStreamPaused] = useState(false);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const audioCtxRef = useRef<AudioContext | null>(null);
|
||||
@@ -144,7 +161,7 @@ export function useStreamingGeneration({
|
||||
|
||||
if (!hasStartedPlaybackRef.current) {
|
||||
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
|
||||
if (bufferedSecs >= PREBUFFER_SECS) {
|
||||
if (bufferedSecs >= prebufferSecs) {
|
||||
flushBufferedAudio();
|
||||
}
|
||||
return;
|
||||
@@ -154,18 +171,18 @@ export function useStreamingGeneration({
|
||||
if (isUserPausedRef.current) return;
|
||||
|
||||
const ahead = nextStartTimeRef.current - ctx.currentTime;
|
||||
if (ctx.state === "running" && ahead < REBUFFER_THRESHOLD_SECS) {
|
||||
if (ctx.state === "running" && ahead < rebufferThresholdSecs) {
|
||||
ctx.suspend().catch(() => {});
|
||||
isAutoBufferingRef.current = true;
|
||||
} else if (
|
||||
ctx.state === "suspended" &&
|
||||
isAutoBufferingRef.current &&
|
||||
ahead >= RESUME_THRESHOLD_SECS
|
||||
ahead >= resumeThresholdSecs
|
||||
) {
|
||||
ctx.resume().catch(() => {});
|
||||
isAutoBufferingRef.current = false;
|
||||
}
|
||||
}, [enqueue, flushBufferedAudio]);
|
||||
}, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
|
||||
|
||||
const generate = useCallback(async (options: GenerateOptions) => {
|
||||
if (!options.text.trim()) return;
|
||||
|
||||
Reference in New Issue
Block a user