Add AMD ROCm GPU support

Introduces a third hardware mode alongside CUDA and CPU: ROCm (AMD GPU).
AMD GPUs present as CUDA devices under PyTorch ROCm, so the existing GPU
path is reused with minimal changes — the main additions are wheel management,
device detection, and suppressing flash_attn (unsupported on ROCm).

- server/vibevoice_server.py: extend _resolve_device() to recognise 'rocm'
  (auto-detected via torch.version.hip); add _torch_device() helper that maps
  'rocm' → 'cuda' for all PyTorch API calls; apply GPU optimisations for both
  cuda and rocm in _init_model(); always use sdpa on ROCm; propagate
  _torch_device() to _load_voice_presets() map_location.
- server/start.sh: add --rocm flag; sync .venv-rocm with uv sync --no-sources
  then replace torch with the ROCm 6.2 wheel via uv pip install; set
  VIBEPOD_DEVICE=rocm for uvicorn.
- server/pyproject.toml: register pytorch-rocm62 index (explicit); add
  .venv-rocm to ruff excludes.
- package.json: add dev:rocm and dev:server:rocm scripts.
- README.md: document ROCm mode, prerequisites (RX 6000+, ROCm 6.2+, Linux),
  and new commands; expand CUDA vs CPU section to CUDA vs CPU vs ROCm.

https://claude.ai/code/session_0168pSswiaoEf6LEx6UQWfBu
This commit is contained in:
Claude
2026-05-04 01:54:57 +00:00
parent f4d759c385
commit bb6da662de
5 changed files with 102 additions and 31 deletions
+29 -4
View File
@@ -6,15 +6,17 @@
# Usage:
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
# ./start.sh --rocm — ROCm mode (AMD GPU, uses PyTorch ROCm 6.2 wheel, venv: .venv-rocm)
#
# Optional CUDA acceleration:
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
#
# The two modes maintain completely separate virtual environments so their torch
# The three modes maintain completely separate virtual environments so their torch
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
# --no-sources skips [tool.uv.sources] so the CPU/ROCm run pulls the default PyPI torch
# wheel first, then torch is replaced with the appropriate wheel for that mode.
set -euo pipefail
@@ -25,12 +27,14 @@ cd "$SCRIPT_DIR"
# Parse flags
# ---------------------------------------------------------------------------
CPU_MODE=false
ROCM_MODE=false
PASSTHROUGH_ARGS=()
for arg in "$@"; do
case "$arg" in
--cpu) CPU_MODE=true ;;
*) PASSTHROUGH_ARGS+=("$arg") ;;
--cpu) CPU_MODE=true ;;
--rocm) ROCM_MODE=true ;;
*) PASSTHROUGH_ARGS+=("$arg") ;;
esac
done
@@ -38,6 +42,8 @@ echo "================================================"
echo " VibePod TTS Server"
if $CPU_MODE; then
echo " Mode : CPU-only"
elif $ROCM_MODE; then
echo " Mode : ROCm (AMD GPU)"
else
echo " Mode : CUDA (default)"
fi
@@ -89,6 +95,21 @@ if $CPU_MODE; then
cp "$LOCK_BACKUP" uv.lock
rm -f "$LOCK_BACKUP"
fi
elif $ROCM_MODE; then
echo "--> Syncing ROCm Python environment (.venv-rocm)..."
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
LOCK_BACKUP=""
if [[ -f uv.lock ]]; then
LOCK_BACKUP="$(mktemp)"
cp uv.lock "$LOCK_BACKUP"
fi
uv sync --no-sources
echo "--> Installing PyTorch ROCm 6.2 wheel..."
uv pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
if [[ -n "$LOCK_BACKUP" ]]; then
cp "$LOCK_BACKUP" uv.lock
rm -f "$LOCK_BACKUP"
fi
else
echo "--> Syncing CUDA Python environment (.venv)..."
uv sync
@@ -166,6 +187,10 @@ if $CPU_MODE; then
# VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive
# models on CPU — not recommended, kept for experimentation)
UV_RUN_ARGS=(--no-sync --no-sources)
elif $ROCM_MODE; then
export VIBEPOD_DEVICE="rocm"
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
UV_RUN_ARGS=(--no-sync --no-sources)
else
export VIBEPOD_DEVICE="cuda"
UV_RUN_ARGS=()