mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
Add AMD ROCm GPU support
Introduces a third hardware mode alongside CUDA and CPU: ROCm (AMD GPU). AMD GPUs present as CUDA devices under PyTorch ROCm, so the existing GPU path is reused with minimal changes — the main additions are wheel management, device detection, and suppressing flash_attn (unsupported on ROCm). - server/vibevoice_server.py: extend _resolve_device() to recognise 'rocm' (auto-detected via torch.version.hip); add _torch_device() helper that maps 'rocm' → 'cuda' for all PyTorch API calls; apply GPU optimisations for both cuda and rocm in _init_model(); always use sdpa on ROCm; propagate _torch_device() to _load_voice_presets() map_location. - server/start.sh: add --rocm flag; sync .venv-rocm with uv sync --no-sources then replace torch with the ROCm 6.2 wheel via uv pip install; set VIBEPOD_DEVICE=rocm for uvicorn. - server/pyproject.toml: register pytorch-rocm62 index (explicit); add .venv-rocm to ruff excludes. - package.json: add dev:rocm and dev:server:rocm scripts. - README.md: document ROCm mode, prerequisites (RX 6000+, ROCm 6.2+, Linux), and new commands; expand CUDA vs CPU section to CUDA vs CPU vs ROCm. https://claude.ai/code/session_0168pSswiaoEf6LEx6UQWfBu
This commit is contained in:
@@ -37,6 +37,7 @@ cp .env.example .env.local
|
||||
# 4. Start everything
|
||||
pnpm dev # CUDA (requires NVIDIA GPU + driver >= 525.60)
|
||||
pnpm dev:cpu # CPU-only (no GPU required)
|
||||
pnpm dev:rocm # ROCm (requires AMD GPU + ROCm 6.2+, Linux only)
|
||||
```
|
||||
|
||||
`pnpm dev` / `pnpm dev:cpu` start both services concurrently:
|
||||
@@ -46,26 +47,31 @@ pnpm dev:cpu # CPU-only (no GPU required)
|
||||
|
||||
The frontend shows a loading indicator while the model downloads. Once the server reports `status: online`, generation is available.
|
||||
|
||||
## CUDA vs CPU
|
||||
## CUDA vs CPU vs ROCm
|
||||
|
||||
VibePod maintains two completely separate Python virtual environments so CUDA and CPU torch installs never conflict:
|
||||
VibePod maintains three completely separate Python virtual environments so torch installs never conflict:
|
||||
|
||||
| Mode | Command | venv | torch source |
|
||||
| -------------- | -------------- | ------------------ | ----------------------- |
|
||||
| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index |
|
||||
| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) |
|
||||
| Mode | Command | venv | torch source |
|
||||
| -------------- | ---------------- | --------------------- | ------------------------- |
|
||||
| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index |
|
||||
| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) |
|
||||
| ROCm (AMD GPU) | `pnpm dev:rocm` | `server/.venv-rocm` | PyTorch ROCm 6.2 index |
|
||||
|
||||
On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"` or `"device": "cuda"`.
|
||||
On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"`, `"device": "cuda"`, or `"device": "rocm"`.
|
||||
|
||||
> **CUDA requirement:** driver >= 525.60 (RTX 30/40 series all qualify). Run `nvidia-smi` to check.
|
||||
|
||||
> **ROCm requirement:** ROCm 6.2+ installed on Linux. Supported GPUs: AMD RX 6000 series (RDNA2) or newer, RX 7000 series (RDNA3), and Instinct accelerators. ROCm is not supported on Windows. Flash attention is not available on ROCm — SDPA is used instead.
|
||||
|
||||
## Individual commands
|
||||
|
||||
```bash
|
||||
pnpm dev # CUDA — server + web
|
||||
pnpm dev:cpu # CPU — server + web
|
||||
pnpm dev:rocm # ROCm — server + web
|
||||
pnpm dev:server # CUDA — Python server only
|
||||
pnpm dev:server:cpu # CPU — Python server only
|
||||
pnpm dev:server:rocm # ROCm — Python server only
|
||||
pnpm dev:web # Next.js only (no Python server)
|
||||
pnpm build # Production build of the frontend
|
||||
```
|
||||
@@ -133,4 +139,4 @@ cd server && uv add <package>
|
||||
cd server && uv lock --upgrade
|
||||
```
|
||||
|
||||
> **Note:** The `[tool.uv.sources]` block in `pyproject.toml` pulls torch from the PyTorch CUDA 12.4 index by default. Running with `--cpu` (or `uv sync --no-sources`) bypasses this and installs the standard PyPI CPU wheel instead.
|
||||
> **Note:** The `[tool.uv.sources]` block in `pyproject.toml` pulls torch from the PyTorch CUDA 12.4 index by default. Running with `--cpu` or `--rocm` (or `uv sync --no-sources`) bypasses this and installs the standard PyPI CPU wheel first; for ROCm, the torch wheel is then replaced with the PyTorch ROCm 6.2 build.
|
||||
|
||||
@@ -6,8 +6,10 @@
|
||||
"build": "pnpm --filter vibepod-web build",
|
||||
"dev": "bash dev.sh",
|
||||
"dev:cpu": "bash dev.sh --cpu",
|
||||
"dev:rocm": "bash dev.sh --rocm",
|
||||
"dev:server": "bash server/start.sh",
|
||||
"dev:server:cpu": "bash server/start.sh --cpu",
|
||||
"dev:server:rocm": "bash server/start.sh --rocm",
|
||||
"dev:web": "pnpm --filter vibepod-web dev",
|
||||
"format": "prettier --write . && cd server && uv run ruff format .",
|
||||
"format:check": "prettier --check . && cd server && uv run ruff format --check .",
|
||||
|
||||
@@ -32,7 +32,7 @@ dev = [
|
||||
line-length = 100
|
||||
indent-width = 4
|
||||
target-version = "py310"
|
||||
exclude = [".git", ".venv", ".venv-cpu", "__pycache__"]
|
||||
exclude = [".git", ".venv", ".venv-cpu", ".venv-rocm", "__pycache__"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "UP", "B", "SIM", "I"]
|
||||
@@ -54,6 +54,11 @@ name = "pytorch-cu124"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-rocm62"
|
||||
url = "https://download.pytorch.org/whl/rocm6.2"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
# Pull torch from the PyTorch CUDA 12.4 index instead of PyPI's CPU-only wheel.
|
||||
# CUDA 12.4 runs on any driver >= 525.60 (RTX 30/40 series all qualify).
|
||||
|
||||
+29
-4
@@ -6,15 +6,17 @@
|
||||
# Usage:
|
||||
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
|
||||
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
|
||||
# ./start.sh --rocm — ROCm mode (AMD GPU, uses PyTorch ROCm 6.2 wheel, venv: .venv-rocm)
|
||||
#
|
||||
# Optional CUDA acceleration:
|
||||
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
||||
# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
|
||||
# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
|
||||
#
|
||||
# The two modes maintain completely separate virtual environments so their torch
|
||||
# The three modes maintain completely separate virtual environments so their torch
|
||||
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
||||
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
|
||||
# --no-sources skips [tool.uv.sources] so the CPU/ROCm run pulls the default PyPI torch
|
||||
# wheel first, then torch is replaced with the appropriate wheel for that mode.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
@@ -25,12 +27,14 @@ cd "$SCRIPT_DIR"
|
||||
# Parse flags
|
||||
# ---------------------------------------------------------------------------
|
||||
CPU_MODE=false
|
||||
ROCM_MODE=false
|
||||
PASSTHROUGH_ARGS=()
|
||||
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
--cpu) CPU_MODE=true ;;
|
||||
*) PASSTHROUGH_ARGS+=("$arg") ;;
|
||||
--cpu) CPU_MODE=true ;;
|
||||
--rocm) ROCM_MODE=true ;;
|
||||
*) PASSTHROUGH_ARGS+=("$arg") ;;
|
||||
esac
|
||||
done
|
||||
|
||||
@@ -38,6 +42,8 @@ echo "================================================"
|
||||
echo " VibePod TTS Server"
|
||||
if $CPU_MODE; then
|
||||
echo " Mode : CPU-only"
|
||||
elif $ROCM_MODE; then
|
||||
echo " Mode : ROCm (AMD GPU)"
|
||||
else
|
||||
echo " Mode : CUDA (default)"
|
||||
fi
|
||||
@@ -89,6 +95,21 @@ if $CPU_MODE; then
|
||||
cp "$LOCK_BACKUP" uv.lock
|
||||
rm -f "$LOCK_BACKUP"
|
||||
fi
|
||||
elif $ROCM_MODE; then
|
||||
echo "--> Syncing ROCm Python environment (.venv-rocm)..."
|
||||
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
|
||||
LOCK_BACKUP=""
|
||||
if [[ -f uv.lock ]]; then
|
||||
LOCK_BACKUP="$(mktemp)"
|
||||
cp uv.lock "$LOCK_BACKUP"
|
||||
fi
|
||||
uv sync --no-sources
|
||||
echo "--> Installing PyTorch ROCm 6.2 wheel..."
|
||||
uv pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
if [[ -n "$LOCK_BACKUP" ]]; then
|
||||
cp "$LOCK_BACKUP" uv.lock
|
||||
rm -f "$LOCK_BACKUP"
|
||||
fi
|
||||
else
|
||||
echo "--> Syncing CUDA Python environment (.venv)..."
|
||||
uv sync
|
||||
@@ -166,6 +187,10 @@ if $CPU_MODE; then
|
||||
# VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive
|
||||
# models on CPU — not recommended, kept for experimentation)
|
||||
UV_RUN_ARGS=(--no-sync --no-sources)
|
||||
elif $ROCM_MODE; then
|
||||
export VIBEPOD_DEVICE="rocm"
|
||||
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
|
||||
UV_RUN_ARGS=(--no-sync --no-sources)
|
||||
else
|
||||
export VIBEPOD_DEVICE="cuda"
|
||||
UV_RUN_ARGS=()
|
||||
|
||||
+51
-18
@@ -2,7 +2,7 @@
|
||||
VibePod — VibeVoice FastAPI TTS Server
|
||||
|
||||
This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model,
|
||||
optimized for real-time streaming on both CPU and NVIDIA GPU hardware.
|
||||
optimized for real-time streaming on CPU, NVIDIA GPU (CUDA), and AMD GPU (ROCm) hardware.
|
||||
|
||||
MAINTAINER GUIDE / FILE MAP:
|
||||
- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars.
|
||||
@@ -13,12 +13,12 @@ MAINTAINER GUIDE / FILE MAP:
|
||||
- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network.
|
||||
|
||||
RUNTIME CONFIGURATION (Environment Variables):
|
||||
- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset).
|
||||
- VIBEPOD_DEVICE: 'cpu', 'cuda', or 'rocm' (auto-detected if unset).
|
||||
- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU).
|
||||
- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend).
|
||||
- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill.
|
||||
- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback.
|
||||
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA).
|
||||
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA/ROCm).
|
||||
- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging.
|
||||
|
||||
CPU-SPECIFIC OPTIMIZATIONS:
|
||||
@@ -30,9 +30,9 @@ CPU-SPECIFIC OPTIMIZATIONS:
|
||||
- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization.
|
||||
- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS).
|
||||
|
||||
CUDA-SPECIFIC OPTIMIZATIONS:
|
||||
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'.
|
||||
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'.
|
||||
CUDA/ROCM OPTIMIZATIONS:
|
||||
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'. Applies to both CUDA and ROCm.
|
||||
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'. ROCm always uses sdpa.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -93,19 +93,41 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
|
||||
|
||||
def _resolve_device() -> str:
|
||||
"""
|
||||
Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment
|
||||
variable, falling back to CUDA if available, otherwise CPU.
|
||||
Resolve the target device (cpu, cuda, or rocm) by checking the VIBEPOD_DEVICE
|
||||
environment variable, falling back to auto-detection.
|
||||
|
||||
ROCm presents AMD GPUs as CUDA devices in PyTorch but sets torch.version.hip.
|
||||
Auto-detection returns 'rocm' when hip is detected, 'cuda' for NVIDIA, else 'cpu'.
|
||||
The string 'rocm' is used for logging/config; all PyTorch calls use 'cuda' as the
|
||||
actual device string (see _torch_device()).
|
||||
"""
|
||||
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
||||
if env in ("cpu", "cuda"):
|
||||
if env in ("cpu", "cuda", "rocm"):
|
||||
if env == "cuda" and not torch.cuda.is_available():
|
||||
logger.warning(
|
||||
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
|
||||
)
|
||||
return "cpu"
|
||||
if env == "rocm" and not (
|
||||
torch.cuda.is_available() and getattr(torch.version, "hip", None)
|
||||
):
|
||||
logger.warning(
|
||||
"VIBEPOD_DEVICE=rocm requested but ROCm is not available — falling back to CPU."
|
||||
)
|
||||
return "cpu"
|
||||
return env
|
||||
# Auto-detect
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# Auto-detect: ROCm sets torch.version.hip; NVIDIA leaves it None
|
||||
if torch.cuda.is_available():
|
||||
return "rocm" if getattr(torch.version, "hip", None) else "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _torch_device(device: str) -> str:
|
||||
"""Map the VibePod device string to the PyTorch device string.
|
||||
|
||||
ROCm GPUs are addressed as 'cuda' in all PyTorch APIs.
|
||||
"""
|
||||
return "cuda" if device == "rocm" else device
|
||||
|
||||
|
||||
# ── Env-var helpers ─────────────────────────────────────────────────────────────
|
||||
@@ -284,8 +306,10 @@ def _init_model(device: str):
|
||||
Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and
|
||||
apply VibePod-specific performance optimizations.
|
||||
"""
|
||||
torch_device = _torch_device(device)
|
||||
is_gpu = device in ("cuda", "rocm")
|
||||
logger.info("Loading model on %s...", device)
|
||||
if device == "cuda":
|
||||
if is_gpu:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
@@ -300,14 +324,18 @@ def _init_model(device: str):
|
||||
torch.backends.cuda.mem_efficient_sdp_enabled(),
|
||||
torch.backends.cuda.math_sdp_enabled(),
|
||||
)
|
||||
if device == "rocm":
|
||||
logger.info(
|
||||
"ROCm/HIP backend detected (torch.version.hip=%s)", torch.version.hip
|
||||
)
|
||||
elif device == "cpu":
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
logger.info("CPU runtime configuration: %s", _configure_cpu_runtime())
|
||||
|
||||
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
|
||||
if device == "cuda" and cuda_dtype == "fp16":
|
||||
if is_gpu and cuda_dtype == "fp16":
|
||||
load_dtype = torch.float16
|
||||
elif device == "cuda":
|
||||
elif is_gpu:
|
||||
load_dtype = torch.bfloat16
|
||||
else:
|
||||
cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower()
|
||||
@@ -328,7 +356,11 @@ def _init_model(device: str):
|
||||
logger.info("Loading model weights with dtype %s", load_dtype)
|
||||
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
|
||||
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
|
||||
if requested_attn_impl in {"eager", "sdpa"}:
|
||||
if device == "rocm":
|
||||
# flash_attn is not reliably supported on ROCm; always use sdpa.
|
||||
attn_impl = "sdpa"
|
||||
logger.info("ROCm: using sdpa attention (flash_attn not supported on ROCm).")
|
||||
elif requested_attn_impl in {"eager", "sdpa"}:
|
||||
attn_impl = requested_attn_impl
|
||||
elif requested_attn_impl == "flash_attention_2":
|
||||
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
|
||||
@@ -348,7 +380,7 @@ def _init_model(device: str):
|
||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=device,
|
||||
device_map=torch_device,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
except Exception as exc:
|
||||
@@ -362,7 +394,7 @@ def _init_model(device: str):
|
||||
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=device,
|
||||
device_map=torch_device,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
@@ -672,11 +704,12 @@ def _load_voice_presets(device: str) -> dict[str, object]:
|
||||
"""
|
||||
Load all pre-downloaded voice tensor files (.pt) from the voices directory.
|
||||
"""
|
||||
torch_device = _torch_device(device)
|
||||
presets = {}
|
||||
for name, filename in EN_VOICES.items():
|
||||
path = VOICES_DIR / filename
|
||||
if path.exists():
|
||||
presets[name] = torch.load(path, map_location=device, weights_only=False)
|
||||
presets[name] = torch.load(path, map_location=torch_device, weights_only=False)
|
||||
return presets
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user