diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index cf36061..0000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(mv podcast-forge/pnpm-lock.yaml /tmp/vibepod-pnpm-lock.yaml)", - "Bash(git mv *)", - "Bash(mv /tmp/vibepod-pnpm-lock.yaml web/pnpm-lock.yaml)", - "Bash(git rm *)", - "Bash(uv lock *)", - "Bash(pnpm install *)", - "Bash(git add *)", - "Bash(command -v uv)", - "Bash(uv --version)", - "Bash(uv sync *)", - "Bash(pnpm --filter vibepod-web exec tsc --noEmit)", - "Bash(xargs cat *)", - "Bash(.venv/Scripts/python.exe -c \"import torch; print\\('torch:', torch.__version__\\); print\\('CUDA available:', torch.cuda.is_available\\(\\)\\); print\\('CUDA version:', torch.version.cuda\\)\")", - "Bash(nvidia-smi)" - ] - } -} diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..d37e399 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,37 @@ +root = true + +[*] +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{ts,tsx,js,jsx,mjs,cjs,mts,cts}] +indent_style = space +indent_size = 2 + +[*.{json,jsonc}] +indent_style = space +indent_size = 2 + +[*.{css,html}] +indent_style = space +indent_size = 2 + +[*.{yaml,yml}] +indent_style = space +indent_size = 2 + +[*.py] +indent_style = space +indent_size = 4 + +[*.{toml}] +indent_style = space +indent_size = 4 + +[*.md] +trim_trailing_whitespace = false + +[Makefile] +indent_style = tab diff --git a/.env.example b/.env.example index 537770d..dc32c2b 100644 --- a/.env.example +++ b/.env.example @@ -8,3 +8,42 @@ HF_TOKEN= # Override the HuggingFace model cache directory (optional) # HF_HOME=/path/to/hf-cache + +# --------------------------------------------------------------------------- +# Runtime tuning +# --------------------------------------------------------------------------- + +# Force the Python server device. Usually set by `pnpm dev` / `pnpm dev:cpu`. +# VIBEPOD_DEVICE=cuda +# VIBEPOD_DEVICE=cpu + +# CPU mode: keep async decode enabled. This overlaps acoustic decoding with +# language-model work and measured ~20% faster on an 8-thread CPU run. +VIBEPOD_ASYNC_DECODE=1 + +# CPU mode: thread tuning. On an 8-core / 16-thread Ryzen test system, +# 8 worker threads with 1 inter-op thread gave the best wall time, while 12 +# over-subscribed and regressed. +# VIBEPOD_CPU_THREADS=8 +# VIBEPOD_CPU_INTEROP_THREADS=1 + +# CPU mode: playback buffering. CPU generation is slower than realtime, so +# smooth streaming needs a larger initial buffer than CUDA. Lower these for +# faster startup if you are OK with occasional rebuffering. +# VIBEPOD_PREBUFFER_SECS=24 +# VIBEPOD_REBUFFER_THRESHOLD_SECS=2 +# VIBEPOD_RESUME_THRESHOLD_SECS=12 + +# CPU mode: dynamic INT8 quantization is enabled by default in start.sh. +# Set to 0 if you are comparing quality/performance or debugging. +# VIBEPOD_QUANTIZE=1 + +# CUDA mode: dtype and attention selection. Defaults are bf16 + SDPA unless +# optional FlashAttention is explicitly enabled and importable. +# VIBEPOD_CUDA_DTYPE=bf16 +# VIBEPOD_ATTN_IMPL=sdpa +# VIBEPOD_ENABLE_FLASH_ATTN=0 + +# Debug/profiling. Keep disabled for benchmark timing; async CPU profiling +# double-counts overlapped decode work. +# VIBEPOD_PROFILE_GENERATION=0 diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..629ff9c --- /dev/null +++ b/.gitattributes @@ -0,0 +1,24 @@ +* text=auto eol=lf + +*.sh text eol=lf +*.py text eol=lf +*.ts text eol=lf +*.tsx text eol=lf +*.js text eol=lf +*.jsx text eol=lf +*.mjs text eol=lf +*.cjs text eol=lf +*.mts text eol=lf +*.cts text eol=lf +*.css text eol=lf +*.html text eol=lf +*.json text eol=lf +*.jsonc text eol=lf +*.yaml text eol=lf +*.yml text eol=lf +*.toml text eol=lf +*.md text eol=lf +*.mdx text eol=lf +*.lock text eol=lf +*.env text eol=lf +*.env.* text eol=lf diff --git a/.gitignore b/.gitignore index 13db197..040007f 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ web/node_modules/ .DS_Store Thumbs.db .vscode/settings.json +.claude/settings.local.json diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000..08615cc --- /dev/null +++ b/.prettierignore @@ -0,0 +1,18 @@ +# Dependencies +node_modules/ +web/node_modules/ + +# Build outputs +web/.next/ +web/tsconfig.tsbuildinfo +web/next-env.d.ts + +# Python / server +server/ + +# Lock files +pnpm-lock.yaml +web/pnpm-lock.yaml + +# Generated +web/public/ diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..2aa8717 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,8 @@ +{ + "semi": true, + "singleQuote": false, + "tabWidth": 2, + "trailingComma": "es5", + "printWidth": 100, + "endOfLine": "lf" +} diff --git a/AGENTS.md b/AGENTS.md index ed18570..7ca403c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -8,10 +8,10 @@ This file gives AI coding agents (Jules, Copilot, Claude Code, etc.) the context VibePod is a text-to-speech web app. It has two services that must both run for the app to work: -| Service | Language | Entry point | Port | -|---------|----------|-------------|------| -| **server** | Python 3.10+ (FastAPI + VibeVoice) | `server/start.sh` | 8000 | -| **web** | TypeScript (Next.js 15, React 19) | `pnpm --filter vibepod-web dev` | 3000 | +| Service | Language | Entry point | Port | +| ---------- | ---------------------------------- | ------------------------------- | ---- | +| **server** | Python 3.10+ (FastAPI + VibeVoice) | `server/start.sh` | 8000 | +| **web** | TypeScript (Next.js 15, React 19) | `pnpm --filter vibepod-web dev` | 3000 | The Next.js frontend proxies all model requests through its own API routes to the FastAPI server — it never calls the Python server directly from the browser. @@ -51,12 +51,12 @@ pnpm build The `--cpu` flag in `start.sh` sets `VIBEPOD_DEVICE=cpu` and uses a separate venv (`server/.venv-cpu`) so CUDA and CPU installs never conflict. `vibevoice_server.py` reads `VIBEPOD_DEVICE` at startup via `_resolve_device()` — do not remove or rename that function. -| Env var | Values | Set by | -|---------|--------|--------| -| `VIBEPOD_DEVICE` | `cpu` \| `cuda` | `server/start.sh` | -| `UV_PROJECT_ENVIRONMENT` | `.venv-cpu` \| `.venv` | `server/start.sh` | -| `HF_TOKEN` | HuggingFace token | Jules secret / `.env.local` | -| `VIBEVOICE_SERVER_URL` | `http://localhost:8000` | `.env.local` | +| Env var | Values | Set by | +| ------------------------ | ----------------------- | --------------------------- | +| `VIBEPOD_DEVICE` | `cpu` \| `cuda` | `server/start.sh` | +| `UV_PROJECT_ENVIRONMENT` | `.venv-cpu` \| `.venv` | `server/start.sh` | +| `HF_TOKEN` | HuggingFace token | Jules secret / `.env.local` | +| `VIBEVOICE_SERVER_URL` | `http://localhost:8000` | `.env.local` | --- @@ -94,7 +94,9 @@ dev.sh Concurrent launcher (forwards flags to start.sh) ## API reference ### `GET /health` + Returns server status. Safe to poll. + ```json { "status": "online", @@ -103,13 +105,17 @@ Returns server status. Safe to poll. "voices": ["carter", "davis", "emma", "frank", "grace", "mike"] } ``` + `status` values: `downloading` | `loading` | `online` | `error` ### `POST /generate` + Streams audio as SSE events. + ```json { "text": "Hello world", "speaker": "carter", "cfg_scale": 1.5, "inference_steps": 10 } ``` + Event types: `audio_chunk` (base64 float32 PCM) | `complete` | `error` | `cancelled` --- @@ -117,12 +123,14 @@ Event types: `audio_chunk` (base64 float32 PCM) | `complete` | `error` | `cancel ## Do / Don't **Do:** + - Use `pnpm dev:cpu` in Jules — never plain `pnpm dev` - Run `git checkout server/uv.lock` if uv rewrites it during setup - Keep `_resolve_device()` in `vibevoice_server.py` — it's the CPU/CUDA switching logic - Test server changes against `GET /health` and `POST /generate` **Don't:** + - Run `uv sync` without `UV_PROJECT_ENVIRONMENT=.venv-cpu` in the Jules sandbox - Install Python packages with pip - Modify `server/uv.lock` manually diff --git a/DESIGN.md b/DESIGN.md index 2654734..42a00df 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -173,16 +173,21 @@ The shape language is a hybrid of structural precision and tactile softness. ## Components ### Card Containers + The fundamental building block of the UI. Every distinct section (Script, Player, Controls, Logs) is housed in a card featuring the `card-bg`, a 1px `border`, and `rounded-xl` corners. The internal layout always features an uppercase teal header for immediate section identification. ### Primary Action Buttons + Used for high-leverage actions like "Generate Audio" and "Play/Pause." These buttons utilize the `gradient-primary-dim` background, bold white text, and emit a soft teal glow to draw the eye and signify their importance. ### Range Sliders + Custom-styled input ranges replace default browser styles. The tracks are muted and slim, while the thumbs are bright teal, fully rounded, and emit a glow that intensifies on hover, providing a premium, tactile scrubbing experience. ### Status Indicators & Logs + A critical component of the application. Status badges utilize a minimalist pill shape with a pulsing ring animation to indicate active server processing. The log panel explicitly uses monospace typography and color-codes messages (green for success, red for error, white for neutral) to provide a terminal-like readout of the backend systems. ### Gradients + Gradients are used purposefully to indicate progress, activity, or brand presence. The primary gradient (`135deg` from teal to violet) is used for branding (the logo icon and text) and primary buttons. Horizontal gradients (`90deg`) are used dynamically in progress bars to represent the flow of data over time (e.g., loading, downloading, and audio generation). diff --git a/README.md b/README.md index f8202f5..ab76d3c 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,12 @@ The Next.js app proxies audio generation requests to the FastAPI server, keeping ## Prerequisites -| Tool | Install | -|------|---------| -| [Node.js 20+](https://nodejs.org) | `winget install OpenJS.NodeJS.LTS` | -| [pnpm](https://pnpm.io) | `npm i -g pnpm` | +| Tool | Install | +| ---------------------------------- | ----------------------------------- | +| [Node.js 20+](https://nodejs.org) | `winget install OpenJS.NodeJS.LTS` | +| [pnpm](https://pnpm.io) | `npm i -g pnpm` | | [Python 3.10+](https://python.org) | `winget install Python.Python.3.13` | -| [uv](https://docs.astral.sh/uv/) | `winget install astral-sh.uv` | +| [uv](https://docs.astral.sh/uv/) | `winget install astral-sh.uv` | ## Getting started @@ -50,10 +50,10 @@ The frontend shows a loading indicator while the model downloads. Once the serve VibePod maintains two completely separate Python virtual environments so CUDA and CPU torch installs never conflict: -| Mode | Command | venv | torch source | -|------|---------|------|--------------| -| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index | -| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) | +| Mode | Command | venv | torch source | +| -------------- | -------------- | ------------------ | ----------------------- | +| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index | +| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) | On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"` or `"device": "cuda"`. @@ -74,11 +74,11 @@ pnpm build # Production build of the frontend Copy `.env.example` to `.env.local` and set: -| Variable | Default | Description | -|----------|---------|-------------| +| Variable | Default | Description | +| ---------------------- | ----------------------- | --------------------------------------------------------- | | `VIBEVOICE_SERVER_URL` | `http://localhost:8000` | URL the Next.js API routes use to reach the Python server | -| `HF_TOKEN` | — | HuggingFace token (required if the model repo is gated) | -| `HF_HOME` | — | Override the HuggingFace model cache directory | +| `HF_TOKEN` | — | HuggingFace token (required if the model repo is gated) | +| `HF_HOME` | — | Override the HuggingFace model cache directory | ## Project structure @@ -107,11 +107,11 @@ server/ ## Generation parameters -| Parameter | Range | Default | Effect | -|-----------|-------|---------|--------| -| `speaker` | `carter`, `davis`, `emma`, `frank`, `grace`, `mike` | `carter` | Voice preset used for the generated audio | -| `cfg_scale` | 0.5 – 4.0 | 1.5 | Higher = more expressive guidance | -| `inference_steps` | 5 – 20 | 10 | More steps = higher quality, slower generation | +| Parameter | Range | Default | Effect | +| ----------------- | --------------------------------------------------- | -------- | ---------------------------------------------- | +| `speaker` | `carter`, `davis`, `emma`, `frank`, `grace`, `mike` | `carter` | Voice preset used for the generated audio | +| `cfg_scale` | 0.5 – 4.0 | 1.5 | Higher = more expressive guidance | +| `inference_steps` | 5 – 20 | 10 | More steps = higher quality, slower generation | ## How it works diff --git a/package.json b/package.json index 034f905..8769a40 100644 --- a/package.json +++ b/package.json @@ -8,7 +8,14 @@ "dev:cpu": "bash dev.sh --cpu", "dev:server": "bash server/start.sh", "dev:server:cpu": "bash server/start.sh --cpu", - "dev:web": "pnpm --filter vibepod-web dev" + "dev:web": "pnpm --filter vibepod-web dev", + "format": "prettier --write . && cd server && uv run ruff format .", + "format:check": "prettier --check . && cd server && uv run ruff format --check .", + "lint:server": "cd server && uv run ruff check .", + "lint:server:fix": "cd server && uv run ruff check --fix ." + }, + "devDependencies": { + "prettier": "^3.5.3" }, "packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8" } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 07d0a9a..b89acfe 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -6,7 +6,11 @@ settings: importers: - .: {} + .: + devDependencies: + prettier: + specifier: ^3.5.3 + version: 3.8.3 web: dependencies: @@ -516,6 +520,11 @@ packages: resolution: {integrity: sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==} engines: {node: ^10 || ^12 || >=14} + prettier@3.8.3: + resolution: {integrity: sha512-7igPTM53cGHMW8xWuVTydi2KO233VFiTNyF5hLJqpilHfmn8C8gPf+PS7dUT64YcXFbiMGZxS9pCSxL/Dxm/Jw==} + engines: {node: '>=14'} + hasBin: true + react-dom@19.1.0: resolution: {integrity: sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==} peerDependencies: @@ -917,6 +926,8 @@ snapshots: picocolors: 1.1.1 source-map-js: 1.2.1 + prettier@3.8.3: {} + react-dom@19.1.0(react@19.1.0): dependencies: react: 19.1.0 diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index b1cedb5..92a7e8b 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -1,2 +1,2 @@ packages: - - 'web' + - "web" diff --git a/roadmap.md b/roadmap.md index b2035fe..5825102 100644 --- a/roadmap.md +++ b/roadmap.md @@ -39,6 +39,13 @@ VibePod Studio will turn generated audio from a one-shot download into a reusabl - Add project save/load, autosave, and recoverable render jobs. - Prepare the audio pipeline for queueing longer renders outside the request lifecycle. +## Later: VibeVoice Performance Research + +- Move the current VibePod hot-path monkey patches into the `JezzWTF/VibeVoice` fork once the feature direction has settled. +- Add clearer generation profiling for overlapped CPU work, especially decode wait time versus total acoustic decode time. +- Prototype batched positive/negative CFG TTS LM inference behind an opt-in flag and benchmark it against the current sequential path on CPU and CUDA. +- Keep experimental performance work isolated from user-facing feature work unless it shows a clear speedup without audio quality regressions. + ## Foundation Work Needed First - Persist generated outputs with stable IDs. diff --git a/server/.python-version b/server/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/server/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/server/download_model.py b/server/download_model.py index 1377d17..369a0fb 100644 --- a/server/download_model.py +++ b/server/download_model.py @@ -30,15 +30,12 @@ def download() -> str: from huggingface_hub import snapshot_download except ImportError: print( - "ERROR: huggingface_hub is not installed.\n" - "Run: pip install huggingface_hub", + "ERROR: huggingface_hub is not installed.\nRun: pip install huggingface_hub", file=sys.stderr, ) sys.exit(1) - token: str | None = os.environ.get("HF_TOKEN") or os.environ.get( - "HUGGINGFACE_TOKEN" - ) + token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") print(f"Checking / downloading model: {MODEL_ID}") print("(This may take several minutes on first run — the model is ~1 GB)") diff --git a/server/pyproject.toml b/server/pyproject.toml index 3756ed3..a45dad9 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -8,7 +8,8 @@ dependencies = [ # To switch back to CPU-only, remove the [tool.uv.sources] torch entry below. "torch>=2.0.0", # VibeVoice custom model + processor classes (not yet in upstream transformers) - "vibevoice @ git+https://github.com/microsoft/VibeVoice.git", + # Uses JezzWTF/VibeVoice fork so VibePod-specific optimisations land here. + "vibevoice @ git+https://github.com/JezzWTF/VibeVoice.git", # Exact version required by vibevoice's streaming TTS module "transformers==4.51.3", "fastapi>=0.111.0", @@ -18,10 +19,36 @@ dependencies = [ "huggingface_hub>=0.23.0", ] +[dependency-groups] +dev = [ + "ruff>=0.11.0", +] + # No build-system — this is a scripts project, not an installable package. # Lock file is committed so installs are reproducible. # Run `uv lock --upgrade` to bump dependencies. +[tool.ruff] +line-length = 100 +indent-width = 4 +target-version = "py310" +exclude = [".git", ".venv", ".venv-cpu", "__pycache__"] + +[tool.ruff.lint] +select = ["E", "F", "UP", "B", "SIM", "I"] +ignore = [ + "E501", # line-too-long — handled by formatter + "B905", # zip() without strict= — existing code, lengths are known-correct +] + +[tool.ruff.lint.per-file-ignores] +"download_model.py" = ["T201"] # allow print statements in CLI scripts + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "lf" + [[tool.uv.index]] name = "pytorch-cu124" url = "https://download.pytorch.org/whl/cu124" diff --git a/server/start.sh b/server/start.sh index befa203..c50277d 100755 --- a/server/start.sh +++ b/server/start.sh @@ -7,6 +7,11 @@ # ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv) # ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu) # +# Optional CUDA acceleration: +# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh +# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12, +# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA. +# # The two modes maintain completely separate virtual environments so their torch # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; # --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel. @@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then exit 1 fi +validate_flash_attn() { + uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null +} + +remove_broken_flash_attn() { + if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then + if ! validate_flash_attn; then + echo " Installed flash-attn is not usable in this environment; removing it." + uv pip uninstall flash-attn + fi + fi +} + # --------------------------------------------------------------------------- # 2. Sync Python environment # CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the @@ -61,10 +79,64 @@ echo "" if $CPU_MODE; then echo "--> Syncing CPU Python environment (.venv-cpu)..." export UV_PROJECT_ENVIRONMENT=".venv-cpu" + LOCK_BACKUP="" + if [[ -f uv.lock ]]; then + LOCK_BACKUP="$(mktemp)" + cp uv.lock "$LOCK_BACKUP" + fi uv sync --no-sources + if [[ -n "$LOCK_BACKUP" ]]; then + cp "$LOCK_BACKUP" uv.lock + rm -f "$LOCK_BACKUP" + fi else echo "--> Syncing CUDA Python environment (.venv)..." uv sync + + remove_broken_flash_attn + + if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then + echo "" + echo "--> Checking optional FlashAttention wheel..." + + if validate_flash_attn; then + echo " flash-attn already installed and importable." + else + PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")" + TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")" + CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")" + + FLASH_ATTN_WHEEL_URL="" + if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then + case "$(uname -s)" in + MINGW*|CYGWIN*|MSYS*) + FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" + echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4 (Windows)..." + ;; + *) + echo " No pre-built flash-attn wheel available for this platform ($(uname -s))." + echo " Continuing with PyTorch SDPA attention." + ;; + esac + else + echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG." + echo " Continuing with PyTorch SDPA attention." + fi + + if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]; then + if uv pip install "$FLASH_ATTN_WHEEL_URL"; then + if validate_flash_attn; then + echo " flash-attn import check passed." + else + echo " flash-attn import check failed; removing it and continuing with SDPA." + uv pip uninstall flash-attn + fi + else + echo " flash-attn wheel install failed; continuing with SDPA." + fi + fi + fi + fi fi # --------------------------------------------------------------------------- @@ -78,11 +150,28 @@ export PYTHONUTF8=1 if $CPU_MODE; then export VIBEPOD_DEVICE="cpu" export UV_PROJECT_ENVIRONMENT=".venv-cpu" + if [[ -z "${VIBEPOD_CPU_THREADS:-}" ]]; then + VIBEPOD_CPU_THREADS="$(uv run --no-sync --no-sources python -c "import os; print(max(1, (os.cpu_count() or 2) // 2))")" + export VIBEPOD_CPU_THREADS + fi + export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$VIBEPOD_CPU_THREADS}" + export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$VIBEPOD_CPU_THREADS}" + # Dynamic INT8 quantization — on by default for CPU (~22% faster, prediction_head + # excluded automatically to avoid regression on small fixed-size tensors). + # Set VIBEPOD_QUANTIZE=0 to disable if you notice audio quality differences. + export VIBEPOD_QUANTIZE="${VIBEPOD_QUANTIZE:-1}" + # Optional CPU flags: + # VIBEPOD_ASYNC_DECODE=0 Disable async decode/tts_lm overlap (on by default) + # VIBEPOD_CPU_BF16=1 Force bfloat16 weights (auto-detected via AVX512_BF16) + # VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive + # models on CPU — not recommended, kept for experimentation) + UV_RUN_ARGS=(--no-sync --no-sources) else export VIBEPOD_DEVICE="cuda" + UV_RUN_ARGS=() fi -exec uv run uvicorn vibevoice_server:app \ +exec uv run "${UV_RUN_ARGS[@]}" uvicorn vibevoice_server:app \ --host 127.0.0.1 \ --port 8000 \ --log-level info \ diff --git a/server/uv.lock b/server/uv.lock index 7fc34c0..c347132 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -1479,7 +1479,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -1490,7 +1490,7 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, @@ -1509,9 +1509,9 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "nvidia-cusparse-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, @@ -1522,7 +1522,7 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, @@ -2394,6 +2394,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654 }, ] +[[package]] +name = "ruff" +version = "0.15.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/99/43/3291f1cc9106f4c63bdce7a8d0df5047fe8422a75b091c16b5e9355e0b11/ruff-0.15.12.tar.gz", hash = "sha256:ecea26adb26b4232c0c2ca19ccbc0083a68344180bba2a600605538ce51a40a6", size = 4643852 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/6e/e78ffb61d4686f3d96ba3df2c801161843746dcbcbb17a1e927d4829312b/ruff-0.15.12-py3-none-linux_armv6l.whl", hash = "sha256:f86f176e188e94d6bdbc09f09bfd9dc729059ad93d0e7390b5a73efe19f8861c", size = 10640713 }, + { url = "https://files.pythonhosted.org/packages/ae/08/a317bc231fb9e7b93e4ef3089501e51922ff88d6936ce5cf870c4fe55419/ruff-0.15.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e3bcd123364c3770b8e1b7baaf343cc99a35f197c5c6e8af79015c666c423a6c", size = 11069267 }, + { url = "https://files.pythonhosted.org/packages/aa/a4/f828e9718d3dce1f5f11c39c4f65afd32783c8b2aebb2e3d259e492c47bd/ruff-0.15.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fe87510d000220aa1ed530d4448a7c696a0cae1213e5ec30e5874287b66557b5", size = 10397182 }, + { url = "https://files.pythonhosted.org/packages/71/e0/3310fc6d1b5e1fdea22bf3b1b807c7e187b581021b0d7d4514cccdb5fb71/ruff-0.15.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84a1630093121375a3e2a95b4a6dc7b59e2b4ee76216e32d81aae550a832d002", size = 10758012 }, + { url = "https://files.pythonhosted.org/packages/11/c1/a606911aee04c324ddaa883ae418f3569792fd3c4a10c50e0dd0a2311e1e/ruff-0.15.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb129f40f114f089ebe0ca56c0d251cf2061b17651d464bb6478dc01e69f11f5", size = 10447479 }, + { url = "https://files.pythonhosted.org/packages/9d/68/4201e8444f0894f21ab4aeeaee68aa4f10b51613514a20d80bd628d57e88/ruff-0.15.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0c862b172d695db7598426b8af465e7e9ac00a3ea2a3630ee67eb82e366aaa6", size = 11234040 }, + { url = "https://files.pythonhosted.org/packages/34/ff/8a6d6cf4ccc23fd67060874e832c18919d1557a0611ebef03fdb01fff11e/ruff-0.15.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2849ea9f3484c3aca43a82f484210370319e7170df4dfe4843395ddf6c57bc33", size = 12087377 }, + { url = "https://files.pythonhosted.org/packages/85/f6/c669cf73f5152f623d34e69866a46d5e6185816b19fcd5b6dd8a2d299922/ruff-0.15.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e77c7e51c07fe396826d5969a5b846d9cd4c402535835fb6e21ce8b28fef847", size = 11367784 }, + { url = "https://files.pythonhosted.org/packages/e8/39/c61d193b8a1daaa8977f7dea9e8d8ba866e02ea7b65d32f6861693aa4c12/ruff-0.15.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b2f4f2f3b1026b5fb449b467d9264bf22067b600f7b6f41fc5958909f449d0", size = 11344088 }, + { url = "https://files.pythonhosted.org/packages/c2/8d/49afab3645e31e12c590acb6d3b5b69d7aab5b81926dbaf7461f9441f37a/ruff-0.15.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9ba3b8f1afd7e2e43d8943e55f249e13f9682fde09711644a6e7290eb4f3e339", size = 11271770 }, + { url = "https://files.pythonhosted.org/packages/46/06/33f41fe94403e2b755481cdfb9b7ef3e4e0ed031c4581124658d935d52b4/ruff-0.15.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e852ba9fdc890655e1d78f2df1499efbe0e54126bd405362154a75e2bde159c5", size = 10719355 }, + { url = "https://files.pythonhosted.org/packages/0d/59/18aa4e014debbf559670e4048e39260a85c7fcee84acfd761ac01e7b8d35/ruff-0.15.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dd8aed930da53780d22fc70bdf84452c843cf64f8cb4eb38984319c24c5cd5fd", size = 10462758 }, + { url = "https://files.pythonhosted.org/packages/25/e7/cc9f16fd0f3b5fddcbd7ec3d6ae30c8f3fde1047f32a4093a98d633c6570/ruff-0.15.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01da3988d225628b709493d7dc67c3b9b12c0210016b08690ef9bd27970b262b", size = 10953498 }, + { url = "https://files.pythonhosted.org/packages/72/7a/a9ba7f98c7a575978698f4230c5e8cc54bbc761af34f560818f933dafa0c/ruff-0.15.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9cae0f92bd5700d1213188b31cd3bdd2b315361296d10b96b8e2337d3d11f53e", size = 11447765 }, + { url = "https://files.pythonhosted.org/packages/ea/f9/0ae446942c846b8266059ad8a30702a35afae55f5cdc54c5adf8d7afdc27/ruff-0.15.12-py3-none-win32.whl", hash = "sha256:d0185894e038d7043ba8fd6aee7499ece6462dc0ea9f1e260c7451807c714c20", size = 10657277 }, + { url = "https://files.pythonhosted.org/packages/33/f1/9614e03e1cdcbf9437570b5400ced8a720b5db22b28d8e0f1bda429f660d/ruff-0.15.12-py3-none-win_amd64.whl", hash = "sha256:c87a162d61ab3adca47c03f7f717c68672edec7d1b5499e652331780fe74950d", size = 11837758 }, + { url = "https://files.pythonhosted.org/packages/c0/98/6beb4b351e472e5f4c4613f7c35a5290b8be2497e183825310c4c3a3984b/ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f", size = 11120821 }, +] + [[package]] name = "safehttpx" version = "0.1.7" @@ -3100,6 +3125,11 @@ dependencies = [ { name = "vibevoice" }, ] +[package.dev-dependencies] +dev = [ + { name = "ruff" }, +] + [package.metadata] requires-dist = [ { name = "fastapi", specifier = ">=0.111.0" }, @@ -3109,13 +3139,16 @@ requires-dist = [ { name = "torch", specifier = ">=2.0.0", index = "https://download.pytorch.org/whl/cu124" }, { name = "transformers", specifier = "==4.51.3" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.29.0" }, - { name = "vibevoice", git = "https://github.com/microsoft/VibeVoice.git" }, + { name = "vibevoice", git = "https://github.com/JezzWTF/VibeVoice.git" }, ] +[package.metadata.requires-dev] +dev = [{ name = "ruff", specifier = ">=0.11.0" }] + [[package]] name = "vibevoice" version = "1.0.0" -source = { git = "https://github.com/microsoft/VibeVoice.git#e73d1e17c3754f046352014856a922f8208fb5d3" } +source = { git = "https://github.com/JezzWTF/VibeVoice.git#fe832f20e3d1638594f551a08f02253f14408dbd" } dependencies = [ { name = "absl-py" }, { name = "accelerate" }, diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index f5bc012..0c32e6c 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -20,17 +20,22 @@ Device selection: import asyncio import base64 +import concurrent.futures import copy import functools +import importlib.util import json import logging import os +import platform import threading import time +import types import urllib.request +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator, Literal, Optional +from typing import Literal import torch from fastapi import FastAPI, HTTPException, Request @@ -46,8 +51,7 @@ SAMPLE_RATE = 24_000 VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model" VOICE_BASE_URL = ( - "https://raw.githubusercontent.com/microsoft/VibeVoice/main" - "/demo/voices/streaming_model" + "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model" ) EN_VOICES: dict[str, str] = { @@ -62,6 +66,11 @@ DEFAULT_SPEAKER = "carter" _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] +# ── Pipeline executor ────────────────────────────────────────────────────────── +# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker). + +_decode_executor: concurrent.futures.ThreadPoolExecutor | None = None + # ── Device selection ──────────────────────────────────────────────────────────── # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. # Falls back to auto-detection if not set. @@ -106,6 +115,38 @@ def _env_float(name: str, default: float) -> float: return default +def _cpu_supports_bf16() -> bool: + """Return True if the CPU has AVX512_BF16 hardware support.""" + return ( + hasattr(torch, "cpu") + and hasattr(torch.cpu, "is_avx512_bf16_supported") + and torch.cpu.is_avx512_bf16_supported() + ) + + +def _configure_cpu_runtime() -> dict[str, object]: + logical_cpus = os.cpu_count() or 1 + default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus + intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads) + interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) + mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" + + torch.set_num_threads(max(1, intra_threads)) + try: + torch.set_num_interop_threads(max(1, interop_threads)) + except RuntimeError as exc: + logger.warning("Could not set CPU inter-op threads: %s", exc) + + torch.backends.mkldnn.enabled = mkldnn_enabled + return { + "logical_cpus": logical_cpus, + "threads": torch.get_num_threads(), + "interop_threads": torch.get_num_interop_threads(), + "mkldnn_available": torch.backends.mkldnn.is_available(), + "mkldnn_enabled": torch.backends.mkldnn.enabled, + } + + # ── Global state ──────────────────────────────────────────────────────────────── ModelStatus = Literal["downloading", "loading", "online", "error"] @@ -114,7 +155,7 @@ _processor = None _model = None _device: str = "cpu" _model_status: ModelStatus = "loading" -_model_error: Optional[str] = None +_model_error: str | None = None _voice_presets: dict[str, object] = {} _load_lock = threading.Lock() _generation_lock = asyncio.Lock() @@ -161,9 +202,7 @@ def _is_model_cached() -> bool: try: from huggingface_hub import snapshot_download - snapshot_download( - MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS - ) + snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS) return True except Exception: return False @@ -172,9 +211,7 @@ def _is_model_cached() -> bool: def _download_model() -> None: from huggingface_hub import snapshot_download - token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get( - "HUGGINGFACE_TOKEN" - ) + token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") DlTqdm = _make_dl_tqdm() logger.info("Model not cached — downloading %s...", MODEL_ID) snapshot_download( @@ -188,7 +225,7 @@ def _download_model() -> None: def _download_voices() -> None: VOICES_DIR.mkdir(parents=True, exist_ok=True) - for name, filename in EN_VOICES.items(): + for _name, filename in EN_VOICES.items(): dest = VOICES_DIR / filename if not dest.exists(): url = f"{VOICE_BASE_URL}/{filename}" @@ -211,8 +248,56 @@ def _init_processor(): def _init_model(device: str): logger.info("Loading model on %s...", device) - load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 - attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" + if device == "cuda": + torch.set_float32_matmul_precision("high") + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) + logger.info( + "PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s", + torch.backends.cuda.flash_sdp_enabled(), + torch.backends.cuda.mem_efficient_sdp_enabled(), + torch.backends.cuda.math_sdp_enabled(), + ) + elif device == "cpu": + torch.set_float32_matmul_precision("medium") + logger.info("CPU runtime configuration: %s", _configure_cpu_runtime()) + + cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower() + if device == "cuda" and cuda_dtype == "fp16": + load_dtype = torch.float16 + elif device == "cuda": + load_dtype = torch.bfloat16 + else: + cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower() + if cpu_bf16_env == "1": + load_dtype = torch.bfloat16 + logger.info("CPU BF16 forced via VIBEPOD_CPU_BF16=1") + elif cpu_bf16_env == "0": + load_dtype = torch.float32 + logger.info("CPU float32 forced via VIBEPOD_CPU_BF16=0") + elif _cpu_supports_bf16(): + load_dtype = torch.bfloat16 + logger.info("AVX512_BF16 detected — loading model in bfloat16") + else: + load_dtype = torch.float32 + logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)") + logger.info("Loading model weights with dtype %s", load_dtype) + requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower() + has_flash_attn = importlib.util.find_spec("flash_attn") is not None + if requested_attn_impl in {"eager", "sdpa"}: + attn_impl = requested_attn_impl + elif requested_attn_impl == "flash_attention_2": + attn_impl = "flash_attention_2" if has_flash_attn else "sdpa" + else: + attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa" + logger.info("Using Transformers attention implementation: %s", attn_impl) + if device == "cuda" and not has_flash_attn: + logger.info("flash_attn is not installed; using PyTorch SDPA attention.") from vibevoice.modular.modeling_vibevoice_streaming_inference import ( VibeVoiceStreamingForConditionalGenerationInference, @@ -225,9 +310,13 @@ def _init_model(device: str): device_map=device, attn_implementation=attn_impl, ) - except Exception: + except Exception as exc: + if attn_impl == "sdpa": + raise logger.warning( - "Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True + "Model load with %s failed (%s); falling back to sdpa", + attn_impl, + exc, ) model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, @@ -237,10 +326,272 @@ def _init_model(device: str): ) model.eval() + if device == "cpu": + model = _apply_cpu_optimizations(model) model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"]) + _install_generation_optimizations(model) + if device == "cpu": + # Must run after _install_generation_optimizations so the async wrapper + # sits outside the profiling wrapper (VibeVoice calls async → profiling → real decode). + _install_cpu_pipeline_optimizations(model) return model +def _apply_cpu_optimizations(model: object) -> object: + """Apply optional post-load CPU optimizations. Returns (possibly new) model object.""" + + do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1" + do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1" + + if do_quantize: + logger.info("Applying dynamic INT8 quantization to Linear layers...") + try: + import torch.ao.quantization + + # The diffusion prediction_head operates on small fixed-size tensors where + # INT8 pack/unpack overhead exceeds the matmul savings (~+20% regression in + # testing). Save and restore it so it stays in float32. + saved_prediction_head = None + if hasattr(model, "model") and hasattr(model.model, "prediction_head"): + saved_prediction_head = model.model.prediction_head + del model.model.prediction_head + + model = torch.ao.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) + + if saved_prediction_head is not None: + model.model.prediction_head = saved_prediction_head + logger.info( + "Dynamic INT8 quantization applied (prediction_head excluded — stays float32)." + ) + else: + logger.info("Dynamic INT8 quantization applied.") + except Exception as exc: + logger.warning("Dynamic quantization failed: %s — skipping", exc) + + if do_compile: + # torch.compile with inductor on CPU is ineffective for autoregressive TTS: + # each token step produces a unique input shape, so every step triggers a new + # kernel compile event rather than reusing compiled code. Kept as an escape + # hatch but not recommended. + compile_mode = os.environ.get("VIBEPOD_COMPILE_MODE", "reduce-overhead") + logger.info( + "torch.compile enabled (mode=%s) — NOTE: limited benefit for autoregressive" + " models on CPU due to dynamic sequence lengths.", + compile_mode, + ) + _compile_targets: list[tuple[str, object, str, bool]] = [ + ("forward_tts_lm", model, "forward_tts_lm", True), + ] + if hasattr(model, "model"): + inner = model.model + if hasattr(inner, "prediction_head"): + _compile_targets.append(("prediction_head", inner, "prediction_head", False)) + if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"): + _compile_targets.append( + ("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False) + ) + + for label, obj, attr, dynamic in _compile_targets: + try: + compiled = torch.compile( + getattr(obj, attr), + backend="inductor", + mode=compile_mode, + dynamic=dynamic, + ) + setattr(obj, attr, compiled) + logger.info(" compiled: %s", label) + except Exception as exc: + logger.warning(" torch.compile failed for %s: %s — skipping", label, exc) + + return model + + +def _install_generation_optimizations(model: object) -> None: + """Patch VibeVoice hot paths without changing model quality settings.""" + + def profile_enabled() -> bool: + return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1" + + def profile_sync() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def profile_record(self, key: str, elapsed: float) -> None: + stats = getattr(self, "_vibepod_profile", None) + if stats is None: + stats = {} + self._vibepod_profile = stats + bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0}) + bucket["count"] += 1 + bucket["seconds"] += elapsed + + def timed_method(self, key: str, fn, *args, **kwargs): + if not profile_enabled(): + return fn(*args, **kwargs) + profile_sync() + started = time.perf_counter() + result = fn(*args, **kwargs) + profile_sync() + profile_record(self, key, time.perf_counter() - started) + return result + + def prepare_noise_scheduler(self): + scheduler = self.model.noise_scheduler + cache_key = self.ddpm_inference_steps + cache = getattr(self, "_vibepod_scheduler_cache", {}) + cached = cache.get(cache_key) + + if cached is None: + scheduler.set_timesteps(self.ddpm_inference_steps) + cached = { + "num_inference_steps": scheduler.num_inference_steps, + "timesteps": scheduler.timesteps, + "sigmas": scheduler.sigmas, + } + cache[cache_key] = cached + self._vibepod_scheduler_cache = cache + else: + scheduler.num_inference_steps = cached["num_inference_steps"] + scheduler.timesteps = cached["timesteps"] + scheduler.sigmas = cached["sigmas"] + scheduler.model_outputs = [None] * scheduler.config.solver_order + scheduler.lower_order_nums = 0 + scheduler._step_index = None + scheduler._begin_index = None + + return scheduler + + def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0): + scheduler = prepare_noise_scheduler(self) + + condition = torch.cat([condition, neg_condition], dim=0).to( + self.model.prediction_head.device + ) + batch_size = condition.shape[0] // 2 + speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition) + t_batch_cache_key = ( + self.ddpm_inference_steps, + condition.device.type, + condition.device.index, + condition.dtype, + batch_size, + ) + t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {}) + t_batches = t_batch_cache.get(t_batch_cache_key) + if t_batches is None or len(t_batches) != len(scheduler.timesteps): + t_batches = [ + t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype) + for t in scheduler.timesteps + ] + t_batch_cache[t_batch_cache_key] = t_batches + self._vibepod_t_batch_cache = t_batch_cache + + for t, t_batch in zip(scheduler.timesteps, t_batches): + if batch_size == 1: + combined = speech.expand(condition.shape[0], -1) + else: + combined = torch.cat([speech, speech], dim=0) + if profile_enabled(): + profile_sync() + started = time.perf_counter() + eps = self.model.prediction_head(combined, t_batch, condition=condition) + if profile_enabled(): + profile_sync() + profile_record(self, "diffusion_prediction_head", time.perf_counter() - started) + cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0) + guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + if profile_enabled(): + started = time.perf_counter() + speech = scheduler.step(guided_eps, t, speech).prev_sample + if profile_enabled(): + profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started) + + return speech + + forward_lm = model.forward_lm + forward_tts_lm = model.forward_tts_lm + acoustic_decode = model.model.acoustic_tokenizer.decode + + def forward_lm_profiled(*args, **kwargs): + return timed_method(model, "forward_lm", forward_lm, *args, **kwargs) + + def forward_tts_lm_profiled(*args, **kwargs): + return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs) + + def acoustic_decode_profiled(*args, **kwargs): + return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs) + + model.forward_lm = forward_lm_profiled + model.forward_tts_lm = forward_tts_lm_profiled + model.model.acoustic_tokenizer.decode = acoustic_decode_profiled + model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model) + logger.info("Installed VibeVoice generation hot-path optimizations.") + + +def _install_cpu_pipeline_optimizations(model: object) -> None: + """Attach the decode executor to the model for the optimised generate() loop. + + The JezzWTF/VibeVoice fork's generate() checks for two optional attributes: + + model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) that + overlaps acoustic_decode with acoustic_connector + forward_tts_lm. + Profiling showed this hides ~72s of decode cost behind tts_lm work, + capturing ~96% of the theoretical overlap savings. + + model._vibepod_cfg_executor — intentionally NOT set. Parallel pos/neg + forward_tts_lm via a second thread causes MKL OpenMP thread-pool + contention on CPU: both threads compete for the same OMP worker pool, + making each call slower rather than faster. Net effect: ~6% regression. + The hook remains in the fork for potential GPU or future use. + + Attributes default to None, so the fork's generate() falls back to the + original sequential behaviour on CUDA or any non-VibePod install. + """ + global _decode_executor + + if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1": + logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.") + return + + _decode_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="vibepod-decode" + ) + model._vibepod_decode_executor = _decode_executor + logger.info( + "CPU pipeline: decode executor attached — acoustic_decode overlaps " + "tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0." + ) + + +def _model_float_dtype() -> torch.dtype: + try: + return next(_model.parameters()).dtype + except StopIteration: + return torch.float32 + + +def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object: + if torch.is_tensor(value): + if torch.is_floating_point(value): + return value.to(device=device, dtype=dtype) + return value.to(device=device) + if isinstance(value, dict): + for k in list(value.keys()): + value[k] = _move_cached_prompt(value[k], device, dtype) + return value + if isinstance(value, list): + return [_move_cached_prompt(v, device, dtype) for v in value] + if isinstance(value, tuple): + return tuple(_move_cached_prompt(v, device, dtype) for v in value) + if hasattr(value, "key_cache") and hasattr(value, "value_cache"): + value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache] + value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache] + return value + + def _load_voice_presets(device: str) -> dict[str, object]: presets = {} for name, filename in EN_VOICES.items(): @@ -273,19 +624,33 @@ def _load_model_sync() -> None: is_cpu = _device == "cpu" _config["device"] = _device _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) - _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0) - _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4) - _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5) - _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) + _config["prebuffer_secs"] = _env_float( + "VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0 + ) + _config["rebuffer_threshold_secs"] = _env_float( + "VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0 + ) + _config["resume_threshold_secs"] = _env_float( + "VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0 + ) + _config["default_inference_steps"] = _env_int( + "VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10 + ) + if is_cpu: + logical_cpus = os.cpu_count() or 1 + _config["cpu_threads"] = _env_int( + "VIBEPOD_CPU_THREADS", + max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus, + ) + _config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) + _config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" _processor = _init_processor() _model = _init_model(_device) _voice_presets = _load_voice_presets(_device) _model_status = "online" - logger.info( - "Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()) - ) + logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())) logger.info("Configuration: %s", _config) except Exception as exc: @@ -302,6 +667,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader") thread.start() yield + if _decode_executor is not None: + _decode_executor.shutdown(wait=False) app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan) @@ -314,7 +681,7 @@ class GenerateRequest(BaseModel): text: str = Field(..., min_length=1, max_length=10_000) speaker: str = Field(default=DEFAULT_SPEAKER) cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0) - inference_steps: Optional[int] = Field(default=None, ge=5, le=20) + inference_steps: int | None = Field(default=None, ge=5, le=20) @field_validator("text") @classmethod @@ -353,8 +720,8 @@ async def health() -> dict: def _sync_generate( req: GenerateRequest, - streamer: Optional[object] = None, - cancel_event: Optional[threading.Event] = None, + streamer: object | None = None, + cancel_event: threading.Event | None = None, ) -> str: """Blocking inference. Returns the speaker used. Runs in a thread-pool executor — do not call from the event loop directly. @@ -364,10 +731,17 @@ def _sync_generate( raise RuntimeError("Generation cancelled.") speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER - voice_preset = copy.deepcopy(_voice_presets[speaker]) + model_dtype = _model_float_dtype() + voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype) - steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"] + steps = ( + req.inference_steps + if req.inference_steps is not None + else _config["default_inference_steps"] + ) _model.set_ddpm_inference_steps(num_steps=steps) + if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1": + _model._vibepod_profile = {} inputs = _processor.process_input_with_cached_prompt( text=req.text, @@ -380,19 +754,20 @@ def _sync_generate( if torch.is_tensor(v): inputs[k] = v.to(_device) - outputs = _model.generate( - **inputs, - max_new_tokens=None, - cfg_scale=req.cfg_scale, - tokenizer=_processor.tokenizer, - generation_config={"do_sample": False}, - verbose=True, - all_prefilled_outputs=copy.deepcopy(voice_preset), - audio_streamer=streamer, - ) - - if not outputs.speech_outputs or outputs.speech_outputs[0] is None: - raise ValueError("Model returned no audio output.") + with torch.inference_mode(): + _model.generate( + **inputs, + max_new_tokens=None, + cfg_scale=req.cfg_scale, + tokenizer=_processor.tokenizer, + generation_config={"do_sample": False}, + verbose=False, + show_progress_bar=False, + return_speech=False, + stop_check_fn=cancel_event.is_set if cancel_event else None, + all_prefilled_outputs=voice_preset, + audio_streamer=streamer, + ) return speaker @@ -401,6 +776,22 @@ def _sse(event: dict) -> str: return f"data: {json.dumps(event)}\n\n" +def _generation_profile() -> dict[str, dict[str, float]] | None: + if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1": + return None + stats = getattr(_model, "_vibepod_profile", None) + if not stats: + return {} + return { + key: { + "count": value["count"], + "seconds": round(value["seconds"], 3), + "avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0, + } + for key, value in sorted(stats.items()) + } + + @app.post("/generate") async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if _model_status != "online": @@ -417,28 +808,62 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: ) async def event_stream() -> AsyncGenerator[str, None]: - from vibevoice.modular.streamer import AsyncAudioStreamer + class NonBlockingAudioStreamer: + """Async streamer that keeps GPU->CPU copies out of the model thread.""" + + def __init__(self, batch_size: int, stop_signal: object = None) -> None: + self.batch_size = batch_size + self.stop_signal = stop_signal + self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] + self.finished_flags = [False for _ in range(batch_size)] + self.loop = asyncio.get_running_loop() + + def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None: + for i, sample_idx in enumerate(sample_indices): + idx = sample_idx.item() + if idx < self.batch_size and not self.finished_flags[idx]: + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, + audio_chunks[i].detach(), + ) + + def end(self, sample_indices: torch.Tensor | None = None) -> None: + if sample_indices is None: + indices_to_end = range(self.batch_size) + else: + indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] + for idx in indices_to_end: + if idx < self.batch_size and not self.finished_flags[idx]: + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, self.stop_signal + ) + self.finished_flags[idx] = True start = time.monotonic() - streamer = AsyncAudioStreamer(batch_size=1) + streamer = NonBlockingAudioStreamer(batch_size=1) cancel_event = threading.Event() accum_size = max(1, _config["chunk_accum"]) accumulated_chunks = [] + chunk_count = 0 + audio_samples = 0 + first_chunk_at: float | None = None + last_chunk_at: float | None = None + max_chunk_gap = 0.0 + speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER async with _generation_lock: loop = asyncio.get_event_loop() future = loop.run_in_executor( None, functools.partial(_sync_generate, req, streamer, cancel_event) ) + future.add_done_callback(lambda _: streamer.end()) # Drain audio chunks as they arrive from the diffusion head. # stop_signal=None is the default sentinel that ends the queue. while True: try: - chunk = await asyncio.wait_for( - streamer.audio_queues[0].get(), timeout=120.0 - ) + chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0) except asyncio.TimeoutError: cancel_event.set() future.cancel() @@ -454,17 +879,45 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: if chunk is None: # stop signal break - accumulated_chunks.append(chunk.detach().cpu().float()) + accumulated_chunks.append(chunk.detach()) if len(accumulated_chunks) >= accum_size: - combined = torch.cat(accumulated_chunks, dim=0) + now = time.monotonic() + if first_chunk_at is None: + first_chunk_at = now + if last_chunk_at is not None: + max_chunk_gap = max(max_chunk_gap, now - last_chunk_at) + last_chunk_at = now + + combined = ( + torch.cat(accumulated_chunks, dim=0) + .detach() + .to("cpu", dtype=torch.float32) + .contiguous() + ) + chunk_count += 1 + audio_samples += combined.numel() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() yield _sse({"type": "audio_chunk", "data": pcm_b64}) accumulated_chunks = [] # Flush any remaining chunks if accumulated_chunks: - combined = torch.cat(accumulated_chunks, dim=0) + now = time.monotonic() + if first_chunk_at is None: + first_chunk_at = now + if last_chunk_at is not None: + max_chunk_gap = max(max_chunk_gap, now - last_chunk_at) + last_chunk_at = now + + combined = ( + torch.cat(accumulated_chunks, dim=0) + .detach() + .to("cpu", dtype=torch.float32) + .contiguous() + ) + chunk_count += 1 + audio_samples += combined.numel() pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode() yield _sse({"type": "audio_chunk", "data": pcm_b64}) @@ -479,17 +932,40 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: yield _sse( { "type": "error", - "message": "Internal server error during generation.", + "message": f"Generation failed: {exc}", } ) return elapsed = round(time.monotonic() - start, 1) + audio_secs = audio_samples / SAMPLE_RATE + realtime_factor = audio_secs / elapsed if elapsed > 0 else None + profile = _generation_profile() + if profile is not None: + logger.info("Generation profile: %s", profile) logger.info("Generation complete in %.1fs", elapsed) - yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker}) + complete_event = { + "type": "complete", + "elapsed": elapsed, + "speaker": speaker, + "audio_secs": round(audio_secs, 2), + "realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None, + "chunks": chunk_count, + "first_chunk_secs": round(first_chunk_at - start, 2) + if first_chunk_at is not None + else None, + "max_chunk_gap_secs": round(max_chunk_gap, 2), + } + if profile is not None: + complete_event["profile"] = profile + yield _sse(complete_event) return StreamingResponse( event_stream(), media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + headers={ + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "X-Content-Type-Options": "nosniff", + }, ) diff --git a/web/app/api/generate/route.ts b/web/app/api/generate/route.ts index 8bd1b94..310bb01 100644 --- a/web/app/api/generate/route.ts +++ b/web/app/api/generate/route.ts @@ -1,10 +1,13 @@ import { NextRequest, NextResponse } from "next/server"; +export const dynamic = "force-dynamic"; +export const runtime = "nodejs"; + export async function POST(request: NextRequest) { const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; try { - const body = await request.json() as { + const body = (await request.json()) as { text: string; speaker?: string; cfg_scale?: number; @@ -24,6 +27,7 @@ export async function POST(request: NextRequest) { cfg_scale: body.cfg_scale ?? 1.5, inference_steps: body.inference_steps ?? 10, }), + signal: request.signal, }); if (!upstream.ok) { @@ -36,8 +40,9 @@ export async function POST(request: NextRequest) { status: 200, headers: { "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + "X-Content-Type-Options": "nosniff", "X-Accel-Buffering": "no", }, }); diff --git a/web/app/api/health/route.ts b/web/app/api/health/route.ts index dffb3f8..e4d3506 100644 --- a/web/app/api/health/route.ts +++ b/web/app/api/health/route.ts @@ -4,8 +4,7 @@ const OFFLINE_RESPONSE = { status: "offline" }; const COMMON_OPTIONS = { headers: { "Cache-Control": "no-store" } }; export async function GET() { - const pythonServerUrl = - process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; + const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000"; try { const res = await fetch(`${pythonServerUrl}/health`, { @@ -27,6 +26,7 @@ export async function GET() { message: data.message, progress: data.progress ?? null, voices: data.voices ?? [], + config: data.config ?? null, }, COMMON_OPTIONS ); diff --git a/web/app/globals.css b/web/app/globals.css index 9388e7f..d4569ee 100644 --- a/web/app/globals.css +++ b/web/app/globals.css @@ -12,8 +12,10 @@ --muted: #64748b; --success: #22c55e; --error: #ef4444; - --font-sans: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; - --font-mono: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas, "Liberation Mono", monospace; + --font-sans: + ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + --font-mono: + ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas, "Liberation Mono", monospace; } @theme inline { diff --git a/web/app/page.tsx b/web/app/page.tsx index 275d658..128824a 100644 --- a/web/app/page.tsx +++ b/web/app/page.tsx @@ -69,19 +69,39 @@ type AppAction = function reducer(state: AppState, action: AppAction): AppState { switch (action.type) { - case "SET_SCRIPT": return { ...state, script: action.payload }; - case "SET_SPEAKER": return { ...state, speaker: action.payload }; - case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload }; - case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload }; - case "SET_PREBUFFER_SECS": return { ...state, prebufferSecs: action.payload }; - case "SET_REBUFFER_THRESHOLD": return { ...state, rebufferThresholdSecs: action.payload }; - case "SET_RESUME_THRESHOLD": return { ...state, resumeThresholdSecs: action.payload }; + case "SET_SCRIPT": + return { ...state, script: action.payload }; + case "SET_SPEAKER": + return { ...state, speaker: action.payload }; + case "SET_CFG_SCALE": + return { ...state, cfgScale: action.payload }; + case "SET_INFERENCE_STEPS": + return { ...state, inferenceSteps: action.payload }; + case "SET_PREBUFFER_SECS": + return { ...state, prebufferSecs: action.payload }; + case "SET_REBUFFER_THRESHOLD": + return { ...state, rebufferThresholdSecs: action.payload }; + case "SET_RESUME_THRESHOLD": + return { ...state, resumeThresholdSecs: action.payload }; case "START_GENERATION": - return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null }; + return { + ...state, + isGenerating: true, + audioUrl: null, + logs: [], + genElapsed: 0, + genPct: null, + }; case "GEN_PROGRESS": return { ...state, genElapsed: action.elapsed, genPct: action.pct }; case "GENERATION_SUCCESS": - return { ...state, isGenerating: false, genElapsed: 0, genPct: null, audioUrl: action.payload }; + return { + ...state, + isGenerating: false, + genElapsed: 0, + genPct: null, + audioUrl: action.payload, + }; case "GENERATION_CANCELLED": case "GENERATION_ERROR": return { ...state, isGenerating: false, genElapsed: 0, genPct: null }; @@ -89,21 +109,27 @@ function reducer(state: AppState, action: AppAction): AppState { return { ...state, logs: [...state.logs, action.payload] }; case "SET_SERVER_STATUS": { const isNewConfig = !state.serverConfig && action.payload.config; - const deviceChanged = !!(state.serverConfig && action.payload.config && state.serverConfig.device !== action.payload.config.device); + const deviceChanged = !!( + state.serverConfig && + action.payload.config && + state.serverConfig.device !== action.payload.config.device + ); - const nextSteps = (isNewConfig || deviceChanged) + const nextSteps = + isNewConfig || deviceChanged ? action.payload.config!.default_inference_steps : state.inferenceSteps; - const nextPrebuffer = (isNewConfig || deviceChanged) - ? action.payload.config!.prebuffer_secs - : state.prebufferSecs; + const nextPrebuffer = + isNewConfig || deviceChanged ? action.payload.config!.prebuffer_secs : state.prebufferSecs; - const nextRebuffer = (isNewConfig || deviceChanged) + const nextRebuffer = + isNewConfig || deviceChanged ? action.payload.config!.rebuffer_threshold_secs : state.rebufferThresholdSecs; - const nextResume = (isNewConfig || deviceChanged) + const nextResume = + isNewConfig || deviceChanged ? action.payload.config!.resume_threshold_secs : state.resumeThresholdSecs; @@ -121,7 +147,8 @@ function reducer(state: AppState, action: AppAction): AppState { resumeThresholdSecs: nextResume, }; } - default: return state; + default: + return state; } } @@ -130,9 +157,9 @@ const initialState: AppState = { speaker: "carter", cfgScale: 1.5, inferenceSteps: 10, - prebufferSecs: 2.0, - rebufferThresholdSecs: 0.4, - resumeThresholdSecs: 1.5, + prebufferSecs: 5.0, + rebufferThresholdSecs: 1.0, + resumeThresholdSecs: 3.0, isGenerating: false, genElapsed: 0, genPct: null, @@ -213,7 +240,10 @@ export default function HomePage() { } poll(); - return () => { cancelled = true; clearTimeout(timeoutId); }; + return () => { + cancelled = true; + clearTimeout(timeoutId); + }; }, []); const handleGenerate = useCallback(async () => { @@ -241,7 +271,6 @@ export default function HomePage() {
- {/* Left: script + audio player */}
dispatch({ type: "SET_CFG_SCALE", payload: v })} inferenceSteps={state.inferenceSteps} onInferenceStepsChange={(v) => dispatch({ type: "SET_INFERENCE_STEPS", payload: v })} - prebufferSecs={state.prebufferSecs} - onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })} - rebufferThresholdSecs={state.rebufferThresholdSecs} - onRebufferThresholdChange={(v) => dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })} - resumeThresholdSecs={state.resumeThresholdSecs} - onResumeThresholdChange={(v) => dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })} + prebufferSecs={state.prebufferSecs} + onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })} + rebufferThresholdSecs={state.rebufferThresholdSecs} + onRebufferThresholdChange={(v) => + dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v }) + } + resumeThresholdSecs={state.resumeThresholdSecs} + onResumeThresholdChange={(v) => + dispatch({ type: "SET_RESUME_THRESHOLD", payload: v }) + } onGenerate={handleGenerate} onStop={stop} onPauseStream={pauseStream} @@ -281,7 +314,6 @@ export default function HomePage() { />
-
diff --git a/web/components/AudioPlayer.tsx b/web/components/AudioPlayer.tsx index f54f25e..36ca5b1 100644 --- a/web/components/AudioPlayer.tsx +++ b/web/components/AudioPlayer.tsx @@ -14,15 +14,8 @@ function formatTime(seconds: number): string { } export default function AudioPlayer({ audioUrl }: AudioPlayerProps) { - const { - isPlaying, - currentTime, - duration, - volume, - toggle, - seek, - setVolume, - } = useAudioPlayer(audioUrl); + const { isPlaying, currentTime, duration, volume, toggle, seek, setVolume } = + useAudioPlayer(audioUrl); if (!audioUrl) return null; @@ -56,12 +49,10 @@ export default function AudioPlayer({ audioUrl }: AudioPlayerProps) { background: "rgba(45, 212, 191, 0.05)", }} onMouseEnter={(e) => { - (e.currentTarget as HTMLButtonElement).style.background = - "rgba(45, 212, 191, 0.15)"; + (e.currentTarget as HTMLButtonElement).style.background = "rgba(45, 212, 191, 0.15)"; }} onMouseLeave={(e) => { - (e.currentTarget as HTMLButtonElement).style.background = - "rgba(45, 212, 191, 0.05)"; + (e.currentTarget as HTMLButtonElement).style.background = "rgba(45, 212, 191, 0.05)"; }} > {isPlaying ? ( - + ) : ( - + )} @@ -143,9 +125,7 @@ export default function AudioPlayer({ audioUrl }: AudioPlayerProps) { {/* Duration info */}
- - {formatTime(currentTime)} - + {formatTime(currentTime)} / {formatTime(duration)}
diff --git a/web/components/GenerationControls.tsx b/web/components/GenerationControls.tsx index f9a7d4c..373fe0d 100644 --- a/web/components/GenerationControls.tsx +++ b/web/components/GenerationControls.tsx @@ -36,18 +36,27 @@ const STATUS_CONFIG: Record< Exclude, { color: string; label: (p: DownloadProgress | null) => string } > = { - offline: { color: "var(--error)", label: () => "Server offline — waiting for connection..." }, - downloading: { color: "#60a5fa", label: (p) => p && p.total > 0 ? `Downloading model... (${p.done} / ${p.total} files)` : "Downloading model (~1 GB)..." }, - loading: { color: "#fbbf24", label: () => "Loading model into memory..." }, - error: { color: "var(--error)", label: () => "Server error — check the terminal for details." }, + offline: { color: "var(--error)", label: () => "Server offline — waiting for connection..." }, + downloading: { + color: "#60a5fa", + label: (p) => + p && p.total > 0 + ? `Downloading model... (${p.done} / ${p.total} files)` + : "Downloading model (~1 GB)...", + }, + loading: { color: "#fbbf24", label: () => "Loading model into memory..." }, + error: { color: "var(--error)", label: () => "Server error — check the terminal for details." }, }; - function SpinnerIcon() { return ( - + ); } @@ -146,7 +155,10 @@ export default function GenerationControls({ onChange={(e) => onCfgScaleChange(parseFloat(e.target.value))} className="w-full" /> -
+
Flat (0.5) CFG Scale Expressive (4.0) @@ -157,7 +169,7 @@ export default function GenerationControls({
-
+
Faster (5) Diffusion Steps Better (20) @@ -207,7 +222,11 @@ export default function GenerationControls({
{showAdvanced && ( -
+
{/* Pre-buffer */}
@@ -221,7 +240,7 @@ export default function GenerationControls({ onPrebufferSecsChange(parseFloat(e.target.value))} @@ -232,7 +251,11 @@ export default function GenerationControls({ {/* Re-buffer threshold */}
-