mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
Improve CPU Inference Stability: Adaptive Buffering & Chunk Accumulation (#11)
* Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation This change addresses audio stuttering issues when running on CPU-only hardware by: - Implementing server-side audio chunk accumulation to reduce SSE overhead. - Introducing device-aware default configurations for buffering and inference steps. - Exposing key performance parameters as environment variables. - Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration. Changes: - Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`. - Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters. - Updated `web/app/page.tsx` to fetch and apply server-side configuration. Verified on CPU mode in the development environment. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> * Improve CPU Inference Stability: Implement Adaptive Buffering and Chunk Accumulation This change addresses audio stuttering issues when running on CPU-only hardware by: - Implementing server-side audio chunk accumulation to reduce SSE overhead. - Introducing device-aware default configurations for buffering and inference steps. - Exposing key performance parameters as environment variables. - Enabling the frontend to adaptively adjust its buffering thresholds based on the server's configuration. Changes: - Modified `server/vibevoice_server.py` to support accumulation and provide config via `/health`. - Updated `web/hooks/useStreamingGeneration.ts` to accept configurable buffering parameters. - Updated `web/app/page.tsx` to fetch and apply server-side configuration. Verified on CPU mode in the development environment. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> * Improve CPU Inference Stability: Adaptive Buffering UI & Logic This change enhances the initial CPU stability fix by: - Exposing adaptive buffering settings (Pre-buffer, Re-buffer Threshold, Resume Threshold) in a new "Advanced Buffering" UI section. - Managing buffering settings in the application state to allow for manual overrides. - Implementing robust re-initialization of buffering and inference defaults whenever the server's device (CPU/CUDA) changes. - Including the active device in the server's config object for reliable client-side detection. Verified with frontend screenshots and full build. Responds to PR feedback regarding actioning the adaptive logic. Co-authored-by: LyAhn <27559362+LyAhn@users.noreply.github.com> * Refine adaptive buffering: env helpers, threshold validation, a11y fixes - Extract _env_int/_env_float helpers in server to validate env-var config with graceful fallback instead of bare int/float casts - Fix inference_steps falsy-check (0 is valid) to use explicit None guard - Enforce rebufferThresholdSecs < resumeThresholdSecs in both the hook (with console.warn + clamp) and the GenerationControls UI (sliders block invalid states by auto-bumping or ignoring the drag) - Add type="button", aria-expanded, aria-controls, htmlFor, and input id attributes to GenerationControls for accessibility - Add .vscode/settings.json to .gitignore; sort package.json scripts --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
This commit is contained in:
@@ -23,3 +23,4 @@ web/node_modules/
|
|||||||
# OS
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
Thumbs.db
|
Thumbs.db
|
||||||
|
.vscode/settings.json
|
||||||
|
|||||||
+2
-2
@@ -3,12 +3,12 @@
|
|||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"private": true,
|
"private": true,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
"build": "pnpm --filter vibepod-web build",
|
||||||
"dev": "bash dev.sh",
|
"dev": "bash dev.sh",
|
||||||
"dev:cpu": "bash dev.sh --cpu",
|
"dev:cpu": "bash dev.sh --cpu",
|
||||||
"dev:server": "bash server/start.sh",
|
"dev:server": "bash server/start.sh",
|
||||||
"dev:server:cpu": "bash server/start.sh --cpu",
|
"dev:server:cpu": "bash server/start.sh --cpu",
|
||||||
"dev:web": "pnpm --filter vibepod-web dev",
|
"dev:web": "pnpm --filter vibepod-web dev"
|
||||||
"build": "pnpm --filter vibepod-web build"
|
|
||||||
},
|
},
|
||||||
"packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8"
|
"packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ def download() -> str:
|
|||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||||
|
"HUGGINGFACE_TOKEN"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Checking / downloading model: {MODEL_ID}")
|
print(f"Checking / downloading model: {MODEL_ID}")
|
||||||
print("(This may take several minutes on first run — the model is ~1 GB)")
|
print("(This may take several minutes on first run — the model is ~1 GB)")
|
||||||
|
|||||||
+111
-25
@@ -52,11 +52,11 @@ VOICE_BASE_URL = (
|
|||||||
|
|
||||||
EN_VOICES: dict[str, str] = {
|
EN_VOICES: dict[str, str] = {
|
||||||
"carter": "en-Carter_man.pt",
|
"carter": "en-Carter_man.pt",
|
||||||
"davis": "en-Davis_man.pt",
|
"davis": "en-Davis_man.pt",
|
||||||
"emma": "en-Emma_woman.pt",
|
"emma": "en-Emma_woman.pt",
|
||||||
"frank": "en-Frank_man.pt",
|
"frank": "en-Frank_man.pt",
|
||||||
"grace": "en-Grace_woman.pt",
|
"grace": "en-Grace_woman.pt",
|
||||||
"mike": "en-Mike_man.pt",
|
"mike": "en-Mike_man.pt",
|
||||||
}
|
}
|
||||||
DEFAULT_SPEAKER = "carter"
|
DEFAULT_SPEAKER = "carter"
|
||||||
|
|
||||||
@@ -66,6 +66,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o
|
|||||||
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
||||||
# Falls back to auto-detection if not set.
|
# Falls back to auto-detection if not set.
|
||||||
|
|
||||||
|
|
||||||
def _resolve_device() -> str:
|
def _resolve_device() -> str:
|
||||||
"""Resolve the target device from env var or auto-detect."""
|
"""Resolve the target device from env var or auto-detect."""
|
||||||
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
||||||
@@ -80,6 +81,31 @@ def _resolve_device() -> str:
|
|||||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Env-var helpers ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _env_int(name: str, default: int) -> int:
|
||||||
|
raw = os.environ.get(name, "").strip()
|
||||||
|
if not raw:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(raw)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid value for %s=%r — using default %d", name, raw, default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _env_float(name: str, default: float) -> float:
|
||||||
|
raw = os.environ.get(name, "").strip()
|
||||||
|
if not raw:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid value for %s=%r — using default %g", name, raw, default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
# ── Global state ────────────────────────────────────────────────────────────────
|
# ── Global state ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
ModelStatus = Literal["downloading", "loading", "online", "error"]
|
ModelStatus = Literal["downloading", "loading", "online", "error"]
|
||||||
@@ -93,13 +119,24 @@ _voice_presets: dict[str, object] = {}
|
|||||||
_load_lock = threading.Lock()
|
_load_lock = threading.Lock()
|
||||||
_generation_lock = asyncio.Lock()
|
_generation_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Config defaults (can be overridden by env vars)
|
||||||
|
# These are populated in _load_model_sync once the device is known.
|
||||||
|
_config = {
|
||||||
|
"device": "cpu",
|
||||||
|
"chunk_accum": 1,
|
||||||
|
"prebuffer_secs": 2.0,
|
||||||
|
"rebuffer_threshold_secs": 0.4,
|
||||||
|
"resume_threshold_secs": 1.5,
|
||||||
|
"default_inference_steps": 10,
|
||||||
|
}
|
||||||
|
|
||||||
# Download progress (files downloaded so far)
|
# Download progress (files downloaded so far)
|
||||||
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
|
_dl_progress: dict[str, int] = {"done": 0, "total": 0}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
|
# ── Progress-tracking tqdm (for model file downloads) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _make_dl_tqdm() -> type:
|
def _make_dl_tqdm() -> type:
|
||||||
class _DlTqdm(_BaseTqdm):
|
class _DlTqdm(_BaseTqdm):
|
||||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||||
@@ -119,10 +156,14 @@ def _make_dl_tqdm() -> type:
|
|||||||
|
|
||||||
# ── Model / voice helpers ───────────────────────────────────────────────────────
|
# ── Model / voice helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _is_model_cached() -> bool:
|
def _is_model_cached() -> bool:
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
|
||||||
|
snapshot_download(
|
||||||
|
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@@ -130,7 +171,10 @@ def _is_model_cached() -> bool:
|
|||||||
|
|
||||||
def _download_model() -> None:
|
def _download_model() -> None:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
|
||||||
|
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||||
|
"HUGGINGFACE_TOKEN"
|
||||||
|
)
|
||||||
DlTqdm = _make_dl_tqdm()
|
DlTqdm = _make_dl_tqdm()
|
||||||
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
@@ -155,11 +199,13 @@ def _download_voices() -> None:
|
|||||||
|
|
||||||
# ── Background model loader ─────────────────────────────────────────────────────
|
# ── Background model loader ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _init_processor():
|
def _init_processor():
|
||||||
logger.info("Loading processor...")
|
logger.info("Loading processor...")
|
||||||
from vibevoice.processor.vibevoice_streaming_processor import (
|
from vibevoice.processor.vibevoice_streaming_processor import (
|
||||||
VibeVoiceStreamingProcessor,
|
VibeVoiceStreamingProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID)
|
||||||
|
|
||||||
|
|
||||||
@@ -171,6 +217,7 @@ def _init_model(device: str):
|
|||||||
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
|
||||||
VibeVoiceStreamingForConditionalGenerationInference,
|
VibeVoiceStreamingForConditionalGenerationInference,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
@@ -179,7 +226,9 @@ def _init_model(device: str):
|
|||||||
attn_implementation=attn_impl,
|
attn_implementation=attn_impl,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True)
|
logger.warning(
|
||||||
|
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True
|
||||||
|
)
|
||||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
torch_dtype=load_dtype,
|
torch_dtype=load_dtype,
|
||||||
@@ -188,7 +237,7 @@ def _init_model(device: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model.set_ddpm_inference_steps(num_steps=10)
|
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -197,14 +246,12 @@ def _load_voice_presets(device: str) -> dict[str, object]:
|
|||||||
for name, filename in EN_VOICES.items():
|
for name, filename in EN_VOICES.items():
|
||||||
path = VOICES_DIR / filename
|
path = VOICES_DIR / filename
|
||||||
if path.exists():
|
if path.exists():
|
||||||
presets[name] = torch.load(
|
presets[name] = torch.load(path, map_location=device, weights_only=False)
|
||||||
path, map_location=device, weights_only=False
|
|
||||||
)
|
|
||||||
return presets
|
return presets
|
||||||
|
|
||||||
|
|
||||||
def _load_model_sync() -> None:
|
def _load_model_sync() -> None:
|
||||||
global _processor, _model, _device, _model_status, _model_error, _voice_presets
|
global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config
|
||||||
|
|
||||||
with _load_lock:
|
with _load_lock:
|
||||||
if _model is not None:
|
if _model is not None:
|
||||||
@@ -222,12 +269,24 @@ def _load_model_sync() -> None:
|
|||||||
_device = _resolve_device()
|
_device = _resolve_device()
|
||||||
logger.info("Using device: %s", _device)
|
logger.info("Using device: %s", _device)
|
||||||
|
|
||||||
|
# Populate config based on device
|
||||||
|
is_cpu = _device == "cpu"
|
||||||
|
_config["device"] = _device
|
||||||
|
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
||||||
|
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0)
|
||||||
|
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4)
|
||||||
|
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5)
|
||||||
|
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
|
||||||
|
|
||||||
_processor = _init_processor()
|
_processor = _init_processor()
|
||||||
_model = _init_model(_device)
|
_model = _init_model(_device)
|
||||||
_voice_presets = _load_voice_presets(_device)
|
_voice_presets = _load_voice_presets(_device)
|
||||||
|
|
||||||
_model_status = "online"
|
_model_status = "online"
|
||||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
logger.info(
|
||||||
|
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
|
||||||
|
)
|
||||||
|
logger.info("Configuration: %s", _config)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_model_status = "error"
|
_model_status = "error"
|
||||||
@@ -237,6 +296,7 @@ def _load_model_sync() -> None:
|
|||||||
|
|
||||||
# ── FastAPI app ─────────────────────────────────────────────────────────────────
|
# ── FastAPI app ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
||||||
@@ -249,11 +309,12 @@ app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
|
|||||||
|
|
||||||
# ── Schemas ─────────────────────────────────────────────────────────────────────
|
# ── Schemas ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class GenerateRequest(BaseModel):
|
class GenerateRequest(BaseModel):
|
||||||
text: str = Field(..., min_length=1, max_length=10_000)
|
text: str = Field(..., min_length=1, max_length=10_000)
|
||||||
speaker: str = Field(default=DEFAULT_SPEAKER)
|
speaker: str = Field(default=DEFAULT_SPEAKER)
|
||||||
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
||||||
inference_steps: int = Field(default=10, ge=5, le=20)
|
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
|
||||||
|
|
||||||
@field_validator("text")
|
@field_validator("text")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -270,6 +331,7 @@ class GenerateRequest(BaseModel):
|
|||||||
|
|
||||||
# ── Endpoints ───────────────────────────────────────────────────────────────────
|
# ── Endpoints ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
body: dict = {
|
body: dict = {
|
||||||
@@ -277,9 +339,13 @@ async def health() -> dict:
|
|||||||
"model": MODEL_ID,
|
"model": MODEL_ID,
|
||||||
"device": _device,
|
"device": _device,
|
||||||
"voices": list(_voice_presets.keys()),
|
"voices": list(_voice_presets.keys()),
|
||||||
|
"config": _config,
|
||||||
}
|
}
|
||||||
if _model_status == "downloading":
|
if _model_status == "downloading":
|
||||||
body["progress"] = {"done": _dl_progress["done"], "total": _dl_progress["total"]}
|
body["progress"] = {
|
||||||
|
"done": _dl_progress["done"],
|
||||||
|
"total": _dl_progress["total"],
|
||||||
|
}
|
||||||
if _model_error:
|
if _model_error:
|
||||||
body["message"] = _model_error
|
body["message"] = _model_error
|
||||||
return body
|
return body
|
||||||
@@ -300,7 +366,8 @@ def _sync_generate(
|
|||||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||||
voice_preset = copy.deepcopy(_voice_presets[speaker])
|
voice_preset = copy.deepcopy(_voice_presets[speaker])
|
||||||
|
|
||||||
_model.set_ddpm_inference_steps(num_steps=req.inference_steps)
|
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
|
||||||
|
_model.set_ddpm_inference_steps(num_steps=steps)
|
||||||
|
|
||||||
inputs = _processor.process_input_with_cached_prompt(
|
inputs = _processor.process_input_with_cached_prompt(
|
||||||
text=req.text,
|
text=req.text,
|
||||||
@@ -339,13 +406,15 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
if _model_status != "online":
|
if _model_status != "online":
|
||||||
detail = {
|
detail = {
|
||||||
"downloading": "Model is downloading — please wait.",
|
"downloading": "Model is downloading — please wait.",
|
||||||
"loading": "Model is loading into memory — please wait.",
|
"loading": "Model is loading into memory — please wait.",
|
||||||
"error": f"Model failed to load: {_model_error or 'unknown error'}",
|
"error": f"Model failed to load: {_model_error or 'unknown error'}",
|
||||||
}.get(_model_status, "Server not ready.")
|
}.get(_model_status, "Server not ready.")
|
||||||
raise HTTPException(status_code=503, detail=detail)
|
raise HTTPException(status_code=503, detail=detail)
|
||||||
|
|
||||||
if _generation_lock.locked():
|
if _generation_lock.locked():
|
||||||
raise HTTPException(status_code=503, detail="Server is already generating audio. Please wait.")
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Server is already generating audio. Please wait."
|
||||||
|
)
|
||||||
|
|
||||||
async def event_stream() -> AsyncGenerator[str, None]:
|
async def event_stream() -> AsyncGenerator[str, None]:
|
||||||
from vibevoice.modular.streamer import AsyncAudioStreamer
|
from vibevoice.modular.streamer import AsyncAudioStreamer
|
||||||
@@ -354,6 +423,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
streamer = AsyncAudioStreamer(batch_size=1)
|
streamer = AsyncAudioStreamer(batch_size=1)
|
||||||
cancel_event = threading.Event()
|
cancel_event = threading.Event()
|
||||||
|
|
||||||
|
accum_size = max(1, _config["chunk_accum"])
|
||||||
|
accumulated_chunks = []
|
||||||
|
|
||||||
async with _generation_lock:
|
async with _generation_lock:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
future = loop.run_in_executor(
|
future = loop.run_in_executor(
|
||||||
@@ -382,9 +454,18 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
if chunk is None: # stop signal
|
if chunk is None: # stop signal
|
||||||
break
|
break
|
||||||
|
|
||||||
pcm_b64 = base64.b64encode(
|
accumulated_chunks.append(chunk.detach().cpu().float())
|
||||||
chunk.detach().cpu().float().numpy().tobytes()
|
|
||||||
).decode()
|
if len(accumulated_chunks) >= accum_size:
|
||||||
|
combined = torch.cat(accumulated_chunks, dim=0)
|
||||||
|
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||||
|
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||||
|
accumulated_chunks = []
|
||||||
|
|
||||||
|
# Flush any remaining chunks
|
||||||
|
if accumulated_chunks:
|
||||||
|
combined = torch.cat(accumulated_chunks, dim=0)
|
||||||
|
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||||
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
yield _sse({"type": "audio_chunk", "data": pcm_b64})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -395,7 +476,12 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
return
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Generation failed: %s", exc)
|
logger.exception("Generation failed: %s", exc)
|
||||||
yield _sse({"type": "error", "message": "Internal server error during generation."})
|
yield _sse(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"message": "Internal server error during generation.",
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
elapsed = round(time.monotonic() - start, 1)
|
elapsed = round(time.monotonic() - start, 1)
|
||||||
|
|||||||
+81
-13
@@ -15,11 +15,23 @@ export interface DownloadProgress {
|
|||||||
total: number;
|
total: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ServerConfig {
|
||||||
|
device: string;
|
||||||
|
chunk_accum: number;
|
||||||
|
prebuffer_secs: number;
|
||||||
|
rebuffer_threshold_secs: number;
|
||||||
|
resume_threshold_secs: number;
|
||||||
|
default_inference_steps: number;
|
||||||
|
}
|
||||||
|
|
||||||
interface AppState {
|
interface AppState {
|
||||||
script: string;
|
script: string;
|
||||||
speaker: string;
|
speaker: string;
|
||||||
cfgScale: number;
|
cfgScale: number;
|
||||||
inferenceSteps: number;
|
inferenceSteps: number;
|
||||||
|
prebufferSecs: number;
|
||||||
|
rebufferThresholdSecs: number;
|
||||||
|
resumeThresholdSecs: number;
|
||||||
isGenerating: boolean;
|
isGenerating: boolean;
|
||||||
genElapsed: number;
|
genElapsed: number;
|
||||||
genPct: number | null;
|
genPct: number | null;
|
||||||
@@ -28,6 +40,7 @@ interface AppState {
|
|||||||
serverStatus: ServerStatus;
|
serverStatus: ServerStatus;
|
||||||
downloadProgress: DownloadProgress | null;
|
downloadProgress: DownloadProgress | null;
|
||||||
availableVoices: string[];
|
availableVoices: string[];
|
||||||
|
serverConfig: ServerConfig | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppAction =
|
type AppAction =
|
||||||
@@ -35,6 +48,9 @@ type AppAction =
|
|||||||
| { type: "SET_SPEAKER"; payload: string }
|
| { type: "SET_SPEAKER"; payload: string }
|
||||||
| { type: "SET_CFG_SCALE"; payload: number }
|
| { type: "SET_CFG_SCALE"; payload: number }
|
||||||
| { type: "SET_INFERENCE_STEPS"; payload: number }
|
| { type: "SET_INFERENCE_STEPS"; payload: number }
|
||||||
|
| { type: "SET_PREBUFFER_SECS"; payload: number }
|
||||||
|
| { type: "SET_REBUFFER_THRESHOLD"; payload: number }
|
||||||
|
| { type: "SET_RESUME_THRESHOLD"; payload: number }
|
||||||
| { type: "START_GENERATION" }
|
| { type: "START_GENERATION" }
|
||||||
| { type: "GEN_PROGRESS"; elapsed: number; pct: number | null }
|
| { type: "GEN_PROGRESS"; elapsed: number; pct: number | null }
|
||||||
| { type: "GENERATION_SUCCESS"; payload: string }
|
| { type: "GENERATION_SUCCESS"; payload: string }
|
||||||
@@ -43,7 +59,12 @@ type AppAction =
|
|||||||
| { type: "ADD_LOG"; payload: string }
|
| { type: "ADD_LOG"; payload: string }
|
||||||
| {
|
| {
|
||||||
type: "SET_SERVER_STATUS";
|
type: "SET_SERVER_STATUS";
|
||||||
payload: { status: ServerStatus; progress?: DownloadProgress | null; voices?: string[] };
|
payload: {
|
||||||
|
status: ServerStatus;
|
||||||
|
progress?: DownloadProgress | null;
|
||||||
|
voices?: string[];
|
||||||
|
config?: ServerConfig | null;
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
function reducer(state: AppState, action: AppAction): AppState {
|
function reducer(state: AppState, action: AppAction): AppState {
|
||||||
@@ -52,6 +73,9 @@ function reducer(state: AppState, action: AppAction): AppState {
|
|||||||
case "SET_SPEAKER": return { ...state, speaker: action.payload };
|
case "SET_SPEAKER": return { ...state, speaker: action.payload };
|
||||||
case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload };
|
case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload };
|
||||||
case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload };
|
case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload };
|
||||||
|
case "SET_PREBUFFER_SECS": return { ...state, prebufferSecs: action.payload };
|
||||||
|
case "SET_REBUFFER_THRESHOLD": return { ...state, rebufferThresholdSecs: action.payload };
|
||||||
|
case "SET_RESUME_THRESHOLD": return { ...state, resumeThresholdSecs: action.payload };
|
||||||
case "START_GENERATION":
|
case "START_GENERATION":
|
||||||
return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null };
|
return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null };
|
||||||
case "GEN_PROGRESS":
|
case "GEN_PROGRESS":
|
||||||
@@ -63,14 +87,40 @@ function reducer(state: AppState, action: AppAction): AppState {
|
|||||||
return { ...state, isGenerating: false, genElapsed: 0, genPct: null };
|
return { ...state, isGenerating: false, genElapsed: 0, genPct: null };
|
||||||
case "ADD_LOG":
|
case "ADD_LOG":
|
||||||
return { ...state, logs: [...state.logs, action.payload] };
|
return { ...state, logs: [...state.logs, action.payload] };
|
||||||
case "SET_SERVER_STATUS":
|
case "SET_SERVER_STATUS": {
|
||||||
|
const isNewConfig = !state.serverConfig && action.payload.config;
|
||||||
|
const deviceChanged = !!(state.serverConfig && action.payload.config && state.serverConfig.device !== action.payload.config.device);
|
||||||
|
|
||||||
|
const nextSteps = (isNewConfig || deviceChanged)
|
||||||
|
? action.payload.config!.default_inference_steps
|
||||||
|
: state.inferenceSteps;
|
||||||
|
|
||||||
|
const nextPrebuffer = (isNewConfig || deviceChanged)
|
||||||
|
? action.payload.config!.prebuffer_secs
|
||||||
|
: state.prebufferSecs;
|
||||||
|
|
||||||
|
const nextRebuffer = (isNewConfig || deviceChanged)
|
||||||
|
? action.payload.config!.rebuffer_threshold_secs
|
||||||
|
: state.rebufferThresholdSecs;
|
||||||
|
|
||||||
|
const nextResume = (isNewConfig || deviceChanged)
|
||||||
|
? action.payload.config!.resume_threshold_secs
|
||||||
|
: state.resumeThresholdSecs;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...state,
|
...state,
|
||||||
serverStatus: action.payload.status,
|
serverStatus: action.payload.status,
|
||||||
downloadProgress: action.payload.progress ?? null,
|
downloadProgress: action.payload.progress ?? null,
|
||||||
availableVoices:
|
availableVoices: action.payload.voices?.length
|
||||||
action.payload.voices?.length ? action.payload.voices : state.availableVoices,
|
? action.payload.voices
|
||||||
|
: state.availableVoices,
|
||||||
|
serverConfig: action.payload.config ?? state.serverConfig,
|
||||||
|
inferenceSteps: nextSteps,
|
||||||
|
prebufferSecs: nextPrebuffer,
|
||||||
|
rebufferThresholdSecs: nextRebuffer,
|
||||||
|
resumeThresholdSecs: nextResume,
|
||||||
};
|
};
|
||||||
|
}
|
||||||
default: return state;
|
default: return state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -80,6 +130,9 @@ const initialState: AppState = {
|
|||||||
speaker: "carter",
|
speaker: "carter",
|
||||||
cfgScale: 1.5,
|
cfgScale: 1.5,
|
||||||
inferenceSteps: 10,
|
inferenceSteps: 10,
|
||||||
|
prebufferSecs: 2.0,
|
||||||
|
rebufferThresholdSecs: 0.4,
|
||||||
|
resumeThresholdSecs: 1.5,
|
||||||
isGenerating: false,
|
isGenerating: false,
|
||||||
genElapsed: 0,
|
genElapsed: 0,
|
||||||
genPct: null,
|
genPct: null,
|
||||||
@@ -88,6 +141,7 @@ const initialState: AppState = {
|
|||||||
serverStatus: "offline",
|
serverStatus: "offline",
|
||||||
downloadProgress: null,
|
downloadProgress: null,
|
||||||
availableVoices: [],
|
availableVoices: [],
|
||||||
|
serverConfig: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
export default function HomePage() {
|
export default function HomePage() {
|
||||||
@@ -106,19 +160,16 @@ export default function HomePage() {
|
|||||||
const handleGenerationCancel = useCallback(() => dispatch({ type: "GENERATION_CANCELLED" }), []);
|
const handleGenerationCancel = useCallback(() => dispatch({ type: "GENERATION_CANCELLED" }), []);
|
||||||
const handleGenerationError = useCallback(() => dispatch({ type: "GENERATION_ERROR" }), []);
|
const handleGenerationError = useCallback(() => dispatch({ type: "GENERATION_ERROR" }), []);
|
||||||
|
|
||||||
const {
|
const { generate, pauseStream, resumeStream, stop, isStreamPaused } = useStreamingGeneration({
|
||||||
generate,
|
|
||||||
pauseStream,
|
|
||||||
resumeStream,
|
|
||||||
stop,
|
|
||||||
isStreamPaused,
|
|
||||||
} = useStreamingGeneration({
|
|
||||||
onLog: addLog,
|
onLog: addLog,
|
||||||
onStart: handleGenerationStart,
|
onStart: handleGenerationStart,
|
||||||
onProgress: handleGenerationProgress,
|
onProgress: handleGenerationProgress,
|
||||||
onSuccess: handleGenerationSuccess,
|
onSuccess: handleGenerationSuccess,
|
||||||
onCancel: handleGenerationCancel,
|
onCancel: handleGenerationCancel,
|
||||||
onError: handleGenerationError,
|
onError: handleGenerationError,
|
||||||
|
prebufferSecs: state.prebufferSecs,
|
||||||
|
rebufferThresholdSecs: state.rebufferThresholdSecs,
|
||||||
|
resumeThresholdSecs: state.resumeThresholdSecs,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Server health polling — fast while not ready, slow when online
|
// Server health polling — fast while not ready, slow when online
|
||||||
@@ -131,21 +182,32 @@ export default function HomePage() {
|
|||||||
let nextStatus: ServerStatus = "offline";
|
let nextStatus: ServerStatus = "offline";
|
||||||
let nextProgress: DownloadProgress | null = null;
|
let nextProgress: DownloadProgress | null = null;
|
||||||
let nextVoices: string[] = [];
|
let nextVoices: string[] = [];
|
||||||
|
let nextConfig: ServerConfig | null = null;
|
||||||
try {
|
try {
|
||||||
const res = await fetch("/api/health", { cache: "no-store" });
|
const res = await fetch("/api/health", { cache: "no-store" });
|
||||||
const data = await res.json() as {
|
const data = (await res.json()) as {
|
||||||
status: ServerStatus;
|
status: ServerStatus;
|
||||||
progress?: DownloadProgress | null;
|
progress?: DownloadProgress | null;
|
||||||
voices?: string[];
|
voices?: string[];
|
||||||
|
config?: ServerConfig;
|
||||||
};
|
};
|
||||||
nextStatus = data.status ?? "offline";
|
nextStatus = data.status ?? "offline";
|
||||||
nextProgress = data.progress ?? null;
|
nextProgress = data.progress ?? null;
|
||||||
nextVoices = data.voices ?? [];
|
nextVoices = data.voices ?? [];
|
||||||
|
nextConfig = data.config ?? null;
|
||||||
} catch {
|
} catch {
|
||||||
nextStatus = "offline";
|
nextStatus = "offline";
|
||||||
}
|
}
|
||||||
if (!cancelled) {
|
if (!cancelled) {
|
||||||
dispatch({ type: "SET_SERVER_STATUS", payload: { status: nextStatus, progress: nextProgress, voices: nextVoices } });
|
dispatch({
|
||||||
|
type: "SET_SERVER_STATUS",
|
||||||
|
payload: {
|
||||||
|
status: nextStatus,
|
||||||
|
progress: nextProgress,
|
||||||
|
voices: nextVoices,
|
||||||
|
config: nextConfig,
|
||||||
|
},
|
||||||
|
});
|
||||||
timeoutId = setTimeout(poll, nextStatus === "online" ? 15_000 : 2_000);
|
timeoutId = setTimeout(poll, nextStatus === "online" ? 15_000 : 2_000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,6 +261,12 @@ export default function HomePage() {
|
|||||||
onCfgScaleChange={(v) => dispatch({ type: "SET_CFG_SCALE", payload: v })}
|
onCfgScaleChange={(v) => dispatch({ type: "SET_CFG_SCALE", payload: v })}
|
||||||
inferenceSteps={state.inferenceSteps}
|
inferenceSteps={state.inferenceSteps}
|
||||||
onInferenceStepsChange={(v) => dispatch({ type: "SET_INFERENCE_STEPS", payload: v })}
|
onInferenceStepsChange={(v) => dispatch({ type: "SET_INFERENCE_STEPS", payload: v })}
|
||||||
|
prebufferSecs={state.prebufferSecs}
|
||||||
|
onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })}
|
||||||
|
rebufferThresholdSecs={state.rebufferThresholdSecs}
|
||||||
|
onRebufferThresholdChange={(v) => dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })}
|
||||||
|
resumeThresholdSecs={state.resumeThresholdSecs}
|
||||||
|
onResumeThresholdChange={(v) => dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })}
|
||||||
onGenerate={handleGenerate}
|
onGenerate={handleGenerate}
|
||||||
onStop={stop}
|
onStop={stop}
|
||||||
onPauseStream={pauseStream}
|
onPauseStream={pauseStream}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
import type { ServerStatus, DownloadProgress } from "@/app/page";
|
import type { ServerStatus, DownloadProgress } from "@/app/page";
|
||||||
|
|
||||||
const FALLBACK_VOICES = ["carter", "davis", "emma", "frank", "grace", "mike"];
|
const FALLBACK_VOICES = ["carter", "davis", "emma", "frank", "grace", "mike"];
|
||||||
@@ -12,6 +13,12 @@ interface GenerationControlsProps {
|
|||||||
onCfgScaleChange: (v: number) => void;
|
onCfgScaleChange: (v: number) => void;
|
||||||
inferenceSteps: number;
|
inferenceSteps: number;
|
||||||
onInferenceStepsChange: (v: number) => void;
|
onInferenceStepsChange: (v: number) => void;
|
||||||
|
prebufferSecs: number;
|
||||||
|
onPrebufferSecsChange: (v: number) => void;
|
||||||
|
rebufferThresholdSecs: number;
|
||||||
|
onRebufferThresholdChange: (v: number) => void;
|
||||||
|
resumeThresholdSecs: number;
|
||||||
|
onResumeThresholdChange: (v: number) => void;
|
||||||
onGenerate: () => void;
|
onGenerate: () => void;
|
||||||
onStop: () => void;
|
onStop: () => void;
|
||||||
onPauseStream: () => void;
|
onPauseStream: () => void;
|
||||||
@@ -53,6 +60,12 @@ export default function GenerationControls({
|
|||||||
onCfgScaleChange,
|
onCfgScaleChange,
|
||||||
inferenceSteps,
|
inferenceSteps,
|
||||||
onInferenceStepsChange,
|
onInferenceStepsChange,
|
||||||
|
prebufferSecs,
|
||||||
|
onPrebufferSecsChange,
|
||||||
|
rebufferThresholdSecs,
|
||||||
|
onRebufferThresholdChange,
|
||||||
|
resumeThresholdSecs,
|
||||||
|
onResumeThresholdChange,
|
||||||
onGenerate,
|
onGenerate,
|
||||||
onStop,
|
onStop,
|
||||||
onPauseStream,
|
onPauseStream,
|
||||||
@@ -65,6 +78,7 @@ export default function GenerationControls({
|
|||||||
serverStatus,
|
serverStatus,
|
||||||
downloadProgress,
|
downloadProgress,
|
||||||
}: GenerationControlsProps) {
|
}: GenerationControlsProps) {
|
||||||
|
const [showAdvanced, setShowAdvanced] = useState(false);
|
||||||
const voices = availableVoices.length > 0 ? availableVoices : FALLBACK_VOICES;
|
const voices = availableVoices.length > 0 ? availableVoices : FALLBACK_VOICES;
|
||||||
const serverReady = serverStatus === "online";
|
const serverReady = serverStatus === "online";
|
||||||
const buttonDisabled = isGenerating || wordCount === 0 || !serverReady;
|
const buttonDisabled = isGenerating || wordCount === 0 || !serverReady;
|
||||||
@@ -169,6 +183,108 @@ export default function GenerationControls({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Advanced Buffering toggle */}
|
||||||
|
<div className="pt-2">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setShowAdvanced(!showAdvanced)}
|
||||||
|
aria-expanded={showAdvanced}
|
||||||
|
aria-controls="advanced-buffering-panel"
|
||||||
|
className="flex items-center gap-2 text-xs font-semibold uppercase tracking-wider cursor-pointer transition-colors"
|
||||||
|
style={{ color: showAdvanced ? "var(--accent-teal)" : "var(--muted)" }}
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
className={`w-3 h-3 transition-transform ${showAdvanced ? "rotate-90" : ""}`}
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="3"
|
||||||
|
>
|
||||||
|
<polyline points="9 18 15 12 9 6" />
|
||||||
|
</svg>
|
||||||
|
Advanced Buffering
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{showAdvanced && (
|
||||||
|
<div id="advanced-buffering-panel" className="flex flex-col gap-4 pl-2 border-l" style={{ borderColor: "var(--border)" }}>
|
||||||
|
{/* Pre-buffer */}
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<label className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
|
||||||
|
Initial Pre-buffer
|
||||||
|
</label>
|
||||||
|
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
|
||||||
|
{prebufferSecs.toFixed(1)}s
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
type="range"
|
||||||
|
min={0.5}
|
||||||
|
max={10.0}
|
||||||
|
step={0.5}
|
||||||
|
value={prebufferSecs}
|
||||||
|
onChange={(e) => onPrebufferSecsChange(parseFloat(e.target.value))}
|
||||||
|
className="w-full h-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Re-buffer threshold */}
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<label htmlFor="rebuffer-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
|
||||||
|
Re-buffer Threshold
|
||||||
|
</label>
|
||||||
|
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
|
||||||
|
{rebufferThresholdSecs.toFixed(1)}s
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
id="rebuffer-threshold"
|
||||||
|
type="range"
|
||||||
|
min={0.1}
|
||||||
|
max={3.0}
|
||||||
|
step={0.1}
|
||||||
|
value={rebufferThresholdSecs}
|
||||||
|
onChange={(e) => {
|
||||||
|
const next = parseFloat(e.target.value);
|
||||||
|
onRebufferThresholdChange(next);
|
||||||
|
if (resumeThresholdSecs <= next) {
|
||||||
|
onResumeThresholdChange(parseFloat((next + 0.5).toFixed(1)));
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className="w-full h-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Resume threshold */}
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<label htmlFor="resume-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
|
||||||
|
Resume Threshold
|
||||||
|
</label>
|
||||||
|
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
|
||||||
|
{resumeThresholdSecs.toFixed(1)}s
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
id="resume-threshold"
|
||||||
|
type="range"
|
||||||
|
min={0.5}
|
||||||
|
max={5.0}
|
||||||
|
step={0.1}
|
||||||
|
value={resumeThresholdSecs}
|
||||||
|
onChange={(e) => {
|
||||||
|
const next = parseFloat(e.target.value);
|
||||||
|
if (next <= rebufferThresholdSecs) return;
|
||||||
|
onResumeThresholdChange(next);
|
||||||
|
}}
|
||||||
|
className="w-full h-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Server status banner */}
|
{/* Server status banner */}
|
||||||
{!serverReady && (
|
{!serverReady && (
|
||||||
<div
|
<div
|
||||||
@@ -177,7 +293,7 @@ export default function GenerationControls({
|
|||||||
>
|
>
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<span
|
<span
|
||||||
className={`w-2 h-2 rounded-full flex-shrink-0 ${serverStatus === "offline" || serverStatus === "error" ? "" : "animate-pulse"}`}
|
className={`w-2 h-2 rounded-full shrink-0 ${serverStatus === "offline" || serverStatus === "error" ? "" : "animate-pulse"}`}
|
||||||
style={{ background: STATUS_CONFIG[serverStatus].color }}
|
style={{ background: STATUS_CONFIG[serverStatus].color }}
|
||||||
/>
|
/>
|
||||||
<span style={{ color: STATUS_CONFIG[serverStatus].color }}>
|
<span style={{ color: STATUS_CONFIG[serverStatus].color }}>
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
const SAMPLE_RATE = 24_000;
|
const SAMPLE_RATE = 24_000;
|
||||||
const PREBUFFER_SECS = 2.0;
|
const DEFAULT_PREBUFFER_SECS = 2.0;
|
||||||
const REBUFFER_THRESHOLD_SECS = 0.4;
|
const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4;
|
||||||
const RESUME_THRESHOLD_SECS = 1.5;
|
const DEFAULT_RESUME_THRESHOLD_SECS = 1.5;
|
||||||
|
|
||||||
interface GenerateOptions {
|
interface GenerateOptions {
|
||||||
text: string;
|
text: string;
|
||||||
@@ -21,6 +21,12 @@ interface UseStreamingGenerationOptions {
|
|||||||
onSuccess: (audioUrl: string) => void;
|
onSuccess: (audioUrl: string) => void;
|
||||||
onCancel: () => void;
|
onCancel: () => void;
|
||||||
onError: () => void;
|
onError: () => void;
|
||||||
|
/** Seconds of audio to buffer before playback starts. */
|
||||||
|
prebufferSecs?: number;
|
||||||
|
/** Buffer lookahead (seconds) below which playback suspends to refill. */
|
||||||
|
rebufferThresholdSecs?: number;
|
||||||
|
/** Buffer lookahead (seconds) at or above which suspended playback resumes. Must be > rebufferThresholdSecs. */
|
||||||
|
resumeThresholdSecs?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<ArrayBuffer> {
|
function mergeFloat32Arrays(chunks: Float32Array<ArrayBuffer>[]): Float32Array<ArrayBuffer> {
|
||||||
@@ -77,7 +83,18 @@ export function useStreamingGeneration({
|
|||||||
onSuccess,
|
onSuccess,
|
||||||
onCancel,
|
onCancel,
|
||||||
onError,
|
onError,
|
||||||
|
prebufferSecs = DEFAULT_PREBUFFER_SECS,
|
||||||
|
rebufferThresholdSecs: rawRebufferThresholdSecs = DEFAULT_REBUFFER_THRESHOLD_SECS,
|
||||||
|
resumeThresholdSecs: rawResumeThresholdSecs = DEFAULT_RESUME_THRESHOLD_SECS,
|
||||||
}: UseStreamingGenerationOptions) {
|
}: UseStreamingGenerationOptions) {
|
||||||
|
let rebufferThresholdSecs = rawRebufferThresholdSecs;
|
||||||
|
let resumeThresholdSecs = rawResumeThresholdSecs;
|
||||||
|
if (resumeThresholdSecs <= rebufferThresholdSecs) {
|
||||||
|
console.warn(
|
||||||
|
`[useStreamingGeneration] resumeThresholdSecs (${resumeThresholdSecs}) must be greater than rebufferThresholdSecs (${rebufferThresholdSecs}). Clamping resumeThresholdSecs to ${rebufferThresholdSecs + 0.5}.`,
|
||||||
|
);
|
||||||
|
resumeThresholdSecs = rebufferThresholdSecs + 0.5;
|
||||||
|
}
|
||||||
const [isStreamPaused, setIsStreamPaused] = useState(false);
|
const [isStreamPaused, setIsStreamPaused] = useState(false);
|
||||||
const abortRef = useRef<AbortController | null>(null);
|
const abortRef = useRef<AbortController | null>(null);
|
||||||
const audioCtxRef = useRef<AudioContext | null>(null);
|
const audioCtxRef = useRef<AudioContext | null>(null);
|
||||||
@@ -144,7 +161,7 @@ export function useStreamingGeneration({
|
|||||||
|
|
||||||
if (!hasStartedPlaybackRef.current) {
|
if (!hasStartedPlaybackRef.current) {
|
||||||
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
|
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
|
||||||
if (bufferedSecs >= PREBUFFER_SECS) {
|
if (bufferedSecs >= prebufferSecs) {
|
||||||
flushBufferedAudio();
|
flushBufferedAudio();
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -154,18 +171,18 @@ export function useStreamingGeneration({
|
|||||||
if (isUserPausedRef.current) return;
|
if (isUserPausedRef.current) return;
|
||||||
|
|
||||||
const ahead = nextStartTimeRef.current - ctx.currentTime;
|
const ahead = nextStartTimeRef.current - ctx.currentTime;
|
||||||
if (ctx.state === "running" && ahead < REBUFFER_THRESHOLD_SECS) {
|
if (ctx.state === "running" && ahead < rebufferThresholdSecs) {
|
||||||
ctx.suspend().catch(() => {});
|
ctx.suspend().catch(() => {});
|
||||||
isAutoBufferingRef.current = true;
|
isAutoBufferingRef.current = true;
|
||||||
} else if (
|
} else if (
|
||||||
ctx.state === "suspended" &&
|
ctx.state === "suspended" &&
|
||||||
isAutoBufferingRef.current &&
|
isAutoBufferingRef.current &&
|
||||||
ahead >= RESUME_THRESHOLD_SECS
|
ahead >= resumeThresholdSecs
|
||||||
) {
|
) {
|
||||||
ctx.resume().catch(() => {});
|
ctx.resume().catch(() => {});
|
||||||
isAutoBufferingRef.current = false;
|
isAutoBufferingRef.current = false;
|
||||||
}
|
}
|
||||||
}, [enqueue, flushBufferedAudio]);
|
}, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
|
||||||
|
|
||||||
const generate = useCallback(async (options: GenerateOptions) => {
|
const generate = useCallback(async (options: GenerateOptions) => {
|
||||||
if (!options.text.trim()) return;
|
if (!options.text.trim()) return;
|
||||||
|
|||||||
Reference in New Issue
Block a user