diff --git a/.gitignore b/.gitignore index ca177b1..13db197 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ web/node_modules/ # OS .DS_Store Thumbs.db +.vscode/settings.json diff --git a/package.json b/package.json index 90fb2b5..034f905 100644 --- a/package.json +++ b/package.json @@ -3,12 +3,12 @@ "version": "0.1.0", "private": true, "scripts": { + "build": "pnpm --filter vibepod-web build", "dev": "bash dev.sh", "dev:cpu": "bash dev.sh --cpu", "dev:server": "bash server/start.sh", "dev:server:cpu": "bash server/start.sh --cpu", - "dev:web": "pnpm --filter vibepod-web dev", - "build": "pnpm --filter vibepod-web build" + "dev:web": "pnpm --filter vibepod-web dev" }, "packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8" } diff --git a/server/download_model.py b/server/download_model.py index bf11927..1377d17 100644 --- a/server/download_model.py +++ b/server/download_model.py @@ -36,7 +36,9 @@ def download() -> str: ) sys.exit(1) - token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") + token: str | None = os.environ.get("HF_TOKEN") or os.environ.get( + "HUGGINGFACE_TOKEN" + ) print(f"Checking / downloading model: {MODEL_ID}") print("(This may take several minutes on first run — the model is ~1 GB)") diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 9523eba..f5bc012 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -52,11 +52,11 @@ VOICE_BASE_URL = ( EN_VOICES: dict[str, str] = { "carter": "en-Carter_man.pt", - "davis": "en-Davis_man.pt", - "emma": "en-Emma_woman.pt", - "frank": "en-Frank_man.pt", - "grace": "en-Grace_woman.pt", - "mike": "en-Mike_man.pt", + "davis": "en-Davis_man.pt", + "emma": "en-Emma_woman.pt", + "frank": "en-Frank_man.pt", + "grace": "en-Grace_woman.pt", + "mike": "en-Mike_man.pt", } DEFAULT_SPEAKER = "carter" @@ -66,6 +66,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. # Falls back to auto-detection if not set. + def _resolve_device() -> str: """Resolve the target device from env var or auto-detect.""" env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower() @@ -80,6 +81,31 @@ def _resolve_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" +# ── Env-var helpers ───────────────────────────────────────────────────────────── + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name, "").strip() + if not raw: + return default + try: + return int(raw) + except ValueError: + logger.warning("Invalid value for %s=%r — using default %d", name, raw, default) + return default + + +def _env_float(name: str, default: float) -> float: + raw = os.environ.get(name, "").strip() + if not raw: + return default + try: + return float(raw) + except ValueError: + logger.warning("Invalid value for %s=%r — using default %g", name, raw, default) + return default + + # ── Global state ──────────────────────────────────────────────────────────────── ModelStatus = Literal["downloading", "loading", "online", "error"] @@ -93,13 +119,24 @@ _voice_presets: dict[str, object] = {} _load_lock = threading.Lock() _generation_lock = asyncio.Lock() +# Config defaults (can be overridden by env vars) +# These are populated in _load_model_sync once the device is known. +_config = { + "device": "cpu", + "chunk_accum": 1, + "prebuffer_secs": 2.0, + "rebuffer_threshold_secs": 0.4, + "resume_threshold_secs": 1.5, + "default_inference_steps": 10, +} + # Download progress (files downloaded so far) _dl_progress: dict[str, int] = {"done": 0, "total": 0} - # ── Progress-tracking tqdm (for model file downloads) ────────────────────────── + def _make_dl_tqdm() -> type: class _DlTqdm(_BaseTqdm): def __init__(self, *args: object, **kwargs: object) -> None: @@ -119,10 +156,14 @@ def _make_dl_tqdm() -> type: # ── Model / voice helpers ─────────────────────────────────────────────────────── + def _is_model_cached() -> bool: try: from huggingface_hub import snapshot_download - snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS) + + snapshot_download( + MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS + ) return True except Exception: return False @@ -130,7 +171,10 @@ def _is_model_cached() -> bool: def _download_model() -> None: from huggingface_hub import snapshot_download - token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") + + token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get( + "HUGGINGFACE_TOKEN" + ) DlTqdm = _make_dl_tqdm() logger.info("Model not cached — downloading %s...", MODEL_ID) snapshot_download( @@ -155,11 +199,13 @@ def _download_voices() -> None: # ── Background model loader ───────────────────────────────────────────────────── + def _init_processor(): logger.info("Loading processor...") from vibevoice.processor.vibevoice_streaming_processor import ( VibeVoiceStreamingProcessor, ) + return VibeVoiceStreamingProcessor.from_pretrained(MODEL_ID) @@ -171,6 +217,7 @@ def _init_model(device: str): from vibevoice.modular.modeling_vibevoice_streaming_inference import ( VibeVoiceStreamingForConditionalGenerationInference, ) + try: model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, @@ -179,7 +226,9 @@ def _init_model(device: str): attn_implementation=attn_impl, ) except Exception: - logger.warning("Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True) + logger.warning( + "Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True + ) model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, torch_dtype=load_dtype, @@ -188,7 +237,7 @@ def _init_model(device: str): ) model.eval() - model.set_ddpm_inference_steps(num_steps=10) + model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"]) return model @@ -197,14 +246,12 @@ def _load_voice_presets(device: str) -> dict[str, object]: for name, filename in EN_VOICES.items(): path = VOICES_DIR / filename if path.exists(): - presets[name] = torch.load( - path, map_location=device, weights_only=False - ) + presets[name] = torch.load(path, map_location=device, weights_only=False) return presets def _load_model_sync() -> None: - global _processor, _model, _device, _model_status, _model_error, _voice_presets + global _processor, _model, _device, _model_status, _model_error, _voice_presets, _config with _load_lock: if _model is not None: @@ -222,12 +269,24 @@ def _load_model_sync() -> None: _device = _resolve_device() logger.info("Using device: %s", _device) + # Populate config based on device + is_cpu = _device == "cpu" + _config["device"] = _device + _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) + _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0) + _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4) + _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5) + _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) + _processor = _init_processor() _model = _init_model(_device) _voice_presets = _load_voice_presets(_device) _model_status = "online" - logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())) + logger.info( + "Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()) + ) + logger.info("Configuration: %s", _config) except Exception as exc: _model_status = "error" @@ -237,6 +296,7 @@ def _load_model_sync() -> None: # ── FastAPI app ───────────────────────────────────────────────────────────────── + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader") @@ -249,11 +309,12 @@ app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan) # ── Schemas ───────────────────────────────────────────────────────────────────── + class GenerateRequest(BaseModel): text: str = Field(..., min_length=1, max_length=10_000) speaker: str = Field(default=DEFAULT_SPEAKER) cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0) - inference_steps: int = Field(default=10, ge=5, le=20) + inference_steps: Optional[int] = Field(default=None, ge=5, le=20) @field_validator("text") @classmethod @@ -270,6 +331,7 @@ class GenerateRequest(BaseModel): # ── Endpoints ─────────────────────────────────────────────────────────────────── + @app.get("/health") async def health() -> dict: body: dict = { @@ -277,9 +339,13 @@ async def health() -> dict: "model": MODEL_ID, "device": _device, "voices": list(_voice_presets.keys()), + "config": _config, } if _model_status == "downloading": - body["progress"] = {"done": _dl_progress["done"], "total": _dl_progress["total"]} + body["progress"] = { + "done": _dl_progress["done"], + "total": _dl_progress["total"], + } if _model_error: body["message"] = _model_error return body @@ -300,7 +366,8 @@ def _sync_generate( speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER voice_preset = copy.deepcopy(_voice_presets[speaker]) - _model.set_ddpm_inference_steps(num_steps=req.inference_steps) + steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"] + _model.set_ddpm_inference_steps(num_steps=steps) inputs = _processor.process_input_with_cached_prompt( text=req.text, @@ -339,13 +406,15 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if _model_status != "online": detail = { "downloading": "Model is downloading — please wait.", - "loading": "Model is loading into memory — please wait.", - "error": f"Model failed to load: {_model_error or 'unknown error'}", + "loading": "Model is loading into memory — please wait.", + "error": f"Model failed to load: {_model_error or 'unknown error'}", }.get(_model_status, "Server not ready.") raise HTTPException(status_code=503, detail=detail) if _generation_lock.locked(): - raise HTTPException(status_code=503, detail="Server is already generating audio. Please wait.") + raise HTTPException( + status_code=503, detail="Server is already generating audio. Please wait." + ) async def event_stream() -> AsyncGenerator[str, None]: from vibevoice.modular.streamer import AsyncAudioStreamer @@ -354,6 +423,9 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: streamer = AsyncAudioStreamer(batch_size=1) cancel_event = threading.Event() + accum_size = max(1, _config["chunk_accum"]) + accumulated_chunks = [] + async with _generation_lock: loop = asyncio.get_event_loop() future = loop.run_in_executor( @@ -382,9 +454,18 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if chunk is None: # stop signal break - pcm_b64 = base64.b64encode( - chunk.detach().cpu().float().numpy().tobytes() - ).decode() + accumulated_chunks.append(chunk.detach().cpu().float()) + + if len(accumulated_chunks) >= accum_size: + combined = torch.cat(accumulated_chunks, dim=0) + pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() + yield _sse({"type": "audio_chunk", "data": pcm_b64}) + accumulated_chunks = [] + + # Flush any remaining chunks + if accumulated_chunks: + combined = torch.cat(accumulated_chunks, dim=0) + pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() yield _sse({"type": "audio_chunk", "data": pcm_b64}) try: @@ -395,7 +476,12 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: return except Exception as exc: logger.exception("Generation failed: %s", exc) - yield _sse({"type": "error", "message": "Internal server error during generation."}) + yield _sse( + { + "type": "error", + "message": "Internal server error during generation.", + } + ) return elapsed = round(time.monotonic() - start, 1) diff --git a/web/app/page.tsx b/web/app/page.tsx index 464557d..275d658 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -15,11 +15,23 @@ export interface DownloadProgress { total: number; } +export interface ServerConfig { + device: string; + chunk_accum: number; + prebuffer_secs: number; + rebuffer_threshold_secs: number; + resume_threshold_secs: number; + default_inference_steps: number; +} + interface AppState { script: string; speaker: string; cfgScale: number; inferenceSteps: number; + prebufferSecs: number; + rebufferThresholdSecs: number; + resumeThresholdSecs: number; isGenerating: boolean; genElapsed: number; genPct: number | null; @@ -28,6 +40,7 @@ interface AppState { serverStatus: ServerStatus; downloadProgress: DownloadProgress | null; availableVoices: string[]; + serverConfig: ServerConfig | null; } type AppAction = @@ -35,6 +48,9 @@ type AppAction = | { type: "SET_SPEAKER"; payload: string } | { type: "SET_CFG_SCALE"; payload: number } | { type: "SET_INFERENCE_STEPS"; payload: number } + | { type: "SET_PREBUFFER_SECS"; payload: number } + | { type: "SET_REBUFFER_THRESHOLD"; payload: number } + | { type: "SET_RESUME_THRESHOLD"; payload: number } | { type: "START_GENERATION" } | { type: "GEN_PROGRESS"; elapsed: number; pct: number | null } | { type: "GENERATION_SUCCESS"; payload: string } @@ -43,7 +59,12 @@ type AppAction = | { type: "ADD_LOG"; payload: string } | { type: "SET_SERVER_STATUS"; - payload: { status: ServerStatus; progress?: DownloadProgress | null; voices?: string[] }; + payload: { + status: ServerStatus; + progress?: DownloadProgress | null; + voices?: string[]; + config?: ServerConfig | null; + }; }; function reducer(state: AppState, action: AppAction): AppState { @@ -52,6 +73,9 @@ function reducer(state: AppState, action: AppAction): AppState { case "SET_SPEAKER": return { ...state, speaker: action.payload }; case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload }; case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload }; + case "SET_PREBUFFER_SECS": return { ...state, prebufferSecs: action.payload }; + case "SET_REBUFFER_THRESHOLD": return { ...state, rebufferThresholdSecs: action.payload }; + case "SET_RESUME_THRESHOLD": return { ...state, resumeThresholdSecs: action.payload }; case "START_GENERATION": return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null }; case "GEN_PROGRESS": @@ -63,14 +87,40 @@ function reducer(state: AppState, action: AppAction): AppState { return { ...state, isGenerating: false, genElapsed: 0, genPct: null }; case "ADD_LOG": return { ...state, logs: [...state.logs, action.payload] }; - case "SET_SERVER_STATUS": + case "SET_SERVER_STATUS": { + const isNewConfig = !state.serverConfig && action.payload.config; + const deviceChanged = !!(state.serverConfig && action.payload.config && state.serverConfig.device !== action.payload.config.device); + + const nextSteps = (isNewConfig || deviceChanged) + ? action.payload.config!.default_inference_steps + : state.inferenceSteps; + + const nextPrebuffer = (isNewConfig || deviceChanged) + ? action.payload.config!.prebuffer_secs + : state.prebufferSecs; + + const nextRebuffer = (isNewConfig || deviceChanged) + ? action.payload.config!.rebuffer_threshold_secs + : state.rebufferThresholdSecs; + + const nextResume = (isNewConfig || deviceChanged) + ? action.payload.config!.resume_threshold_secs + : state.resumeThresholdSecs; + return { ...state, serverStatus: action.payload.status, downloadProgress: action.payload.progress ?? null, - availableVoices: - action.payload.voices?.length ? action.payload.voices : state.availableVoices, + availableVoices: action.payload.voices?.length + ? action.payload.voices + : state.availableVoices, + serverConfig: action.payload.config ?? state.serverConfig, + inferenceSteps: nextSteps, + prebufferSecs: nextPrebuffer, + rebufferThresholdSecs: nextRebuffer, + resumeThresholdSecs: nextResume, }; + } default: return state; } } @@ -80,6 +130,9 @@ const initialState: AppState = { speaker: "carter", cfgScale: 1.5, inferenceSteps: 10, + prebufferSecs: 2.0, + rebufferThresholdSecs: 0.4, + resumeThresholdSecs: 1.5, isGenerating: false, genElapsed: 0, genPct: null, @@ -88,6 +141,7 @@ const initialState: AppState = { serverStatus: "offline", downloadProgress: null, availableVoices: [], + serverConfig: null, }; export default function HomePage() { @@ -106,19 +160,16 @@ export default function HomePage() { const handleGenerationCancel = useCallback(() => dispatch({ type: "GENERATION_CANCELLED" }), []); const handleGenerationError = useCallback(() => dispatch({ type: "GENERATION_ERROR" }), []); - const { - generate, - pauseStream, - resumeStream, - stop, - isStreamPaused, - } = useStreamingGeneration({ + const { generate, pauseStream, resumeStream, stop, isStreamPaused } = useStreamingGeneration({ onLog: addLog, onStart: handleGenerationStart, onProgress: handleGenerationProgress, onSuccess: handleGenerationSuccess, onCancel: handleGenerationCancel, onError: handleGenerationError, + prebufferSecs: state.prebufferSecs, + rebufferThresholdSecs: state.rebufferThresholdSecs, + resumeThresholdSecs: state.resumeThresholdSecs, }); // Server health polling — fast while not ready, slow when online @@ -131,21 +182,32 @@ export default function HomePage() { let nextStatus: ServerStatus = "offline"; let nextProgress: DownloadProgress | null = null; let nextVoices: string[] = []; + let nextConfig: ServerConfig | null = null; try { const res = await fetch("/api/health", { cache: "no-store" }); - const data = await res.json() as { + const data = (await res.json()) as { status: ServerStatus; progress?: DownloadProgress | null; voices?: string[]; + config?: ServerConfig; }; nextStatus = data.status ?? "offline"; nextProgress = data.progress ?? null; nextVoices = data.voices ?? []; + nextConfig = data.config ?? null; } catch { nextStatus = "offline"; } if (!cancelled) { - dispatch({ type: "SET_SERVER_STATUS", payload: { status: nextStatus, progress: nextProgress, voices: nextVoices } }); + dispatch({ + type: "SET_SERVER_STATUS", + payload: { + status: nextStatus, + progress: nextProgress, + voices: nextVoices, + config: nextConfig, + }, + }); timeoutId = setTimeout(poll, nextStatus === "online" ? 15_000 : 2_000); } } @@ -199,6 +261,12 @@ export default function HomePage() { onCfgScaleChange={(v) => dispatch({ type: "SET_CFG_SCALE", payload: v })} inferenceSteps={state.inferenceSteps} onInferenceStepsChange={(v) => dispatch({ type: "SET_INFERENCE_STEPS", payload: v })} + prebufferSecs={state.prebufferSecs} + onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })} + rebufferThresholdSecs={state.rebufferThresholdSecs} + onRebufferThresholdChange={(v) => dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })} + resumeThresholdSecs={state.resumeThresholdSecs} + onResumeThresholdChange={(v) => dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })} onGenerate={handleGenerate} onStop={stop} onPauseStream={pauseStream} diff --git a/web/components/GenerationControls.tsx b/web/components/GenerationControls.tsx index e0c8105..f9a7d4c 100644 --- a/web/components/GenerationControls.tsx +++ b/web/components/GenerationControls.tsx @@ -1,5 +1,6 @@ "use client"; +import { useState } from "react"; import type { ServerStatus, DownloadProgress } from "@/app/page"; const FALLBACK_VOICES = ["carter", "davis", "emma", "frank", "grace", "mike"]; @@ -12,6 +13,12 @@ interface GenerationControlsProps { onCfgScaleChange: (v: number) => void; inferenceSteps: number; onInferenceStepsChange: (v: number) => void; + prebufferSecs: number; + onPrebufferSecsChange: (v: number) => void; + rebufferThresholdSecs: number; + onRebufferThresholdChange: (v: number) => void; + resumeThresholdSecs: number; + onResumeThresholdChange: (v: number) => void; onGenerate: () => void; onStop: () => void; onPauseStream: () => void; @@ -53,6 +60,12 @@ export default function GenerationControls({ onCfgScaleChange, inferenceSteps, onInferenceStepsChange, + prebufferSecs, + onPrebufferSecsChange, + rebufferThresholdSecs, + onRebufferThresholdChange, + resumeThresholdSecs, + onResumeThresholdChange, onGenerate, onStop, onPauseStream, @@ -65,6 +78,7 @@ export default function GenerationControls({ serverStatus, downloadProgress, }: GenerationControlsProps) { + const [showAdvanced, setShowAdvanced] = useState(false); const voices = availableVoices.length > 0 ? availableVoices : FALLBACK_VOICES; const serverReady = serverStatus === "online"; const buttonDisabled = isGenerating || wordCount === 0 || !serverReady; @@ -169,6 +183,108 @@ export default function GenerationControls({ + {/* Advanced Buffering toggle */} +