Merge pull request #12 from JezzWTF/codex/gen-optim

Improve streaming generation flow and CPU startup handling
This commit is contained in:
2026-05-02 16:46:59 +01:00
committed by GitHub
32 changed files with 1251 additions and 383 deletions
-20
View File
@@ -1,20 +0,0 @@
{
"permissions": {
"allow": [
"Bash(mv podcast-forge/pnpm-lock.yaml /tmp/vibepod-pnpm-lock.yaml)",
"Bash(git mv *)",
"Bash(mv /tmp/vibepod-pnpm-lock.yaml web/pnpm-lock.yaml)",
"Bash(git rm *)",
"Bash(uv lock *)",
"Bash(pnpm install *)",
"Bash(git add *)",
"Bash(command -v uv)",
"Bash(uv --version)",
"Bash(uv sync *)",
"Bash(pnpm --filter vibepod-web exec tsc --noEmit)",
"Bash(xargs cat *)",
"Bash(.venv/Scripts/python.exe -c \"import torch; print\\('torch:', torch.__version__\\); print\\('CUDA available:', torch.cuda.is_available\\(\\)\\); print\\('CUDA version:', torch.version.cuda\\)\")",
"Bash(nvidia-smi)"
]
}
}
+37
View File
@@ -0,0 +1,37 @@
root = true
[*]
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.{ts,tsx,js,jsx,mjs,cjs,mts,cts}]
indent_style = space
indent_size = 2
[*.{json,jsonc}]
indent_style = space
indent_size = 2
[*.{css,html}]
indent_style = space
indent_size = 2
[*.{yaml,yml}]
indent_style = space
indent_size = 2
[*.py]
indent_style = space
indent_size = 4
[*.{toml}]
indent_style = space
indent_size = 4
[*.md]
trim_trailing_whitespace = false
[Makefile]
indent_style = tab
+39
View File
@@ -8,3 +8,42 @@ HF_TOKEN=
# Override the HuggingFace model cache directory (optional)
# HF_HOME=/path/to/hf-cache
# ---------------------------------------------------------------------------
# Runtime tuning
# ---------------------------------------------------------------------------
# Force the Python server device. Usually set by `pnpm dev` / `pnpm dev:cpu`.
# VIBEPOD_DEVICE=cuda
# VIBEPOD_DEVICE=cpu
# CPU mode: keep async decode enabled. This overlaps acoustic decoding with
# language-model work and measured ~20% faster on an 8-thread CPU run.
VIBEPOD_ASYNC_DECODE=1
# CPU mode: thread tuning. On an 8-core / 16-thread Ryzen test system,
# 8 worker threads with 1 inter-op thread gave the best wall time, while 12
# over-subscribed and regressed.
# VIBEPOD_CPU_THREADS=8
# VIBEPOD_CPU_INTEROP_THREADS=1
# CPU mode: playback buffering. CPU generation is slower than realtime, so
# smooth streaming needs a larger initial buffer than CUDA. Lower these for
# faster startup if you are OK with occasional rebuffering.
# VIBEPOD_PREBUFFER_SECS=24
# VIBEPOD_REBUFFER_THRESHOLD_SECS=2
# VIBEPOD_RESUME_THRESHOLD_SECS=12
# CPU mode: dynamic INT8 quantization is enabled by default in start.sh.
# Set to 0 if you are comparing quality/performance or debugging.
# VIBEPOD_QUANTIZE=1
# CUDA mode: dtype and attention selection. Defaults are bf16 + SDPA unless
# optional FlashAttention is explicitly enabled and importable.
# VIBEPOD_CUDA_DTYPE=bf16
# VIBEPOD_ATTN_IMPL=sdpa
# VIBEPOD_ENABLE_FLASH_ATTN=0
# Debug/profiling. Keep disabled for benchmark timing; async CPU profiling
# double-counts overlapped decode work.
# VIBEPOD_PROFILE_GENERATION=0
+24
View File
@@ -0,0 +1,24 @@
* text=auto eol=lf
*.sh text eol=lf
*.py text eol=lf
*.ts text eol=lf
*.tsx text eol=lf
*.js text eol=lf
*.jsx text eol=lf
*.mjs text eol=lf
*.cjs text eol=lf
*.mts text eol=lf
*.cts text eol=lf
*.css text eol=lf
*.html text eol=lf
*.json text eol=lf
*.jsonc text eol=lf
*.yaml text eol=lf
*.yml text eol=lf
*.toml text eol=lf
*.md text eol=lf
*.mdx text eol=lf
*.lock text eol=lf
*.env text eol=lf
*.env.* text eol=lf
+1
View File
@@ -24,3 +24,4 @@ web/node_modules/
.DS_Store
Thumbs.db
.vscode/settings.json
.claude/settings.local.json
+18
View File
@@ -0,0 +1,18 @@
# Dependencies
node_modules/
web/node_modules/
# Build outputs
web/.next/
web/tsconfig.tsbuildinfo
web/next-env.d.ts
# Python / server
server/
# Lock files
pnpm-lock.yaml
web/pnpm-lock.yaml
# Generated
web/public/
+8
View File
@@ -0,0 +1,8 @@
{
"semi": true,
"singleQuote": false,
"tabWidth": 2,
"trailingComma": "es5",
"printWidth": 100,
"endOfLine": "lf"
}
+10 -2
View File
@@ -9,7 +9,7 @@ This file gives AI coding agents (Jules, Copilot, Claude Code, etc.) the context
VibePod is a text-to-speech web app. It has two services that must both run for the app to work:
| Service | Language | Entry point | Port |
|---------|----------|-------------|------|
| ---------- | ---------------------------------- | ------------------------------- | ---- |
| **server** | Python 3.10+ (FastAPI + VibeVoice) | `server/start.sh` | 8000 |
| **web** | TypeScript (Next.js 15, React 19) | `pnpm --filter vibepod-web dev` | 3000 |
@@ -52,7 +52,7 @@ pnpm build
The `--cpu` flag in `start.sh` sets `VIBEPOD_DEVICE=cpu` and uses a separate venv (`server/.venv-cpu`) so CUDA and CPU installs never conflict. `vibevoice_server.py` reads `VIBEPOD_DEVICE` at startup via `_resolve_device()` — do not remove or rename that function.
| Env var | Values | Set by |
|---------|--------|--------|
| ------------------------ | ----------------------- | --------------------------- |
| `VIBEPOD_DEVICE` | `cpu` \| `cuda` | `server/start.sh` |
| `UV_PROJECT_ENVIRONMENT` | `.venv-cpu` \| `.venv` | `server/start.sh` |
| `HF_TOKEN` | HuggingFace token | Jules secret / `.env.local` |
@@ -94,7 +94,9 @@ dev.sh Concurrent launcher (forwards flags to start.sh)
## API reference
### `GET /health`
Returns server status. Safe to poll.
```json
{
"status": "online",
@@ -103,13 +105,17 @@ Returns server status. Safe to poll.
"voices": ["carter", "davis", "emma", "frank", "grace", "mike"]
}
```
`status` values: `downloading` | `loading` | `online` | `error`
### `POST /generate`
Streams audio as SSE events.
```json
{ "text": "Hello world", "speaker": "carter", "cfg_scale": 1.5, "inference_steps": 10 }
```
Event types: `audio_chunk` (base64 float32 PCM) | `complete` | `error` | `cancelled`
---
@@ -117,12 +123,14 @@ Event types: `audio_chunk` (base64 float32 PCM) | `complete` | `error` | `cancel
## Do / Don't
**Do:**
- Use `pnpm dev:cpu` in Jules — never plain `pnpm dev`
- Run `git checkout server/uv.lock` if uv rewrites it during setup
- Keep `_resolve_device()` in `vibevoice_server.py` — it's the CPU/CUDA switching logic
- Test server changes against `GET /health` and `POST /generate`
**Don't:**
- Run `uv sync` without `UV_PROJECT_ENVIRONMENT=.venv-cpu` in the Jules sandbox
- Install Python packages with pip
- Modify `server/uv.lock` manually
+5
View File
@@ -173,16 +173,21 @@ The shape language is a hybrid of structural precision and tactile softness.
## Components
### Card Containers
The fundamental building block of the UI. Every distinct section (Script, Player, Controls, Logs) is housed in a card featuring the `card-bg`, a 1px `border`, and `rounded-xl` corners. The internal layout always features an uppercase teal header for immediate section identification.
### Primary Action Buttons
Used for high-leverage actions like "Generate Audio" and "Play/Pause." These buttons utilize the `gradient-primary-dim` background, bold white text, and emit a soft teal glow to draw the eye and signify their importance.
### Range Sliders
Custom-styled input ranges replace default browser styles. The tracks are muted and slim, while the thumbs are bright teal, fully rounded, and emit a glow that intensifies on hover, providing a premium, tactile scrubbing experience.
### Status Indicators & Logs
A critical component of the application. Status badges utilize a minimalist pill shape with a pulsing ring animation to indicate active server processing. The log panel explicitly uses monospace typography and color-codes messages (green for success, red for error, white for neutral) to provide a terminal-like readout of the backend systems.
### Gradients
Gradients are used purposefully to indicate progress, activity, or brand presence. The primary gradient (`135deg` from teal to violet) is used for branding (the logo icon and text) and primary buttons. Horizontal gradients (`90deg`) are used dynamically in progress bars to represent the flow of data over time (e.g., loading, downloading, and audio generation).
+4 -4
View File
@@ -15,7 +15,7 @@ The Next.js app proxies audio generation requests to the FastAPI server, keeping
## Prerequisites
| Tool | Install |
|------|---------|
| ---------------------------------- | ----------------------------------- |
| [Node.js 20+](https://nodejs.org) | `winget install OpenJS.NodeJS.LTS` |
| [pnpm](https://pnpm.io) | `npm i -g pnpm` |
| [Python 3.10+](https://python.org) | `winget install Python.Python.3.13` |
@@ -51,7 +51,7 @@ The frontend shows a loading indicator while the model downloads. Once the serve
VibePod maintains two completely separate Python virtual environments so CUDA and CPU torch installs never conflict:
| Mode | Command | venv | torch source |
|------|---------|------|--------------|
| -------------- | -------------- | ------------------ | ----------------------- |
| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index |
| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) |
@@ -75,7 +75,7 @@ pnpm build # Production build of the frontend
Copy `.env.example` to `.env.local` and set:
| Variable | Default | Description |
|----------|---------|-------------|
| ---------------------- | ----------------------- | --------------------------------------------------------- |
| `VIBEVOICE_SERVER_URL` | `http://localhost:8000` | URL the Next.js API routes use to reach the Python server |
| `HF_TOKEN` | — | HuggingFace token (required if the model repo is gated) |
| `HF_HOME` | — | Override the HuggingFace model cache directory |
@@ -108,7 +108,7 @@ server/
## Generation parameters
| Parameter | Range | Default | Effect |
|-----------|-------|---------|--------|
| ----------------- | --------------------------------------------------- | -------- | ---------------------------------------------- |
| `speaker` | `carter`, `davis`, `emma`, `frank`, `grace`, `mike` | `carter` | Voice preset used for the generated audio |
| `cfg_scale` | 0.5 4.0 | 1.5 | Higher = more expressive guidance |
| `inference_steps` | 5 20 | 10 | More steps = higher quality, slower generation |
+8 -1
View File
@@ -8,7 +8,14 @@
"dev:cpu": "bash dev.sh --cpu",
"dev:server": "bash server/start.sh",
"dev:server:cpu": "bash server/start.sh --cpu",
"dev:web": "pnpm --filter vibepod-web dev"
"dev:web": "pnpm --filter vibepod-web dev",
"format": "prettier --write . && cd server && uv run ruff format .",
"format:check": "prettier --check . && cd server && uv run ruff format --check .",
"lint:server": "cd server && uv run ruff check .",
"lint:server:fix": "cd server && uv run ruff check --fix ."
},
"devDependencies": {
"prettier": "^3.5.3"
},
"packageManager": "pnpm@10.33.2+sha512.a90faf6feeab71ad6c6e57f94e0fe1a12f5dcc22cd754db40ae9593eb6a3e0b6b12e3540218bb37ae083404b1f2ce6db2a4121e979829b4aff94b99f49da1cf8"
}
+12 -1
View File
@@ -6,7 +6,11 @@ settings:
importers:
.: {}
.:
devDependencies:
prettier:
specifier: ^3.5.3
version: 3.8.3
web:
dependencies:
@@ -516,6 +520,11 @@ packages:
resolution: {integrity: sha512-W62t/Se6rA0Az3DfCL0AqJwXuKwBeYg6nOaIgzP+xZ7N5BFCI7DYi1qs6ygUYT6rvfi6t9k65UMLJC+PHZpDAA==}
engines: {node: ^10 || ^12 || >=14}
prettier@3.8.3:
resolution: {integrity: sha512-7igPTM53cGHMW8xWuVTydi2KO233VFiTNyF5hLJqpilHfmn8C8gPf+PS7dUT64YcXFbiMGZxS9pCSxL/Dxm/Jw==}
engines: {node: '>=14'}
hasBin: true
react-dom@19.1.0:
resolution: {integrity: sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==}
peerDependencies:
@@ -917,6 +926,8 @@ snapshots:
picocolors: 1.1.1
source-map-js: 1.2.1
prettier@3.8.3: {}
react-dom@19.1.0(react@19.1.0):
dependencies:
react: 19.1.0
+1 -1
View File
@@ -1,2 +1,2 @@
packages:
- 'web'
- "web"
+7
View File
@@ -39,6 +39,13 @@ VibePod Studio will turn generated audio from a one-shot download into a reusabl
- Add project save/load, autosave, and recoverable render jobs.
- Prepare the audio pipeline for queueing longer renders outside the request lifecycle.
## Later: VibeVoice Performance Research
- Move the current VibePod hot-path monkey patches into the `JezzWTF/VibeVoice` fork once the feature direction has settled.
- Add clearer generation profiling for overlapped CPU work, especially decode wait time versus total acoustic decode time.
- Prototype batched positive/negative CFG TTS LM inference behind an opt-in flag and benchmark it against the current sequential path on CPU and CUDA.
- Keep experimental performance work isolated from user-facing feature work unless it shows a clear speedup without audio quality regressions.
## Foundation Work Needed First
- Persist generated outputs with stable IDs.
+1
View File
@@ -0,0 +1 @@
3.12
+2 -5
View File
@@ -30,15 +30,12 @@ def download() -> str:
from huggingface_hub import snapshot_download
except ImportError:
print(
"ERROR: huggingface_hub is not installed.\n"
"Run: pip install huggingface_hub",
"ERROR: huggingface_hub is not installed.\nRun: pip install huggingface_hub",
file=sys.stderr,
)
sys.exit(1)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
print(f"Checking / downloading model: {MODEL_ID}")
print("(This may take several minutes on first run — the model is ~1 GB)")
+28 -1
View File
@@ -8,7 +8,8 @@ dependencies = [
# To switch back to CPU-only, remove the [tool.uv.sources] torch entry below.
"torch>=2.0.0",
# VibeVoice custom model + processor classes (not yet in upstream transformers)
"vibevoice @ git+https://github.com/microsoft/VibeVoice.git",
# Uses JezzWTF/VibeVoice fork so VibePod-specific optimisations land here.
"vibevoice @ git+https://github.com/JezzWTF/VibeVoice.git",
# Exact version required by vibevoice's streaming TTS module
"transformers==4.51.3",
"fastapi>=0.111.0",
@@ -18,10 +19,36 @@ dependencies = [
"huggingface_hub>=0.23.0",
]
[dependency-groups]
dev = [
"ruff>=0.11.0",
]
# No build-system — this is a scripts project, not an installable package.
# Lock file is committed so installs are reproducible.
# Run `uv lock --upgrade` to bump dependencies.
[tool.ruff]
line-length = 100
indent-width = 4
target-version = "py310"
exclude = [".git", ".venv", ".venv-cpu", "__pycache__"]
[tool.ruff.lint]
select = ["E", "F", "UP", "B", "SIM", "I"]
ignore = [
"E501", # line-too-long — handled by formatter
"B905", # zip() without strict= — existing code, lengths are known-correct
]
[tool.ruff.lint.per-file-ignores]
"download_model.py" = ["T201"] # allow print statements in CLI scripts
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
line-ending = "lf"
[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
+90 -1
View File
@@ -7,6 +7,11 @@
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
#
# Optional CUDA acceleration:
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
#
# The two modes maintain completely separate virtual environments so their torch
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
@@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then
exit 1
fi
validate_flash_attn() {
uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null
}
remove_broken_flash_attn() {
if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then
if ! validate_flash_attn; then
echo " Installed flash-attn is not usable in this environment; removing it."
uv pip uninstall flash-attn
fi
fi
}
# ---------------------------------------------------------------------------
# 2. Sync Python environment
# CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the
@@ -61,10 +79,64 @@ echo ""
if $CPU_MODE; then
echo "--> Syncing CPU Python environment (.venv-cpu)..."
export UV_PROJECT_ENVIRONMENT=".venv-cpu"
LOCK_BACKUP=""
if [[ -f uv.lock ]]; then
LOCK_BACKUP="$(mktemp)"
cp uv.lock "$LOCK_BACKUP"
fi
uv sync --no-sources
if [[ -n "$LOCK_BACKUP" ]]; then
cp "$LOCK_BACKUP" uv.lock
rm -f "$LOCK_BACKUP"
fi
else
echo "--> Syncing CUDA Python environment (.venv)..."
uv sync
remove_broken_flash_attn
if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then
echo ""
echo "--> Checking optional FlashAttention wheel..."
if validate_flash_attn; then
echo " flash-attn already installed and importable."
else
PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")"
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
FLASH_ATTN_WHEEL_URL=""
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
case "$(uname -s)" in
MINGW*|CYGWIN*|MSYS*)
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4 (Windows)..."
;;
*)
echo " No pre-built flash-attn wheel available for this platform ($(uname -s))."
echo " Continuing with PyTorch SDPA attention."
;;
esac
else
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
echo " Continuing with PyTorch SDPA attention."
fi
if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]; then
if uv pip install "$FLASH_ATTN_WHEEL_URL"; then
if validate_flash_attn; then
echo " flash-attn import check passed."
else
echo " flash-attn import check failed; removing it and continuing with SDPA."
uv pip uninstall flash-attn
fi
else
echo " flash-attn wheel install failed; continuing with SDPA."
fi
fi
fi
fi
fi
# ---------------------------------------------------------------------------
@@ -78,11 +150,28 @@ export PYTHONUTF8=1
if $CPU_MODE; then
export VIBEPOD_DEVICE="cpu"
export UV_PROJECT_ENVIRONMENT=".venv-cpu"
if [[ -z "${VIBEPOD_CPU_THREADS:-}" ]]; then
VIBEPOD_CPU_THREADS="$(uv run --no-sync --no-sources python -c "import os; print(max(1, (os.cpu_count() or 2) // 2))")"
export VIBEPOD_CPU_THREADS
fi
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$VIBEPOD_CPU_THREADS}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$VIBEPOD_CPU_THREADS}"
# Dynamic INT8 quantization — on by default for CPU (~22% faster, prediction_head
# excluded automatically to avoid regression on small fixed-size tensors).
# Set VIBEPOD_QUANTIZE=0 to disable if you notice audio quality differences.
export VIBEPOD_QUANTIZE="${VIBEPOD_QUANTIZE:-1}"
# Optional CPU flags:
# VIBEPOD_ASYNC_DECODE=0 Disable async decode/tts_lm overlap (on by default)
# VIBEPOD_CPU_BF16=1 Force bfloat16 weights (auto-detected via AVX512_BF16)
# VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive
# models on CPU — not recommended, kept for experimentation)
UV_RUN_ARGS=(--no-sync --no-sources)
else
export VIBEPOD_DEVICE="cuda"
UV_RUN_ARGS=()
fi
exec uv run uvicorn vibevoice_server:app \
exec uv run "${UV_RUN_ARGS[@]}" uvicorn vibevoice_server:app \
--host 127.0.0.1 \
--port 8000 \
--log-level info \
+41 -8
View File
@@ -1479,7 +1479,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12" },
{ name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -1490,7 +1490,7 @@ name = "nvidia-cufft-cu12"
version = "11.2.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
@@ -1509,9 +1509,9 @@ name = "nvidia-cusolver-cu12"
version = "11.6.1.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12" },
{ name = "nvidia-cusparse-cu12" },
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
{ name = "nvidia-cusparse-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
@@ -1522,7 +1522,7 @@ name = "nvidia-cusparse-cu12"
version = "12.3.1.170"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
@@ -2394,6 +2394,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654 },
]
[[package]]
name = "ruff"
version = "0.15.12"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/99/43/3291f1cc9106f4c63bdce7a8d0df5047fe8422a75b091c16b5e9355e0b11/ruff-0.15.12.tar.gz", hash = "sha256:ecea26adb26b4232c0c2ca19ccbc0083a68344180bba2a600605538ce51a40a6", size = 4643852 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/6e/e78ffb61d4686f3d96ba3df2c801161843746dcbcbb17a1e927d4829312b/ruff-0.15.12-py3-none-linux_armv6l.whl", hash = "sha256:f86f176e188e94d6bdbc09f09bfd9dc729059ad93d0e7390b5a73efe19f8861c", size = 10640713 },
{ url = "https://files.pythonhosted.org/packages/ae/08/a317bc231fb9e7b93e4ef3089501e51922ff88d6936ce5cf870c4fe55419/ruff-0.15.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e3bcd123364c3770b8e1b7baaf343cc99a35f197c5c6e8af79015c666c423a6c", size = 11069267 },
{ url = "https://files.pythonhosted.org/packages/aa/a4/f828e9718d3dce1f5f11c39c4f65afd32783c8b2aebb2e3d259e492c47bd/ruff-0.15.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fe87510d000220aa1ed530d4448a7c696a0cae1213e5ec30e5874287b66557b5", size = 10397182 },
{ url = "https://files.pythonhosted.org/packages/71/e0/3310fc6d1b5e1fdea22bf3b1b807c7e187b581021b0d7d4514cccdb5fb71/ruff-0.15.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84a1630093121375a3e2a95b4a6dc7b59e2b4ee76216e32d81aae550a832d002", size = 10758012 },
{ url = "https://files.pythonhosted.org/packages/11/c1/a606911aee04c324ddaa883ae418f3569792fd3c4a10c50e0dd0a2311e1e/ruff-0.15.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb129f40f114f089ebe0ca56c0d251cf2061b17651d464bb6478dc01e69f11f5", size = 10447479 },
{ url = "https://files.pythonhosted.org/packages/9d/68/4201e8444f0894f21ab4aeeaee68aa4f10b51613514a20d80bd628d57e88/ruff-0.15.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0c862b172d695db7598426b8af465e7e9ac00a3ea2a3630ee67eb82e366aaa6", size = 11234040 },
{ url = "https://files.pythonhosted.org/packages/34/ff/8a6d6cf4ccc23fd67060874e832c18919d1557a0611ebef03fdb01fff11e/ruff-0.15.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2849ea9f3484c3aca43a82f484210370319e7170df4dfe4843395ddf6c57bc33", size = 12087377 },
{ url = "https://files.pythonhosted.org/packages/85/f6/c669cf73f5152f623d34e69866a46d5e6185816b19fcd5b6dd8a2d299922/ruff-0.15.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e77c7e51c07fe396826d5969a5b846d9cd4c402535835fb6e21ce8b28fef847", size = 11367784 },
{ url = "https://files.pythonhosted.org/packages/e8/39/c61d193b8a1daaa8977f7dea9e8d8ba866e02ea7b65d32f6861693aa4c12/ruff-0.15.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b2f4f2f3b1026b5fb449b467d9264bf22067b600f7b6f41fc5958909f449d0", size = 11344088 },
{ url = "https://files.pythonhosted.org/packages/c2/8d/49afab3645e31e12c590acb6d3b5b69d7aab5b81926dbaf7461f9441f37a/ruff-0.15.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9ba3b8f1afd7e2e43d8943e55f249e13f9682fde09711644a6e7290eb4f3e339", size = 11271770 },
{ url = "https://files.pythonhosted.org/packages/46/06/33f41fe94403e2b755481cdfb9b7ef3e4e0ed031c4581124658d935d52b4/ruff-0.15.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e852ba9fdc890655e1d78f2df1499efbe0e54126bd405362154a75e2bde159c5", size = 10719355 },
{ url = "https://files.pythonhosted.org/packages/0d/59/18aa4e014debbf559670e4048e39260a85c7fcee84acfd761ac01e7b8d35/ruff-0.15.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dd8aed930da53780d22fc70bdf84452c843cf64f8cb4eb38984319c24c5cd5fd", size = 10462758 },
{ url = "https://files.pythonhosted.org/packages/25/e7/cc9f16fd0f3b5fddcbd7ec3d6ae30c8f3fde1047f32a4093a98d633c6570/ruff-0.15.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01da3988d225628b709493d7dc67c3b9b12c0210016b08690ef9bd27970b262b", size = 10953498 },
{ url = "https://files.pythonhosted.org/packages/72/7a/a9ba7f98c7a575978698f4230c5e8cc54bbc761af34f560818f933dafa0c/ruff-0.15.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9cae0f92bd5700d1213188b31cd3bdd2b315361296d10b96b8e2337d3d11f53e", size = 11447765 },
{ url = "https://files.pythonhosted.org/packages/ea/f9/0ae446942c846b8266059ad8a30702a35afae55f5cdc54c5adf8d7afdc27/ruff-0.15.12-py3-none-win32.whl", hash = "sha256:d0185894e038d7043ba8fd6aee7499ece6462dc0ea9f1e260c7451807c714c20", size = 10657277 },
{ url = "https://files.pythonhosted.org/packages/33/f1/9614e03e1cdcbf9437570b5400ced8a720b5db22b28d8e0f1bda429f660d/ruff-0.15.12-py3-none-win_amd64.whl", hash = "sha256:c87a162d61ab3adca47c03f7f717c68672edec7d1b5499e652331780fe74950d", size = 11837758 },
{ url = "https://files.pythonhosted.org/packages/c0/98/6beb4b351e472e5f4c4613f7c35a5290b8be2497e183825310c4c3a3984b/ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f", size = 11120821 },
]
[[package]]
name = "safehttpx"
version = "0.1.7"
@@ -3100,6 +3125,11 @@ dependencies = [
{ name = "vibevoice" },
]
[package.dev-dependencies]
dev = [
{ name = "ruff" },
]
[package.metadata]
requires-dist = [
{ name = "fastapi", specifier = ">=0.111.0" },
@@ -3109,13 +3139,16 @@ requires-dist = [
{ name = "torch", specifier = ">=2.0.0", index = "https://download.pytorch.org/whl/cu124" },
{ name = "transformers", specifier = "==4.51.3" },
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.29.0" },
{ name = "vibevoice", git = "https://github.com/microsoft/VibeVoice.git" },
{ name = "vibevoice", git = "https://github.com/JezzWTF/VibeVoice.git" },
]
[package.metadata.requires-dev]
dev = [{ name = "ruff", specifier = ">=0.11.0" }]
[[package]]
name = "vibevoice"
version = "1.0.0"
source = { git = "https://github.com/microsoft/VibeVoice.git#e73d1e17c3754f046352014856a922f8208fb5d3" }
source = { git = "https://github.com/JezzWTF/VibeVoice.git#fe832f20e3d1638594f551a08f02253f14408dbd" }
dependencies = [
{ name = "absl-py" },
{ name = "accelerate" },
+520 -44
View File
@@ -20,17 +20,22 @@ Device selection:
import asyncio
import base64
import concurrent.futures
import copy
import functools
import importlib.util
import json
import logging
import os
import platform
import threading
import time
import types
import urllib.request
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator, Literal, Optional
from typing import Literal
import torch
from fastapi import FastAPI, HTTPException, Request
@@ -46,8 +51,7 @@ SAMPLE_RATE = 24_000
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
"/demo/voices/streaming_model"
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
)
EN_VOICES: dict[str, str] = {
@@ -62,6 +66,11 @@ DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Pipeline executor ──────────────────────────────────────────────────────────
# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker).
_decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
# Falls back to auto-detection if not set.
@@ -106,6 +115,38 @@ def _env_float(name: str, default: float) -> float:
return default
def _cpu_supports_bf16() -> bool:
"""Return True if the CPU has AVX512_BF16 hardware support."""
return (
hasattr(torch, "cpu")
and hasattr(torch.cpu, "is_avx512_bf16_supported")
and torch.cpu.is_avx512_bf16_supported()
)
def _configure_cpu_runtime() -> dict[str, object]:
logical_cpus = os.cpu_count() or 1
default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
torch.set_num_threads(max(1, intra_threads))
try:
torch.set_num_interop_threads(max(1, interop_threads))
except RuntimeError as exc:
logger.warning("Could not set CPU inter-op threads: %s", exc)
torch.backends.mkldnn.enabled = mkldnn_enabled
return {
"logical_cpus": logical_cpus,
"threads": torch.get_num_threads(),
"interop_threads": torch.get_num_interop_threads(),
"mkldnn_available": torch.backends.mkldnn.is_available(),
"mkldnn_enabled": torch.backends.mkldnn.enabled,
}
# ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"]
@@ -114,7 +155,7 @@ _processor = None
_model = None
_device: str = "cpu"
_model_status: ModelStatus = "loading"
_model_error: Optional[str] = None
_model_error: str | None = None
_voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
@@ -161,9 +202,7 @@ def _is_model_cached() -> bool:
try:
from huggingface_hub import snapshot_download
snapshot_download(
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
)
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
return True
except Exception:
return False
@@ -172,9 +211,7 @@ def _is_model_cached() -> bool:
def _download_model() -> None:
from huggingface_hub import snapshot_download
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download(
@@ -188,7 +225,7 @@ def _download_model() -> None:
def _download_voices() -> None:
VOICES_DIR.mkdir(parents=True, exist_ok=True)
for name, filename in EN_VOICES.items():
for _name, filename in EN_VOICES.items():
dest = VOICES_DIR / filename
if not dest.exists():
url = f"{VOICE_BASE_URL}/{filename}"
@@ -211,8 +248,56 @@ def _init_processor():
def _init_model(device: str):
logger.info("Loading model on %s...", device)
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
attn_impl = "flash_attention_2" if device == "cuda" else "sdpa"
if device == "cuda":
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
logger.info(
"PyTorch SDPA backends: flash=%s, mem_efficient=%s, math=%s",
torch.backends.cuda.flash_sdp_enabled(),
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled(),
)
elif device == "cpu":
torch.set_float32_matmul_precision("medium")
logger.info("CPU runtime configuration: %s", _configure_cpu_runtime())
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
if device == "cuda" and cuda_dtype == "fp16":
load_dtype = torch.float16
elif device == "cuda":
load_dtype = torch.bfloat16
else:
cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower()
if cpu_bf16_env == "1":
load_dtype = torch.bfloat16
logger.info("CPU BF16 forced via VIBEPOD_CPU_BF16=1")
elif cpu_bf16_env == "0":
load_dtype = torch.float32
logger.info("CPU float32 forced via VIBEPOD_CPU_BF16=0")
elif _cpu_supports_bf16():
load_dtype = torch.bfloat16
logger.info("AVX512_BF16 detected — loading model in bfloat16")
else:
load_dtype = torch.float32
logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)")
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
if requested_attn_impl in {"eager", "sdpa"}:
attn_impl = requested_attn_impl
elif requested_attn_impl == "flash_attention_2":
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
else:
attn_impl = "flash_attention_2" if device == "cuda" and has_flash_attn else "sdpa"
logger.info("Using Transformers attention implementation: %s", attn_impl)
if device == "cuda" and not has_flash_attn:
logger.info("flash_attn is not installed; using PyTorch SDPA attention.")
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
@@ -225,9 +310,13 @@ def _init_model(device: str):
device_map=device,
attn_implementation=attn_impl,
)
except Exception:
except Exception as exc:
if attn_impl == "sdpa":
raise
logger.warning(
"Model load with %s failed; falling back to sdpa", attn_impl, exc_info=True
"Model load with %s failed (%s); falling back to sdpa",
attn_impl,
exc,
)
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
@@ -237,10 +326,272 @@ def _init_model(device: str):
)
model.eval()
if device == "cpu":
model = _apply_cpu_optimizations(model)
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
_install_generation_optimizations(model)
if device == "cpu":
# Must run after _install_generation_optimizations so the async wrapper
# sits outside the profiling wrapper (VibeVoice calls async → profiling → real decode).
_install_cpu_pipeline_optimizations(model)
return model
def _apply_cpu_optimizations(model: object) -> object:
"""Apply optional post-load CPU optimizations. Returns (possibly new) model object."""
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
if do_quantize:
logger.info("Applying dynamic INT8 quantization to Linear layers...")
try:
import torch.ao.quantization
# The diffusion prediction_head operates on small fixed-size tensors where
# INT8 pack/unpack overhead exceeds the matmul savings (~+20% regression in
# testing). Save and restore it so it stays in float32.
saved_prediction_head = None
if hasattr(model, "model") and hasattr(model.model, "prediction_head"):
saved_prediction_head = model.model.prediction_head
del model.model.prediction_head
model = torch.ao.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
if saved_prediction_head is not None:
model.model.prediction_head = saved_prediction_head
logger.info(
"Dynamic INT8 quantization applied (prediction_head excluded — stays float32)."
)
else:
logger.info("Dynamic INT8 quantization applied.")
except Exception as exc:
logger.warning("Dynamic quantization failed: %s — skipping", exc)
if do_compile:
# torch.compile with inductor on CPU is ineffective for autoregressive TTS:
# each token step produces a unique input shape, so every step triggers a new
# kernel compile event rather than reusing compiled code. Kept as an escape
# hatch but not recommended.
compile_mode = os.environ.get("VIBEPOD_COMPILE_MODE", "reduce-overhead")
logger.info(
"torch.compile enabled (mode=%s) — NOTE: limited benefit for autoregressive"
" models on CPU due to dynamic sequence lengths.",
compile_mode,
)
_compile_targets: list[tuple[str, object, str, bool]] = [
("forward_tts_lm", model, "forward_tts_lm", True),
]
if hasattr(model, "model"):
inner = model.model
if hasattr(inner, "prediction_head"):
_compile_targets.append(("prediction_head", inner, "prediction_head", False))
if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
_compile_targets.append(
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
)
for label, obj, attr, dynamic in _compile_targets:
try:
compiled = torch.compile(
getattr(obj, attr),
backend="inductor",
mode=compile_mode,
dynamic=dynamic,
)
setattr(obj, attr, compiled)
logger.info(" compiled: %s", label)
except Exception as exc:
logger.warning(" torch.compile failed for %s: %s — skipping", label, exc)
return model
def _install_generation_optimizations(model: object) -> None:
"""Patch VibeVoice hot paths without changing model quality settings."""
def profile_enabled() -> bool:
return os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1"
def profile_sync() -> None:
if torch.cuda.is_available():
torch.cuda.synchronize()
def profile_record(self, key: str, elapsed: float) -> None:
stats = getattr(self, "_vibepod_profile", None)
if stats is None:
stats = {}
self._vibepod_profile = stats
bucket = stats.setdefault(key, {"count": 0, "seconds": 0.0})
bucket["count"] += 1
bucket["seconds"] += elapsed
def timed_method(self, key: str, fn, *args, **kwargs):
if not profile_enabled():
return fn(*args, **kwargs)
profile_sync()
started = time.perf_counter()
result = fn(*args, **kwargs)
profile_sync()
profile_record(self, key, time.perf_counter() - started)
return result
def prepare_noise_scheduler(self):
scheduler = self.model.noise_scheduler
cache_key = self.ddpm_inference_steps
cache = getattr(self, "_vibepod_scheduler_cache", {})
cached = cache.get(cache_key)
if cached is None:
scheduler.set_timesteps(self.ddpm_inference_steps)
cached = {
"num_inference_steps": scheduler.num_inference_steps,
"timesteps": scheduler.timesteps,
"sigmas": scheduler.sigmas,
}
cache[cache_key] = cached
self._vibepod_scheduler_cache = cache
else:
scheduler.num_inference_steps = cached["num_inference_steps"]
scheduler.timesteps = cached["timesteps"]
scheduler.sigmas = cached["sigmas"]
scheduler.model_outputs = [None] * scheduler.config.solver_order
scheduler.lower_order_nums = 0
scheduler._step_index = None
scheduler._begin_index = None
return scheduler
def sample_speech_tokens_optimized(self, condition, neg_condition, cfg_scale=3.0):
scheduler = prepare_noise_scheduler(self)
condition = torch.cat([condition, neg_condition], dim=0).to(
self.model.prediction_head.device
)
batch_size = condition.shape[0] // 2
speech = torch.randn(batch_size, self.config.acoustic_vae_dim).to(condition)
t_batch_cache_key = (
self.ddpm_inference_steps,
condition.device.type,
condition.device.index,
condition.dtype,
batch_size,
)
t_batch_cache = getattr(self, "_vibepod_t_batch_cache", {})
t_batches = t_batch_cache.get(t_batch_cache_key)
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
t_batches = [
t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
for t in scheduler.timesteps
]
t_batch_cache[t_batch_cache_key] = t_batches
self._vibepod_t_batch_cache = t_batch_cache
for t, t_batch in zip(scheduler.timesteps, t_batches):
if batch_size == 1:
combined = speech.expand(condition.shape[0], -1)
else:
combined = torch.cat([speech, speech], dim=0)
if profile_enabled():
profile_sync()
started = time.perf_counter()
eps = self.model.prediction_head(combined, t_batch, condition=condition)
if profile_enabled():
profile_sync()
profile_record(self, "diffusion_prediction_head", time.perf_counter() - started)
cond_eps, uncond_eps = torch.split(eps, batch_size, dim=0)
guided_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
if profile_enabled():
started = time.perf_counter()
speech = scheduler.step(guided_eps, t, speech).prev_sample
if profile_enabled():
profile_record(self, "diffusion_scheduler_step", time.perf_counter() - started)
return speech
forward_lm = model.forward_lm
forward_tts_lm = model.forward_tts_lm
acoustic_decode = model.model.acoustic_tokenizer.decode
def forward_lm_profiled(*args, **kwargs):
return timed_method(model, "forward_lm", forward_lm, *args, **kwargs)
def forward_tts_lm_profiled(*args, **kwargs):
return timed_method(model, "forward_tts_lm", forward_tts_lm, *args, **kwargs)
def acoustic_decode_profiled(*args, **kwargs):
return timed_method(model, "acoustic_decode", acoustic_decode, *args, **kwargs)
model.forward_lm = forward_lm_profiled
model.forward_tts_lm = forward_tts_lm_profiled
model.model.acoustic_tokenizer.decode = acoustic_decode_profiled
model.sample_speech_tokens = types.MethodType(sample_speech_tokens_optimized, model)
logger.info("Installed VibeVoice generation hot-path optimizations.")
def _install_cpu_pipeline_optimizations(model: object) -> None:
"""Attach the decode executor to the model for the optimised generate() loop.
The JezzWTF/VibeVoice fork's generate() checks for two optional attributes:
model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) that
overlaps acoustic_decode with acoustic_connector + forward_tts_lm.
Profiling showed this hides ~72s of decode cost behind tts_lm work,
capturing ~96% of the theoretical overlap savings.
model._vibepod_cfg_executor — intentionally NOT set. Parallel pos/neg
forward_tts_lm via a second thread causes MKL OpenMP thread-pool
contention on CPU: both threads compete for the same OMP worker pool,
making each call slower rather than faster. Net effect: ~6% regression.
The hook remains in the fork for potential GPU or future use.
Attributes default to None, so the fork's generate() falls back to the
original sequential behaviour on CUDA or any non-VibePod install.
"""
global _decode_executor
if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1":
logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.")
return
_decode_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="vibepod-decode"
)
model._vibepod_decode_executor = _decode_executor
logger.info(
"CPU pipeline: decode executor attached — acoustic_decode overlaps "
"tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0."
)
def _model_float_dtype() -> torch.dtype:
try:
return next(_model.parameters()).dtype
except StopIteration:
return torch.float32
def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> object:
if torch.is_tensor(value):
if torch.is_floating_point(value):
return value.to(device=device, dtype=dtype)
return value.to(device=device)
if isinstance(value, dict):
for k in list(value.keys()):
value[k] = _move_cached_prompt(value[k], device, dtype)
return value
if isinstance(value, list):
return [_move_cached_prompt(v, device, dtype) for v in value]
if isinstance(value, tuple):
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
return value
def _load_voice_presets(device: str) -> dict[str, object]:
presets = {}
for name, filename in EN_VOICES.items():
@@ -273,19 +624,33 @@ def _load_model_sync() -> None:
is_cpu = _device == "cpu"
_config["device"] = _device
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 5.0 if is_cpu else 2.0)
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.0 if is_cpu else 0.4)
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 2.5 if is_cpu else 1.5)
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
_config["prebuffer_secs"] = _env_float(
"VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0
)
_config["rebuffer_threshold_secs"] = _env_float(
"VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0
)
_config["resume_threshold_secs"] = _env_float(
"VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0
)
_config["default_inference_steps"] = _env_int(
"VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10
)
if is_cpu:
logical_cpus = os.cpu_count() or 1
_config["cpu_threads"] = _env_int(
"VIBEPOD_CPU_THREADS",
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
)
_config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
_config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
_processor = _init_processor()
_model = _init_model(_device)
_voice_presets = _load_voice_presets(_device)
_model_status = "online"
logger.info(
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
)
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
logger.info("Configuration: %s", _config)
except Exception as exc:
@@ -302,6 +667,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
if _decode_executor is not None:
_decode_executor.shutdown(wait=False)
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)
@@ -314,7 +681,7 @@ class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000)
speaker: str = Field(default=DEFAULT_SPEAKER)
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
inference_steps: int | None = Field(default=None, ge=5, le=20)
@field_validator("text")
@classmethod
@@ -353,8 +720,8 @@ async def health() -> dict:
def _sync_generate(
req: GenerateRequest,
streamer: Optional[object] = None,
cancel_event: Optional[threading.Event] = None,
streamer: object | None = None,
cancel_event: threading.Event | None = None,
) -> str:
"""Blocking inference. Returns the speaker used.
Runs in a thread-pool executor — do not call from the event loop directly.
@@ -364,10 +731,17 @@ def _sync_generate(
raise RuntimeError("Generation cancelled.")
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
voice_preset = copy.deepcopy(_voice_presets[speaker])
model_dtype = _model_float_dtype()
voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
steps = (
req.inference_steps
if req.inference_steps is not None
else _config["default_inference_steps"]
)
_model.set_ddpm_inference_steps(num_steps=steps)
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
_model._vibepod_profile = {}
inputs = _processor.process_input_with_cached_prompt(
text=req.text,
@@ -380,20 +754,21 @@ def _sync_generate(
if torch.is_tensor(v):
inputs[k] = v.to(_device)
outputs = _model.generate(
with torch.inference_mode():
_model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=req.cfg_scale,
tokenizer=_processor.tokenizer,
generation_config={"do_sample": False},
verbose=True,
all_prefilled_outputs=copy.deepcopy(voice_preset),
verbose=False,
show_progress_bar=False,
return_speech=False,
stop_check_fn=cancel_event.is_set if cancel_event else None,
all_prefilled_outputs=voice_preset,
audio_streamer=streamer,
)
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
raise ValueError("Model returned no audio output.")
return speaker
@@ -401,6 +776,22 @@ def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
def _generation_profile() -> dict[str, dict[str, float]] | None:
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
return None
stats = getattr(_model, "_vibepod_profile", None)
if not stats:
return {}
return {
key: {
"count": value["count"],
"seconds": round(value["seconds"], 3),
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
}
for key, value in sorted(stats.items())
}
@app.post("/generate")
async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if _model_status != "online":
@@ -417,28 +808,62 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
)
async def event_stream() -> AsyncGenerator[str, None]:
from vibevoice.modular.streamer import AsyncAudioStreamer
class NonBlockingAudioStreamer:
"""Async streamer that keeps GPU->CPU copies out of the model thread."""
def __init__(self, batch_size: int, stop_signal: object = None) -> None:
self.batch_size = batch_size
self.stop_signal = stop_signal
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor) -> None:
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait,
audio_chunks[i].detach(),
)
def end(self, sample_indices: torch.Tensor | None = None) -> None:
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
start = time.monotonic()
streamer = AsyncAudioStreamer(batch_size=1)
streamer = NonBlockingAudioStreamer(batch_size=1)
cancel_event = threading.Event()
accum_size = max(1, _config["chunk_accum"])
accumulated_chunks = []
chunk_count = 0
audio_samples = 0
first_chunk_at: float | None = None
last_chunk_at: float | None = None
max_chunk_gap = 0.0
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
async with _generation_lock:
loop = asyncio.get_event_loop()
future = loop.run_in_executor(
None, functools.partial(_sync_generate, req, streamer, cancel_event)
)
future.add_done_callback(lambda _: streamer.end())
# Drain audio chunks as they arrive from the diffusion head.
# stop_signal=None is the default sentinel that ends the queue.
while True:
try:
chunk = await asyncio.wait_for(
streamer.audio_queues[0].get(), timeout=120.0
)
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
@@ -454,17 +879,45 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if chunk is None: # stop signal
break
accumulated_chunks.append(chunk.detach().cpu().float())
accumulated_chunks.append(chunk.detach())
if len(accumulated_chunks) >= accum_size:
combined = torch.cat(accumulated_chunks, dim=0)
now = time.monotonic()
if first_chunk_at is None:
first_chunk_at = now
if last_chunk_at is not None:
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
last_chunk_at = now
combined = (
torch.cat(accumulated_chunks, dim=0)
.detach()
.to("cpu", dtype=torch.float32)
.contiguous()
)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
accumulated_chunks = []
# Flush any remaining chunks
if accumulated_chunks:
combined = torch.cat(accumulated_chunks, dim=0)
now = time.monotonic()
if first_chunk_at is None:
first_chunk_at = now
if last_chunk_at is not None:
max_chunk_gap = max(max_chunk_gap, now - last_chunk_at)
last_chunk_at = now
combined = (
torch.cat(accumulated_chunks, dim=0)
.detach()
.to("cpu", dtype=torch.float32)
.contiguous()
)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
yield _sse({"type": "audio_chunk", "data": pcm_b64})
@@ -479,17 +932,40 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
yield _sse(
{
"type": "error",
"message": "Internal server error during generation.",
"message": f"Generation failed: {exc}",
}
)
return
elapsed = round(time.monotonic() - start, 1)
audio_secs = audio_samples / SAMPLE_RATE
realtime_factor = audio_secs / elapsed if elapsed > 0 else None
profile = _generation_profile()
if profile is not None:
logger.info("Generation profile: %s", profile)
logger.info("Generation complete in %.1fs", elapsed)
yield _sse({"type": "complete", "elapsed": elapsed, "speaker": speaker})
complete_event = {
"type": "complete",
"elapsed": elapsed,
"speaker": speaker,
"audio_secs": round(audio_secs, 2),
"realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
"chunks": chunk_count,
"first_chunk_secs": round(first_chunk_at - start, 2)
if first_chunk_at is not None
else None,
"max_chunk_gap_secs": round(max_chunk_gap, 2),
}
if profile is not None:
complete_event["profile"] = profile
yield _sse(complete_event)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"X-Content-Type-Options": "nosniff",
},
)
+8 -3
View File
@@ -1,10 +1,13 @@
import { NextRequest, NextResponse } from "next/server";
export const dynamic = "force-dynamic";
export const runtime = "nodejs";
export async function POST(request: NextRequest) {
const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000";
try {
const body = await request.json() as {
const body = (await request.json()) as {
text: string;
speaker?: string;
cfg_scale?: number;
@@ -24,6 +27,7 @@ export async function POST(request: NextRequest) {
cfg_scale: body.cfg_scale ?? 1.5,
inference_steps: body.inference_steps ?? 10,
}),
signal: request.signal,
});
if (!upstream.ok) {
@@ -36,8 +40,9 @@ export async function POST(request: NextRequest) {
status: 200,
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Cache-Control": "no-cache, no-transform",
Connection: "keep-alive",
"X-Content-Type-Options": "nosniff",
"X-Accel-Buffering": "no",
},
});
+2 -2
View File
@@ -4,8 +4,7 @@ const OFFLINE_RESPONSE = { status: "offline" };
const COMMON_OPTIONS = { headers: { "Cache-Control": "no-store" } };
export async function GET() {
const pythonServerUrl =
process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000";
const pythonServerUrl = process.env.VIBEVOICE_SERVER_URL ?? "http://localhost:8000";
try {
const res = await fetch(`${pythonServerUrl}/health`, {
@@ -27,6 +26,7 @@ export async function GET() {
message: data.message,
progress: data.progress ?? null,
voices: data.voices ?? [],
config: data.config ?? null,
},
COMMON_OPTIONS
);
+4 -2
View File
@@ -12,8 +12,10 @@
--muted: #64748b;
--success: #22c55e;
--error: #ef4444;
--font-sans: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
--font-mono: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas, "Liberation Mono", monospace;
--font-sans:
ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
--font-mono:
ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas, "Liberation Mono", monospace;
}
@theme inline {
+57 -25
View File
@@ -69,19 +69,39 @@ type AppAction =
function reducer(state: AppState, action: AppAction): AppState {
switch (action.type) {
case "SET_SCRIPT": return { ...state, script: action.payload };
case "SET_SPEAKER": return { ...state, speaker: action.payload };
case "SET_CFG_SCALE": return { ...state, cfgScale: action.payload };
case "SET_INFERENCE_STEPS": return { ...state, inferenceSteps: action.payload };
case "SET_PREBUFFER_SECS": return { ...state, prebufferSecs: action.payload };
case "SET_REBUFFER_THRESHOLD": return { ...state, rebufferThresholdSecs: action.payload };
case "SET_RESUME_THRESHOLD": return { ...state, resumeThresholdSecs: action.payload };
case "SET_SCRIPT":
return { ...state, script: action.payload };
case "SET_SPEAKER":
return { ...state, speaker: action.payload };
case "SET_CFG_SCALE":
return { ...state, cfgScale: action.payload };
case "SET_INFERENCE_STEPS":
return { ...state, inferenceSteps: action.payload };
case "SET_PREBUFFER_SECS":
return { ...state, prebufferSecs: action.payload };
case "SET_REBUFFER_THRESHOLD":
return { ...state, rebufferThresholdSecs: action.payload };
case "SET_RESUME_THRESHOLD":
return { ...state, resumeThresholdSecs: action.payload };
case "START_GENERATION":
return { ...state, isGenerating: true, audioUrl: null, logs: [], genElapsed: 0, genPct: null };
return {
...state,
isGenerating: true,
audioUrl: null,
logs: [],
genElapsed: 0,
genPct: null,
};
case "GEN_PROGRESS":
return { ...state, genElapsed: action.elapsed, genPct: action.pct };
case "GENERATION_SUCCESS":
return { ...state, isGenerating: false, genElapsed: 0, genPct: null, audioUrl: action.payload };
return {
...state,
isGenerating: false,
genElapsed: 0,
genPct: null,
audioUrl: action.payload,
};
case "GENERATION_CANCELLED":
case "GENERATION_ERROR":
return { ...state, isGenerating: false, genElapsed: 0, genPct: null };
@@ -89,21 +109,27 @@ function reducer(state: AppState, action: AppAction): AppState {
return { ...state, logs: [...state.logs, action.payload] };
case "SET_SERVER_STATUS": {
const isNewConfig = !state.serverConfig && action.payload.config;
const deviceChanged = !!(state.serverConfig && action.payload.config && state.serverConfig.device !== action.payload.config.device);
const deviceChanged = !!(
state.serverConfig &&
action.payload.config &&
state.serverConfig.device !== action.payload.config.device
);
const nextSteps = (isNewConfig || deviceChanged)
const nextSteps =
isNewConfig || deviceChanged
? action.payload.config!.default_inference_steps
: state.inferenceSteps;
const nextPrebuffer = (isNewConfig || deviceChanged)
? action.payload.config!.prebuffer_secs
: state.prebufferSecs;
const nextPrebuffer =
isNewConfig || deviceChanged ? action.payload.config!.prebuffer_secs : state.prebufferSecs;
const nextRebuffer = (isNewConfig || deviceChanged)
const nextRebuffer =
isNewConfig || deviceChanged
? action.payload.config!.rebuffer_threshold_secs
: state.rebufferThresholdSecs;
const nextResume = (isNewConfig || deviceChanged)
const nextResume =
isNewConfig || deviceChanged
? action.payload.config!.resume_threshold_secs
: state.resumeThresholdSecs;
@@ -121,7 +147,8 @@ function reducer(state: AppState, action: AppAction): AppState {
resumeThresholdSecs: nextResume,
};
}
default: return state;
default:
return state;
}
}
@@ -130,9 +157,9 @@ const initialState: AppState = {
speaker: "carter",
cfgScale: 1.5,
inferenceSteps: 10,
prebufferSecs: 2.0,
rebufferThresholdSecs: 0.4,
resumeThresholdSecs: 1.5,
prebufferSecs: 5.0,
rebufferThresholdSecs: 1.0,
resumeThresholdSecs: 3.0,
isGenerating: false,
genElapsed: 0,
genPct: null,
@@ -213,7 +240,10 @@ export default function HomePage() {
}
poll();
return () => { cancelled = true; clearTimeout(timeoutId); };
return () => {
cancelled = true;
clearTimeout(timeoutId);
};
}, []);
const handleGenerate = useCallback(async () => {
@@ -241,7 +271,6 @@ export default function HomePage() {
<Header />
<main className="flex-1 container mx-auto px-4 py-6 max-w-6xl">
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
{/* Left: script + audio player */}
<div className="lg:col-span-2 flex flex-col gap-6">
<TextInputPanel
@@ -264,9 +293,13 @@ export default function HomePage() {
prebufferSecs={state.prebufferSecs}
onPrebufferSecsChange={(v) => dispatch({ type: "SET_PREBUFFER_SECS", payload: v })}
rebufferThresholdSecs={state.rebufferThresholdSecs}
onRebufferThresholdChange={(v) => dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })}
onRebufferThresholdChange={(v) =>
dispatch({ type: "SET_REBUFFER_THRESHOLD", payload: v })
}
resumeThresholdSecs={state.resumeThresholdSecs}
onResumeThresholdChange={(v) => dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })}
onResumeThresholdChange={(v) =>
dispatch({ type: "SET_RESUME_THRESHOLD", payload: v })
}
onGenerate={handleGenerate}
onStop={stop}
onPauseStream={pauseStream}
@@ -281,7 +314,6 @@ export default function HomePage() {
/>
<StatusLog messages={state.logs} />
</div>
</div>
</main>
</div>
+8 -28
View File
@@ -14,15 +14,8 @@ function formatTime(seconds: number): string {
}
export default function AudioPlayer({ audioUrl }: AudioPlayerProps) {
const {
isPlaying,
currentTime,
duration,
volume,
toggle,
seek,
setVolume,
} = useAudioPlayer(audioUrl);
const { isPlaying, currentTime, duration, volume, toggle, seek, setVolume } =
useAudioPlayer(audioUrl);
if (!audioUrl) return null;
@@ -56,12 +49,10 @@ export default function AudioPlayer({ audioUrl }: AudioPlayerProps) {
background: "rgba(45, 212, 191, 0.05)",
}}
onMouseEnter={(e) => {
(e.currentTarget as HTMLButtonElement).style.background =
"rgba(45, 212, 191, 0.15)";
(e.currentTarget as HTMLButtonElement).style.background = "rgba(45, 212, 191, 0.15)";
}}
onMouseLeave={(e) => {
(e.currentTarget as HTMLButtonElement).style.background =
"rgba(45, 212, 191, 0.05)";
(e.currentTarget as HTMLButtonElement).style.background = "rgba(45, 212, 191, 0.05)";
}}
>
<svg
@@ -115,27 +106,18 @@ export default function AudioPlayer({ audioUrl }: AudioPlayerProps) {
onClick={toggle}
className="w-10 h-10 rounded-full flex items-center justify-center transition-transform active:scale-95 cursor-pointer"
style={{
background:
"linear-gradient(135deg, var(--accent-teal-dim), var(--accent-violet-dim))",
background: "linear-gradient(135deg, var(--accent-teal-dim), var(--accent-violet-dim))",
boxShadow: "0 4px 12px rgba(45, 212, 191, 0.3)",
}}
aria-label={isPlaying ? "Pause" : "Play"}
>
{isPlaying ? (
<svg
className="w-4 h-4 text-white"
viewBox="0 0 24 24"
fill="currentColor"
>
<svg className="w-4 h-4 text-white" viewBox="0 0 24 24" fill="currentColor">
<rect x="6" y="4" width="4" height="16" />
<rect x="14" y="4" width="4" height="16" />
</svg>
) : (
<svg
className="w-4 h-4 text-white"
viewBox="0 0 24 24"
fill="currentColor"
>
<svg className="w-4 h-4 text-white" viewBox="0 0 24 24" fill="currentColor">
<polygon points="5 3 19 12 5 21 5 3" />
</svg>
)}
@@ -143,9 +125,7 @@ export default function AudioPlayer({ audioUrl }: AudioPlayerProps) {
{/* Duration info */}
<div className="flex-1 flex items-center gap-1 text-sm">
<span style={{ color: "var(--foreground)" }}>
{formatTime(currentTime)}
</span>
<span style={{ color: "var(--foreground)" }}>{formatTime(currentTime)}</span>
<span style={{ color: "var(--muted)" }}>/</span>
<span style={{ color: "var(--muted)" }}>{formatTime(duration)}</span>
</div>
+67 -18
View File
@@ -37,17 +37,26 @@ const STATUS_CONFIG: Record<
{ color: string; label: (p: DownloadProgress | null) => string }
> = {
offline: { color: "var(--error)", label: () => "Server offline — waiting for connection..." },
downloading: { color: "#60a5fa", label: (p) => p && p.total > 0 ? `Downloading model... (${p.done} / ${p.total} files)` : "Downloading model (~1 GB)..." },
downloading: {
color: "#60a5fa",
label: (p) =>
p && p.total > 0
? `Downloading model... (${p.done} / ${p.total} files)`
: "Downloading model (~1 GB)...",
},
loading: { color: "#fbbf24", label: () => "Loading model into memory..." },
error: { color: "var(--error)", label: () => "Server error — check the terminal for details." },
};
function SpinnerIcon() {
return (
<svg className="animate-spin w-4 h-4" viewBox="0 0 24 24" fill="none">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" />
<path
className="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z"
/>
</svg>
);
}
@@ -146,7 +155,10 @@ export default function GenerationControls({
onChange={(e) => onCfgScaleChange(parseFloat(e.target.value))}
className="w-full"
/>
<div className="flex items-center justify-between text-xs" style={{ color: "var(--muted)" }}>
<div
className="flex items-center justify-between text-xs"
style={{ color: "var(--muted)" }}
>
<span>Flat (0.5)</span>
<span>CFG Scale</span>
<span>Expressive (4.0)</span>
@@ -157,7 +169,7 @@ export default function GenerationControls({
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between">
<label className="text-sm font-medium" style={{ color: "var(--foreground)" }}>
Quality vs Speed
Speed vs Quality
</label>
<span
className="text-sm font-mono px-2 py-0.5 rounded"
@@ -176,7 +188,10 @@ export default function GenerationControls({
className="w-full"
style={{ "--thumb-color": "var(--accent-violet)" } as React.CSSProperties}
/>
<div className="flex items-center justify-between text-xs" style={{ color: "var(--muted)" }}>
<div
className="flex items-center justify-between text-xs"
style={{ color: "var(--muted)" }}
>
<span>Faster (5)</span>
<span>Diffusion Steps</span>
<span>Better (20)</span>
@@ -207,7 +222,11 @@ export default function GenerationControls({
</div>
{showAdvanced && (
<div id="advanced-buffering-panel" className="flex flex-col gap-4 pl-2 border-l" style={{ borderColor: "var(--border)" }}>
<div
id="advanced-buffering-panel"
className="flex flex-col gap-4 pl-2 border-l"
style={{ borderColor: "var(--border)" }}
>
{/* Pre-buffer */}
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between">
@@ -221,7 +240,7 @@ export default function GenerationControls({
<input
type="range"
min={0.5}
max={10.0}
max={30.0}
step={0.5}
value={prebufferSecs}
onChange={(e) => onPrebufferSecsChange(parseFloat(e.target.value))}
@@ -232,7 +251,11 @@ export default function GenerationControls({
{/* Re-buffer threshold */}
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between">
<label htmlFor="rebuffer-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
<label
htmlFor="rebuffer-threshold"
className="text-xs font-medium"
style={{ color: "var(--foreground)" }}
>
Re-buffer Threshold
</label>
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
@@ -260,7 +283,11 @@ export default function GenerationControls({
{/* Resume threshold */}
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between">
<label htmlFor="resume-threshold" className="text-xs font-medium" style={{ color: "var(--foreground)" }}>
<label
htmlFor="resume-threshold"
className="text-xs font-medium"
style={{ color: "var(--foreground)" }}
>
Resume Threshold
</label>
<span className="text-xs font-mono" style={{ color: "var(--accent-teal)" }}>
@@ -271,7 +298,7 @@ export default function GenerationControls({
id="resume-threshold"
type="range"
min={0.5}
max={5.0}
max={30.0}
step={0.1}
value={resumeThresholdSecs}
onChange={(e) => {
@@ -302,7 +329,10 @@ export default function GenerationControls({
</div>
{serverStatus === "downloading" && (
<div className="w-full rounded-full h-1.5 overflow-hidden" style={{ background: "var(--border)" }}>
<div
className="w-full rounded-full h-1.5 overflow-hidden"
style={{ background: "var(--border)" }}
>
<div
className="h-1.5 rounded-full transition-all duration-500"
style={{
@@ -315,10 +345,16 @@ export default function GenerationControls({
)}
{serverStatus === "loading" && (
<div className="w-full rounded-full h-1.5 overflow-hidden" style={{ background: "var(--border)" }}>
<div
className="w-full rounded-full h-1.5 overflow-hidden"
style={{ background: "var(--border)" }}
>
<div
className="h-1.5 rounded-full animate-pulse"
style={{ width: "60%", background: "linear-gradient(90deg, #fbbf24, var(--accent-teal))" }}
style={{
width: "60%",
background: "linear-gradient(90deg, #fbbf24, var(--accent-teal))",
}}
/>
</div>
)}
@@ -328,11 +364,17 @@ export default function GenerationControls({
{/* Generation progress bar */}
{isGenerating && (
<div className="flex flex-col gap-1.5">
<div className="flex items-center justify-between text-xs" style={{ color: "var(--muted)" }}>
<div
className="flex items-center justify-between text-xs"
style={{ color: "var(--muted)" }}
>
<span>{genElapsed}s elapsed</span>
<span>{genPct !== null ? `${genPct}%` : "starting..."}</span>
</div>
<div className="w-full rounded-full h-1.5 overflow-hidden" style={{ background: "var(--border)" }}>
<div
className="w-full rounded-full h-1.5 overflow-hidden"
style={{ background: "var(--border)" }}
>
<div
className="h-1.5 rounded-full transition-all duration-500"
style={{
@@ -355,7 +397,8 @@ export default function GenerationControls({
buttonDisabled
? { background: "var(--border)", color: "var(--muted)" }
: {
background: "linear-gradient(135deg, var(--accent-teal-dim), var(--accent-violet-dim))",
background:
"linear-gradient(135deg, var(--accent-teal-dim), var(--accent-violet-dim))",
color: "#fff",
boxShadow: "0 4px 15px rgba(45, 212, 191, 0.2)",
}
@@ -373,7 +416,13 @@ export default function GenerationControls({
</>
) : (
<>
<svg className="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2">
<svg
className="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
>
<polygon points="5 3 19 12 5 21 5 3" />
</svg>
Generate Audio
+10 -13
View File
@@ -31,7 +31,10 @@ export default function Header() {
intervalRef.current = setInterval(checkHealth, SLOW_INTERVAL_MS);
}
// Switch to fast polling if we detect the server went offline/loading
if ((newStatus === "offline" || newStatus === "downloading" || newStatus === "loading") && intervalRef.current) {
if (
(newStatus === "offline" || newStatus === "downloading" || newStatus === "loading") &&
intervalRef.current
) {
clearInterval(intervalRef.current);
intervalRef.current = setInterval(checkHealth, FAST_INTERVAL_MS);
}
@@ -95,16 +98,13 @@ export default function Header() {
const cfg = statusConfig[status];
// Device badge — only shown once the server is online and device is known
const deviceBadge = status === "online" && device ? (
const deviceBadge =
status === "online" && device ? (
<span
className="px-2 py-0.5 rounded-full text-xs font-semibold tracking-wide uppercase"
style={{
background: device === "cuda"
? "var(--accent-violet-dim)"
: "var(--accent-teal-dim)",
color: device === "cuda"
? "var(--accent-violet)"
: "var(--accent-teal)",
background: device === "cuda" ? "var(--accent-violet-dim)" : "var(--accent-teal-dim)",
color: device === "cuda" ? "var(--accent-violet)" : "var(--accent-teal)",
border: `1px solid ${device === "cuda" ? "var(--accent-violet-dim)" : "var(--accent-teal-dim)"}`,
}}
title={device === "cuda" ? "Running on NVIDIA GPU" : "Running on CPU"}
@@ -136,8 +136,7 @@ export default function Header() {
<h1
className="text-xl font-bold tracking-tight"
style={{
background:
"linear-gradient(135deg, var(--accent-teal), var(--accent-violet))",
background: "linear-gradient(135deg, var(--accent-teal), var(--accent-violet))",
WebkitBackgroundClip: "text",
WebkitTextFillColor: "transparent",
}}
@@ -167,9 +166,7 @@ export default function Header() {
className={`animate-ping absolute inline-flex h-full w-full rounded-full opacity-75 ${cfg.color}`}
/>
)}
<span
className={`relative inline-flex rounded-full h-2 w-2 ${cfg.color}`}
/>
<span className={`relative inline-flex rounded-full h-2 w-2 ${cfg.color}`} />
</span>
<span style={{ color: "var(--foreground)" }}>{cfg.label}</span>
</div>
+1 -2
View File
@@ -47,8 +47,7 @@ export default function StatusLog({ messages }: StatusLogProps) {
) : (
messages.map((msg, i) => {
const isError =
msg.toLowerCase().includes("error") ||
msg.toLowerCase().includes("failed");
msg.toLowerCase().includes("error") || msg.toLowerCase().includes("failed");
const isSuccess =
msg.toLowerCase().includes("done") ||
msg.toLowerCase().includes("complete") ||
+6 -16
View File
@@ -15,10 +15,7 @@ interface TextInputPanelProps {
onChange: (text: string) => void;
}
export default function TextInputPanel({
value,
onChange,
}: TextInputPanelProps) {
export default function TextInputPanel({ value, onChange }: TextInputPanelProps) {
const charCount = value.length;
const wordCount = value.trim() === "" ? 0 : value.trim().split(/\s+/).length;
@@ -43,15 +40,12 @@ export default function TextInputPanel({
color: "var(--muted)",
}}
onMouseEnter={(e) => {
(e.target as HTMLButtonElement).style.color =
"var(--accent-violet)";
(e.target as HTMLButtonElement).style.borderColor =
"var(--accent-violet)";
(e.target as HTMLButtonElement).style.color = "var(--accent-violet)";
(e.target as HTMLButtonElement).style.borderColor = "var(--accent-violet)";
}}
onMouseLeave={(e) => {
(e.target as HTMLButtonElement).style.color = "var(--muted)";
(e.target as HTMLButtonElement).style.borderColor =
"var(--border)";
(e.target as HTMLButtonElement).style.borderColor = "var(--border)";
}}
>
Load sample script
@@ -69,8 +63,7 @@ export default function TextInputPanel({
}}
onMouseLeave={(e) => {
(e.target as HTMLButtonElement).style.color = "var(--muted)";
(e.target as HTMLButtonElement).style.borderColor =
"var(--border)";
(e.target as HTMLButtonElement).style.borderColor = "var(--border)";
}}
>
Clear
@@ -98,10 +91,7 @@ export default function TextInputPanel({
}}
/>
<div
className="flex items-center justify-between text-xs"
style={{ color: "var(--muted)" }}
>
<div className="flex items-center justify-between text-xs" style={{ color: "var(--muted)" }}>
<span>
{wordCount} word{wordCount !== 1 ? "s" : ""}
</span>
+6 -10
View File
@@ -55,16 +55,12 @@ export function useAudioPlayer(audioUrl: string | null) {
() => setState((prev) => ({ ...prev, isPlaying: false, currentTime: 0 })),
{ signal }
);
audio.addEventListener(
"play",
() => setState((prev) => ({ ...prev, isPlaying: true })),
{ signal }
);
audio.addEventListener(
"pause",
() => setState((prev) => ({ ...prev, isPlaying: false })),
{ signal }
);
audio.addEventListener("play", () => setState((prev) => ({ ...prev, isPlaying: true })), {
signal,
});
audio.addEventListener("pause", () => setState((prev) => ({ ...prev, isPlaying: false })), {
signal,
});
return () => {
audio.pause();
+68 -20
View File
@@ -3,9 +3,10 @@
import { useCallback, useEffect, useRef, useState } from "react";
const SAMPLE_RATE = 24_000;
const DEFAULT_PREBUFFER_SECS = 2.0;
const DEFAULT_REBUFFER_THRESHOLD_SECS = 0.4;
const DEFAULT_RESUME_THRESHOLD_SECS = 1.5;
const DEFAULT_PREBUFFER_SECS = 5.0;
const DEFAULT_REBUFFER_THRESHOLD_SECS = 1.0;
const DEFAULT_RESUME_THRESHOLD_SECS = 3.0;
const MAX_ADAPTIVE_RESUME_SECS = 30.0;
interface GenerateOptions {
text: string;
@@ -91,7 +92,7 @@ export function useStreamingGeneration({
let resumeThresholdSecs = rawResumeThresholdSecs;
if (resumeThresholdSecs <= rebufferThresholdSecs) {
console.warn(
`[useStreamingGeneration] resumeThresholdSecs (${resumeThresholdSecs}) must be greater than rebufferThresholdSecs (${rebufferThresholdSecs}). Clamping resumeThresholdSecs to ${rebufferThresholdSecs + 0.5}.`,
`[useStreamingGeneration] resumeThresholdSecs (${resumeThresholdSecs}) must be greater than rebufferThresholdSecs (${rebufferThresholdSecs}). Clamping resumeThresholdSecs to ${rebufferThresholdSecs + 0.5}.`
);
resumeThresholdSecs = rebufferThresholdSecs + 0.5;
}
@@ -104,6 +105,10 @@ export function useStreamingGeneration({
const isAutoBufferingRef = useRef(false);
const isUserPausedRef = useRef(false);
const audioUrlRef = useRef<string | null>(null);
const firstChunkSeenRef = useRef(false);
const underrunCountRef = useRef(0);
const totalAudioSamplesRef = useRef(0);
const adaptiveResumeSecsRef = useRef(DEFAULT_RESUME_THRESHOLD_SECS);
const revokeCurrentUrl = useCallback(() => {
if (audioUrlRef.current) {
@@ -122,8 +127,12 @@ export function useStreamingGeneration({
hasStartedPlaybackRef.current = false;
isAutoBufferingRef.current = false;
isUserPausedRef.current = false;
firstChunkSeenRef.current = false;
underrunCountRef.current = 0;
totalAudioSamplesRef.current = 0;
adaptiveResumeSecsRef.current = resumeThresholdSecs;
setIsStreamPaused(false);
}, []);
}, [resumeThresholdSecs]);
useEffect(() => {
return () => {
@@ -153,15 +162,23 @@ export function useStreamingGeneration({
hasStartedPlaybackRef.current = true;
}, [enqueue]);
const handleAudioChunk = useCallback((chunk: Float32Array<ArrayBuffer>) => {
const handleAudioChunk = useCallback(
(chunk: Float32Array<ArrayBuffer>) => {
const ctx = audioCtxRef.current;
if (!ctx) return;
chunksRef.current.push(chunk);
totalAudioSamplesRef.current += chunk.length;
if (!firstChunkSeenRef.current) {
firstChunkSeenRef.current = true;
onLog("First audio chunk received");
}
if (!hasStartedPlaybackRef.current) {
const bufferedSecs = chunksRef.current.reduce((sum, c) => sum + c.length, 0) / SAMPLE_RATE;
if (bufferedSecs >= prebufferSecs) {
onLog(`Playback started after ${bufferedSecs.toFixed(1)}s buffered`);
flushBufferedAudio();
}
return;
@@ -171,20 +188,28 @@ export function useStreamingGeneration({
if (isUserPausedRef.current) return;
const ahead = nextStartTimeRef.current - ctx.currentTime;
if (ctx.state === "running" && ahead < rebufferThresholdSecs) {
ctx.suspend().catch(() => {});
if (ctx.state === "running" && !isAutoBufferingRef.current && ahead < rebufferThresholdSecs) {
isAutoBufferingRef.current = true;
} else if (
ctx.state === "suspended" &&
isAutoBufferingRef.current &&
ahead >= resumeThresholdSecs
) {
ctx.resume().catch(() => {});
underrunCountRef.current += 1;
adaptiveResumeSecsRef.current = Math.min(
MAX_ADAPTIVE_RESUME_SECS,
Math.max(resumeThresholdSecs, prebufferSecs + underrunCountRef.current * 2)
);
ctx.suspend().catch(() => {});
onLog(
`Buffer underrun ${underrunCountRef.current}; refilling to ${adaptiveResumeSecsRef.current.toFixed(1)}s`
);
} else if (isAutoBufferingRef.current && ahead >= adaptiveResumeSecsRef.current) {
isAutoBufferingRef.current = false;
ctx.resume().catch(() => {});
onLog(`Buffer recovered with ${ahead.toFixed(1)}s queued`);
}
}, [enqueue, flushBufferedAudio, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]);
},
[enqueue, flushBufferedAudio, onLog, prebufferSecs, rebufferThresholdSecs, resumeThresholdSecs]
);
const generate = useCallback(async (options: GenerateOptions) => {
const generate = useCallback(
async (options: GenerateOptions) => {
if (!options.text.trim()) return;
resetPlayback();
@@ -217,7 +242,7 @@ export function useStreamingGeneration({
});
if (!res.ok || !res.body) {
const err = await res.json().catch(() => ({})) as { error?: string };
const err = (await res.json().catch(() => ({}))) as { error?: string };
throw new Error(err.error ?? `HTTP ${res.status}`);
}
@@ -239,6 +264,11 @@ export function useStreamingGeneration({
type: "audio_chunk" | "complete" | "error" | "cancelled";
data?: string;
elapsed?: number;
audio_secs?: number;
realtime_factor?: number | null;
chunks?: number;
first_chunk_secs?: number | null;
max_chunk_gap_secs?: number;
message?: string;
};
@@ -247,12 +277,28 @@ export function useStreamingGeneration({
} else if (event.type === "complete") {
if (!hasStartedPlaybackRef.current) {
flushBufferedAudio();
} else if (isAutoBufferingRef.current) {
isAutoBufferingRef.current = false;
audioCtxRef.current?.resume().catch(() => {});
}
const wavBlob = buildWav(mergeFloat32Arrays(chunksRef.current), SAMPLE_RATE);
const audioUrl = URL.createObjectURL(wavBlob);
audioUrlRef.current = audioUrl;
const kb = (wavBlob.size / 1024).toFixed(0);
onLog(`Done in ${event.elapsed}s - ${kb} KB`);
const audioSecs = event.audio_secs ?? totalAudioSamplesRef.current / SAMPLE_RATE;
const realtimeFactor =
event.realtime_factor ??
(event.elapsed && event.elapsed > 0 ? audioSecs / event.elapsed : null);
const speedText =
realtimeFactor === null ? "" : ` - ${realtimeFactor.toFixed(2)}x realtime`;
onLog(
`Done in ${event.elapsed}s - ${audioSecs.toFixed(1)}s audio${speedText} - ${kb} KB`
);
if (event.chunks && event.first_chunk_secs !== undefined) {
onLog(
`Stream: first chunk ${event.first_chunk_secs}s, ${event.chunks} chunks, max gap ${event.max_chunk_gap_secs}s`
);
}
onSuccess(audioUrl);
} else if (event.type === "cancelled") {
throw new DOMException("Generation cancelled", "AbortError");
@@ -274,7 +320,8 @@ export function useStreamingGeneration({
window.clearInterval(timerId);
abortRef.current = null;
}
}, [
},
[
flushBufferedAudio,
handleAudioChunk,
onCancel,
@@ -285,7 +332,8 @@ export function useStreamingGeneration({
onSuccess,
resetPlayback,
revokeCurrentUrl,
]);
]
);
const pauseStream = useCallback(() => {
isUserPausedRef.current = true;
+4 -2
View File
@@ -4,8 +4,10 @@
"private": true,
"scripts": {
"dev": "next dev --turbopack",
"build": "next build --turbopack",
"start": "next start"
"build": "next build",
"start": "next start",
"format": "prettier --write .",
"format:check": "prettier --check ."
},
"dependencies": {
"next": "15.5.15",