Add AMD ROCm GPU support

Introduces a third hardware mode alongside CUDA and CPU: ROCm (AMD GPU).
AMD GPUs present as CUDA devices under PyTorch ROCm, so the existing GPU
path is reused with minimal changes — the main additions are wheel management,
device detection, and suppressing flash_attn (unsupported on ROCm).

- server/vibevoice_server.py: extend _resolve_device() to recognise 'rocm'
  (auto-detected via torch.version.hip); add _torch_device() helper that maps
  'rocm' → 'cuda' for all PyTorch API calls; apply GPU optimisations for both
  cuda and rocm in _init_model(); always use sdpa on ROCm; propagate
  _torch_device() to _load_voice_presets() map_location.
- server/start.sh: add --rocm flag; sync .venv-rocm with uv sync --no-sources
  then replace torch with the ROCm 6.2 wheel via uv pip install; set
  VIBEPOD_DEVICE=rocm for uvicorn.
- server/pyproject.toml: register pytorch-rocm62 index (explicit); add
  .venv-rocm to ruff excludes.
- package.json: add dev:rocm and dev:server:rocm scripts.
- README.md: document ROCm mode, prerequisites (RX 6000+, ROCm 6.2+, Linux),
  and new commands; expand CUDA vs CPU section to CUDA vs CPU vs ROCm.

