mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-13 03:58:07 +00:00
Add AMD ROCm GPU support
Introduces a third hardware mode alongside CUDA and CPU: ROCm (AMD GPU). AMD GPUs present as CUDA devices under PyTorch ROCm, so the existing GPU path is reused with minimal changes — the main additions are wheel management, device detection, and suppressing flash_attn (unsupported on ROCm). - server/vibevoice_server.py: extend _resolve_device() to recognise 'rocm' (auto-detected via torch.version.hip); add _torch_device() helper that maps 'rocm' → 'cuda' for all PyTorch API calls; apply GPU optimisations for both cuda and rocm in _init_model(); always use sdpa on ROCm; propagate _torch_device() to _load_voice_presets() map_location. - server/start.sh: add --rocm flag; sync .venv-rocm with uv sync --no-sources then replace torch with the ROCm 6.2 wheel via uv pip install; set VIBEPOD_DEVICE=rocm for uvicorn. - server/pyproject.toml: register pytorch-rocm62 index (explicit); add .venv-rocm to ruff excludes. - package.json: add dev:rocm and dev:server:rocm scripts. - README.md: document ROCm mode, prerequisites (RX 6000+, ROCm 6.2+, Linux), and new commands; expand CUDA vs CPU section to CUDA vs CPU vs ROCm. https://claude.ai/code/session_0168pSswiaoEf6LEx6UQWfBu
This commit is contained in:
+29
-4
@@ -6,15 +6,17 @@
|
||||
# Usage:
|
||||
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
|
||||
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
|
||||
# ./start.sh --rocm — ROCm mode (AMD GPU, uses PyTorch ROCm 6.2 wheel, venv: .venv-rocm)
|
||||
#
|
||||
# Optional CUDA acceleration:
|
||||
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
||||
# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
|
||||
# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
|
||||
#
|
||||
# The two modes maintain completely separate virtual environments so their torch
|
||||
# The three modes maintain completely separate virtual environments so their torch
|
||||
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
||||
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
|
||||
# --no-sources skips [tool.uv.sources] so the CPU/ROCm run pulls the default PyPI torch
|
||||
# wheel first, then torch is replaced with the appropriate wheel for that mode.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
@@ -25,12 +27,14 @@ cd "$SCRIPT_DIR"
|
||||
# Parse flags
|
||||
# ---------------------------------------------------------------------------
|
||||
CPU_MODE=false
|
||||
ROCM_MODE=false
|
||||
PASSTHROUGH_ARGS=()
|
||||
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
--cpu) CPU_MODE=true ;;
|
||||
*) PASSTHROUGH_ARGS+=("$arg") ;;
|
||||
--cpu) CPU_MODE=true ;;
|
||||
--rocm) ROCM_MODE=true ;;
|
||||
*) PASSTHROUGH_ARGS+=("$arg") ;;
|
||||
esac
|
||||
done
|
||||
|
||||
@@ -38,6 +42,8 @@ echo "================================================"
|
||||
echo " VibePod TTS Server"
|
||||
if $CPU_MODE; then
|
||||
echo " Mode : CPU-only"
|
||||
elif $ROCM_MODE; then
|
||||
echo " Mode : ROCm (AMD GPU)"
|
||||
else
|
||||
echo " Mode : CUDA (default)"
|
||||
fi
|
||||
@@ -89,6 +95,21 @@ if $CPU_MODE; then
|
||||
cp "$LOCK_BACKUP" uv.lock
|
||||
rm -f "$LOCK_BACKUP"
|
||||
fi
|
||||
elif $ROCM_MODE; then
|
||||
echo "--> Syncing ROCm Python environment (.venv-rocm)..."
|
||||
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
|
||||
LOCK_BACKUP=""
|
||||
if [[ -f uv.lock ]]; then
|
||||
LOCK_BACKUP="$(mktemp)"
|
||||
cp uv.lock "$LOCK_BACKUP"
|
||||
fi
|
||||
uv sync --no-sources
|
||||
echo "--> Installing PyTorch ROCm 6.2 wheel..."
|
||||
uv pip install torch --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
if [[ -n "$LOCK_BACKUP" ]]; then
|
||||
cp "$LOCK_BACKUP" uv.lock
|
||||
rm -f "$LOCK_BACKUP"
|
||||
fi
|
||||
else
|
||||
echo "--> Syncing CUDA Python environment (.venv)..."
|
||||
uv sync
|
||||
@@ -166,6 +187,10 @@ if $CPU_MODE; then
|
||||
# VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive
|
||||
# models on CPU — not recommended, kept for experimentation)
|
||||
UV_RUN_ARGS=(--no-sync --no-sources)
|
||||
elif $ROCM_MODE; then
|
||||
export VIBEPOD_DEVICE="rocm"
|
||||
export UV_PROJECT_ENVIRONMENT=".venv-rocm"
|
||||
UV_RUN_ARGS=(--no-sync --no-sources)
|
||||
else
|
||||
export VIBEPOD_DEVICE="cuda"
|
||||
UV_RUN_ARGS=()
|
||||
|
||||
Reference in New Issue
Block a user