https://claude.ai/code/session_0168pSswiaoEf6LEx6UQWfBu
This commit is contained in:
Claude
2026-05-04 01:54:57 +00:00
parent f4d759c385
commit bb6da662de
5 changed files with 102 additions and 31 deletions
+14 -8
View File
@@ -37,6 +37,7 @@ cp .env.example .env.local
# 4. Start everything
pnpm dev # CUDA (requires NVIDIA GPU + driver >= 525.60)
pnpm dev:cpu # CPU-only (no GPU required)
pnpm dev:rocm # ROCm (requires AMD GPU + ROCm 6.2+, Linux only)
```
`pnpm dev` / `pnpm dev:cpu` start both services concurrently:
@@ -46,26 +47,31 @@ pnpm dev:cpu # CPU-only (no GPU required)
The frontend shows a loading indicator while the model downloads. Once the server reports `status: online`, generation is available.
## CUDA vs CPU
## CUDA vs CPU vs ROCm
VibePod maintains two completely separate Python virtual environments so CUDA and CPU torch installs never conflict:
VibePod maintains three completely separate Python virtual environments so torch installs never conflict:
| Mode | Command | venv | torch source |
| -------------- | -------------- | ------------------ | ----------------------- |
| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index |
| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) |
| Mode | Command | venv | torch source |
| -------------- | ---------------- | --------------------- | ------------------------- |
| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index |
| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) |
| ROCm (AMD GPU) | `pnpm dev:rocm` | `server/.venv-rocm` | PyTorch ROCm 6.2 index |
On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"` or `"device": "cuda"`.
On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"`, `"device": "cuda"`, or `"device": "rocm"`.
> **CUDA requirement:** driver >= 525.60 (RTX 30/40 series all qualify). Run `nvidia-smi` to check.
> **ROCm requirement:** ROCm 6.2+ installed on Linux. Supported GPUs: AMD RX 6000 series (RDNA2) or newer, RX 7000 series (RDNA3), and Instinct accelerators. ROCm is not supported on Windows. Flash attention is not available on ROCm — SDPA is used instead.
## Individual commands
```bash
pnpm dev # CUDA — server + web
pnpm dev:cpu # CPU — server + web
pnpm dev:rocm # ROCm — server + web
pnpm dev:server # CUDA — Python server only
pnpm dev:server:cpu # CPU — Python server only
pnpm dev:server:rocm # ROCm — Python server only
pnpm dev:web # Next.js only (no Python server)
pnpm build # Production build of the frontend
```
@@ -133,4 +139,4 @@ cd server && uv add <package>
cd server && uv lock --upgrade
```
> **Note:** The `[tool.uv.sources]` block in `pyproject.toml` pulls torch from the PyTorch CUDA 12.4 index by default. Running with `--cpu` (or `uv sync --no-sources`) bypasses this and installs the standard PyPI CPU wheel instead.
> **Note:** The `[tool.uv.sources]` block in `pyproject.toml` pulls torch from the PyTorch CUDA 12.4 index by default. Running with `--cpu` or `--rocm` (or `uv sync --no-sources`) bypasses this and installs the standard PyPI CPU wheel first; for ROCm, the torch wheel is then replaced with the PyTorch ROCm 6.2 build.
+2
View File
@@ -6,8 +6,10 @@
"build": "pnpm --filter vibepod-web build",
"dev": "bash dev.sh",
"dev:cpu": "bash dev.sh --cpu",
"dev:rocm": "bash dev.sh --rocm",
"dev:server": "bash server/start.sh",
"dev:server:cpu": "bash server/start.sh --cpu",
"dev:server:rocm": "bash server/start.sh --rocm",
"dev:web": "pnpm --filter vibepod-web dev",
"format": "prettier --write . && cd server && uv run ruff format .",
"format:check": "prettier --check . && cd server && uv run ruff format --check .",
+6 -1
View File
@@ -32,7 +32,7 @@ dev = [
line-length = 100
indent-width = 4
target-version = "py310"
exclude = [".git", ".venv", ".venv-cpu", "__pycache__"]
exclude = [".git", ".venv", ".venv-cpu", ".venv-rocm", "__pycache__"]
[tool.ruff.lint]
select = ["E", "F", "UP", "B", "SIM", "I"]
@@ -54,6 +54,11 @@ name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true
[[tool.uv.index]]
name = "pytorch-rocm62"
url = "https://download.pytorch.org/whl/rocm6.2"
explicit = true
[tool.uv.sources]
# Pull torch from the PyTorch CUDA 12.4 index instead of PyPI's CPU-only wheel.
# CUDA 12.4 runs on any driver >= 525.60 (RTX 30/40 series all qualify).
+29 -4
View File
@@ -6,15 +6,17 @@
# Usage:
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
# ./start.sh --rocm — ROCm mode (AMD GPU, uses PyTorch ROCm 6.2 wheel, venv: .venv-rocm)
#
# Optional CUDA acceleration:
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
#
# The two modes maintain completely separate virtual environments so their torch
# The three modes maintain completely separate virtual environments so their torch
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
# --no-sources skips [tool.uv.sources] so the CPU/ROCm run pulls the default PyPI torch
# wheel first, then torch is replaced with the appropriate wheel for that mode.
set -euo pipefail
@@ -25,12 +27,14 @@ cd "$SCRIPT_DIR"
# Parse flags
# ---------------------------------------------------------------------------
CPU_MODE=false
ROCM_MODE=false
PASSTHROUGH_ARGS=()
for arg in "$@"; do
case "$arg" in
--cpu) CPU_MODE=true ;;
*) PASSTHROUGH_ARGS+=("$arg") ;;
--cpu) CPU_MODE=true ;;
--rocm) ROCM_MODE=true ;;
*) PASSTHROUGH_ARGS+=("$arg") ;;
esac
done
@@ -38,6 +42,8 @@ echo "================================================"
echo " VibePod TTS Server"
if $CPU_MODE; then
echo " Mode : CPU-only"
elif $ROCM_MODE; then
echo " Mode : ROCm (AMD GPU)"
else
echo " Mode : CUDA (default)"
fi
@@ -89,6 +95,21 @@ if $CPU_MODE; then
cp "$LOCK_BACKUP" uv.lock
rm -f "$LOCK_BACKUP"
fi
elif $ROCM_MODE; then
echo "--> Syncing ROCm Python environment (.venv-rocm)..."
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
LOCK_BACKUP=""
if [[ -f uv.lock ]]; then
LOCK_BACKUP="$(mktemp)"
cp uv.lock "$LOCK_BACKUP"
fi
uv sync --no-sources
echo "--> Installing PyTorch ROCm 6.2 wheel..."
uv pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
if [[ -n "$LOCK_BACKUP" ]]; then
cp "$LOCK_BACKUP" uv.lock
rm -f "$LOCK_BACKUP"
fi
else
echo "--> Syncing CUDA Python environment (.venv)..."
uv sync
@@ -166,6 +187,10 @@ if $CPU_MODE; then
# VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive
# models on CPU — not recommended, kept for experimentation)
UV_RUN_ARGS=(--no-sync --no-sources)
elif $ROCM_MODE; then
export VIBEPOD_DEVICE="rocm"
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
UV_RUN_ARGS=(--no-sync --no-sources)
else
export VIBEPOD_DEVICE="cuda"
UV_RUN_ARGS=()
+51 -18
View File
@@ -2,7 +2,7 @@
VibePod — VibeVoice FastAPI TTS Server
This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model,
optimized for real-time streaming on both CPU and NVIDIA GPU hardware.
optimized for real-time streaming on CPU, NVIDIA GPU (CUDA), and AMD GPU (ROCm) hardware.
MAINTAINER GUIDE / FILE MAP:
- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars.
@@ -13,12 +13,12 @@ MAINTAINER GUIDE / FILE MAP:
- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network.
RUNTIME CONFIGURATION (Environment Variables):
- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset).
- VIBEPOD_DEVICE: 'cpu', 'cuda', or 'rocm' (auto-detected if unset).
- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU).
- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend).
- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill.
- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback.
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA).
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA/ROCm).
- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging.
CPU-SPECIFIC OPTIMIZATIONS:
@@ -30,9 +30,9 @@ CPU-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization.
- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS).
CUDA-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'.
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'.
CUDA/ROCM OPTIMIZATIONS:
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'. Applies to both CUDA and ROCm.
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'. ROCm always uses sdpa.
"""
import asyncio
@@ -93,19 +93,41 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
def _resolve_device() -> str:
"""
Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment
variable, falling back to CUDA if available, otherwise CPU.
Resolve the target device (cpu, cuda, or rocm) by checking the VIBEPOD_DEVICE
environment variable, falling back to auto-detection.
ROCm presents AMD GPUs as CUDA devices in PyTorch but sets torch.version.hip.
Auto-detection returns 'rocm' when hip is detected, 'cuda' for NVIDIA, else 'cpu'.
The string 'rocm' is used for logging/config; all PyTorch calls use 'cuda' as the
actual device string (see _torch_device()).
"""
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
if env in ("cpu", "cuda"):
if env in ("cpu", "cuda", "rocm"):
if env == "cuda" and not torch.cuda.is_available():
logger.warning(
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
)
return "cpu"
if env == "rocm" and not (
torch.cuda.is_available() and getattr(torch.version, "hip", None)
):
logger.warning(
"VIBEPOD_DEVICE=rocm requested but ROCm is not available — falling back to CPU."
)
return "cpu"
return env
# Auto-detect
return "cuda" if torch.cuda.is_available() else "cpu"
# Auto-detect: ROCm sets torch.version.hip; NVIDIA leaves it None
if torch.cuda.is_available():
return "rocm" if getattr(torch.version, "hip", None) else "cuda"
return "cpu"
def _torch_device(device: str) -> str:
"""Map the VibePod device string to the PyTorch device string.
ROCm GPUs are addressed as 'cuda' in all PyTorch APIs.
"""
return "cuda" if device == "rocm" else device
# ── Env-var helpers ─────────────────────────────────────────────────────────────
@@ -284,8 +306,10 @@ def _init_model(device: str):
Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and
apply VibePod-specific performance optimizations.
"""
torch_device = _torch_device(device)
is_gpu = device in ("cuda", "rocm")
logger.info("Loading model on %s...", device)
if device == "cuda":
if is_gpu:
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@@ -300,14 +324,18 @@ def _init_model(device: str):
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled(),
)
if device == "rocm":
logger.info(
"ROCm/HIP backend detected (torch.version.hip=%s)", torch.version.hip
)
elif device == "cpu":
torch.set_float32_matmul_precision("medium")
logger.info("CPU runtime configuration: %s", _configure_cpu_runtime())
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
if device == "cuda" and cuda_dtype == "fp16":
if is_gpu and cuda_dtype == "fp16":
load_dtype = torch.float16
elif device == "cuda":
elif is_gpu:
load_dtype = torch.bfloat16
else:
cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower()
@@ -328,7 +356,11 @@ def _init_model(device: str):
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
if requested_attn_impl in {"eager", "sdpa"}:
if device == "rocm":
# flash_attn is not reliably supported on ROCm; always use sdpa.
attn_impl = "sdpa"
logger.info("ROCm: using sdpa attention (flash_attn not supported on ROCm).")
elif requested_attn_impl in {"eager", "sdpa"}:
attn_impl = requested_attn_impl
elif requested_attn_impl == "flash_attention_2":
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
@@ -348,7 +380,7 @@ def _init_model(device: str):
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
device_map=torch_device,
attn_implementation=attn_impl,
)
except Exception as exc:
@@ -362,7 +394,7 @@ def _init_model(device: str):
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
device_map=torch_device,
attn_implementation="sdpa",
)
@@ -672,11 +704,12 @@ def _load_voice_presets(device: str) -> dict[str, object]:
"""
Load all pre-downloaded voice tensor files (.pt) from the voices directory.
"""
torch_device = _torch_device(device)
presets = {}
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
presets[name] = torch.load(path, map_location=device, weights_only=False)
presets[name] = torch.load(path, map_location=torch_device, weights_only=False)
return presets