From bb6da662deb43b22f0746067e22e44f57f552542 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 4 May 2026 01:54:57 +0000 Subject: [PATCH] Add AMD ROCm GPU support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a third hardware mode alongside CUDA and CPU: ROCm (AMD GPU). AMD GPUs present as CUDA devices under PyTorch ROCm, so the existing GPU path is reused with minimal changes — the main additions are wheel management, device detection, and suppressing flash_attn (unsupported on ROCm). - server/vibevoice_server.py: extend _resolve_device() to recognise 'rocm' (auto-detected via torch.version.hip); add _torch_device() helper that maps 'rocm' → 'cuda' for all PyTorch API calls; apply GPU optimisations for both cuda and rocm in _init_model(); always use sdpa on ROCm; propagate _torch_device() to _load_voice_presets() map_location. - server/start.sh: add --rocm flag; sync .venv-rocm with uv sync --no-sources then replace torch with the ROCm 6.2 wheel via uv pip install; set VIBEPOD_DEVICE=rocm for uvicorn. - server/pyproject.toml: register pytorch-rocm62 index (explicit); add .venv-rocm to ruff excludes. - package.json: add dev:rocm and dev:server:rocm scripts. - README.md: document ROCm mode, prerequisites (RX 6000+, ROCm 6.2+, Linux), and new commands; expand CUDA vs CPU section to CUDA vs CPU vs ROCm. https://claude.ai/code/session_0168pSswiaoEf6LEx6UQWfBu --- README.md | 22 +++++++----- package.json | 2 ++ server/pyproject.toml | 7 +++- server/start.sh | 33 +++++++++++++++--- server/vibevoice_server.py | 69 ++++++++++++++++++++++++++++---------- 5 files changed, 102 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index ab76d3c..511b5f4 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ cp .env.example .env.local # 4. Start everything pnpm dev # CUDA (requires NVIDIA GPU + driver >= 525.60) pnpm dev:cpu # CPU-only (no GPU required) +pnpm dev:rocm # ROCm (requires AMD GPU + ROCm 6.2+, Linux only) ``` `pnpm dev` / `pnpm dev:cpu` start both services concurrently: @@ -46,26 +47,31 @@ pnpm dev:cpu # CPU-only (no GPU required) The frontend shows a loading indicator while the model downloads. Once the server reports `status: online`, generation is available. -## CUDA vs CPU +## CUDA vs CPU vs ROCm -VibePod maintains two completely separate Python virtual environments so CUDA and CPU torch installs never conflict: +VibePod maintains three completely separate Python virtual environments so torch installs never conflict: -| Mode | Command | venv | torch source | -| -------------- | -------------- | ------------------ | ----------------------- | -| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index | -| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) | +| Mode | Command | venv | torch source | +| -------------- | ---------------- | --------------------- | ------------------------- | +| CUDA (default) | `pnpm dev` | `server/.venv` | PyTorch CUDA 12.4 index | +| CPU-only | `pnpm dev:cpu` | `server/.venv-cpu` | PyPI (CPU wheel) | +| ROCm (AMD GPU) | `pnpm dev:rocm` | `server/.venv-rocm` | PyTorch ROCm 6.2 index | -On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"` or `"device": "cuda"`. +On first run, each mode creates its own venv automatically. You can switch between them freely — they are fully independent. The active device is reported by the `/health` endpoint as `"device": "cpu"`, `"device": "cuda"`, or `"device": "rocm"`. > **CUDA requirement:** driver >= 525.60 (RTX 30/40 series all qualify). Run `nvidia-smi` to check. +> **ROCm requirement:** ROCm 6.2+ installed on Linux. Supported GPUs: AMD RX 6000 series (RDNA2) or newer, RX 7000 series (RDNA3), and Instinct accelerators. ROCm is not supported on Windows. Flash attention is not available on ROCm — SDPA is used instead. + ## Individual commands ```bash pnpm dev # CUDA — server + web pnpm dev:cpu # CPU — server + web +pnpm dev:rocm # ROCm — server + web pnpm dev:server # CUDA — Python server only pnpm dev:server:cpu # CPU — Python server only +pnpm dev:server:rocm # ROCm — Python server only pnpm dev:web # Next.js only (no Python server) pnpm build # Production build of the frontend ``` @@ -133,4 +139,4 @@ cd server && uv add cd server && uv lock --upgrade ``` -> **Note:** The `[tool.uv.sources]` block in `pyproject.toml` pulls torch from the PyTorch CUDA 12.4 index by default. Running with `--cpu` (or `uv sync --no-sources`) bypasses this and installs the standard PyPI CPU wheel instead. +> **Note:** The `[tool.uv.sources]` block in `pyproject.toml` pulls torch from the PyTorch CUDA 12.4 index by default. Running with `--cpu` or `--rocm` (or `uv sync --no-sources`) bypasses this and installs the standard PyPI CPU wheel first; for ROCm, the torch wheel is then replaced with the PyTorch ROCm 6.2 build. diff --git a/package.json b/package.json index 8769a40..aafd514 100644 --- a/package.json +++ b/package.json @@ -6,8 +6,10 @@ "build": "pnpm --filter vibepod-web build", "dev": "bash dev.sh", "dev:cpu": "bash dev.sh --cpu", + "dev:rocm": "bash dev.sh --rocm", "dev:server": "bash server/start.sh", "dev:server:cpu": "bash server/start.sh --cpu", + "dev:server:rocm": "bash server/start.sh --rocm", "dev:web": "pnpm --filter vibepod-web dev", "format": "prettier --write . && cd server && uv run ruff format .", "format:check": "prettier --check . && cd server && uv run ruff format --check .", diff --git a/server/pyproject.toml b/server/pyproject.toml index a45dad9..186a285 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -32,7 +32,7 @@ dev = [ line-length = 100 indent-width = 4 target-version = "py310" -exclude = [".git", ".venv", ".venv-cpu", "__pycache__"] +exclude = [".git", ".venv", ".venv-cpu", ".venv-rocm", "__pycache__"] [tool.ruff.lint] select = ["E", "F", "UP", "B", "SIM", "I"] @@ -54,6 +54,11 @@ name = "pytorch-cu124" url = "https://download.pytorch.org/whl/cu124" explicit = true +[[tool.uv.index]] +name = "pytorch-rocm62" +url = "https://download.pytorch.org/whl/rocm6.2" +explicit = true + [tool.uv.sources] # Pull torch from the PyTorch CUDA 12.4 index instead of PyPI's CPU-only wheel. # CUDA 12.4 runs on any driver >= 525.60 (RTX 30/40 series all qualify). diff --git a/server/start.sh b/server/start.sh index c50277d..fd94f70 100755 --- a/server/start.sh +++ b/server/start.sh @@ -6,15 +6,17 @@ # Usage: # ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv) # ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu) +# ./start.sh --rocm — ROCm mode (AMD GPU, uses PyTorch ROCm 6.2 wheel, venv: .venv-rocm) # # Optional CUDA acceleration: # VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh # Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12, # torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA. # -# The two modes maintain completely separate virtual environments so their torch +# The three modes maintain completely separate virtual environments so their torch # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; -# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel. +# --no-sources skips [tool.uv.sources] so the CPU/ROCm run pulls the default PyPI torch +# wheel first, then torch is replaced with the appropriate wheel for that mode. set -euo pipefail @@ -25,12 +27,14 @@ cd "$SCRIPT_DIR" # Parse flags # --------------------------------------------------------------------------- CPU_MODE=false +ROCM_MODE=false PASSTHROUGH_ARGS=() for arg in "$@"; do case "$arg" in - --cpu) CPU_MODE=true ;; - *) PASSTHROUGH_ARGS+=("$arg") ;; + --cpu) CPU_MODE=true ;; + --rocm) ROCM_MODE=true ;; + *) PASSTHROUGH_ARGS+=("$arg") ;; esac done @@ -38,6 +42,8 @@ echo "================================================" echo " VibePod TTS Server" if $CPU_MODE; then echo " Mode : CPU-only" +elif $ROCM_MODE; then + echo " Mode : ROCm (AMD GPU)" else echo " Mode : CUDA (default)" fi @@ -89,6 +95,21 @@ if $CPU_MODE; then cp "$LOCK_BACKUP" uv.lock rm -f "$LOCK_BACKUP" fi +elif $ROCM_MODE; then + echo "--> Syncing ROCm Python environment (.venv-rocm)..." + export UV_PROJECT_ENVIRONMENT=".venv-rocm" + LOCK_BACKUP="" + if [[ -f uv.lock ]]; then + LOCK_BACKUP="$(mktemp)" + cp uv.lock "$LOCK_BACKUP" + fi + uv sync --no-sources + echo "--> Installing PyTorch ROCm 6.2 wheel..." + uv pip install torch --index-url https://download.pytorch.org/whl/rocm6.2 + if [[ -n "$LOCK_BACKUP" ]]; then + cp "$LOCK_BACKUP" uv.lock + rm -f "$LOCK_BACKUP" + fi else echo "--> Syncing CUDA Python environment (.venv)..." uv sync @@ -166,6 +187,10 @@ if $CPU_MODE; then # VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive # models on CPU — not recommended, kept for experimentation) UV_RUN_ARGS=(--no-sync --no-sources) +elif $ROCM_MODE; then + export VIBEPOD_DEVICE="rocm" + export UV_PROJECT_ENVIRONMENT=".venv-rocm" + UV_RUN_ARGS=(--no-sync --no-sources) else export VIBEPOD_DEVICE="cuda" UV_RUN_ARGS=() diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index b5f7880..4862c7a 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -2,7 +2,7 @@ VibePod — VibeVoice FastAPI TTS Server This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model, -optimized for real-time streaming on both CPU and NVIDIA GPU hardware. +optimized for real-time streaming on CPU, NVIDIA GPU (CUDA), and AMD GPU (ROCm) hardware. MAINTAINER GUIDE / FILE MAP: - Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars. @@ -13,12 +13,12 @@ MAINTAINER GUIDE / FILE MAP: - Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network. RUNTIME CONFIGURATION (Environment Variables): -- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset). +- VIBEPOD_DEVICE: 'cpu', 'cuda', or 'rocm' (auto-detected if unset). - VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU). - VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend). - VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill. - VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback. -- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA). +- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA/ROCm). - VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging. CPU-SPECIFIC OPTIMIZATIONS: @@ -30,9 +30,9 @@ CPU-SPECIFIC OPTIMIZATIONS: - VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization. - VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS). -CUDA-SPECIFIC OPTIMIZATIONS: -- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'. -- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'. +CUDA/ROCM OPTIMIZATIONS: +- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'. Applies to both CUDA and ROCm. +- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'. ROCm always uses sdpa. """ import asyncio @@ -93,19 +93,41 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None def _resolve_device() -> str: """ - Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment - variable, falling back to CUDA if available, otherwise CPU. + Resolve the target device (cpu, cuda, or rocm) by checking the VIBEPOD_DEVICE + environment variable, falling back to auto-detection. + + ROCm presents AMD GPUs as CUDA devices in PyTorch but sets torch.version.hip. + Auto-detection returns 'rocm' when hip is detected, 'cuda' for NVIDIA, else 'cpu'. + The string 'rocm' is used for logging/config; all PyTorch calls use 'cuda' as the + actual device string (see _torch_device()). """ env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower() - if env in ("cpu", "cuda"): + if env in ("cpu", "cuda", "rocm"): if env == "cuda" and not torch.cuda.is_available(): logger.warning( "VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU." ) return "cpu" + if env == "rocm" and not ( + torch.cuda.is_available() and getattr(torch.version, "hip", None) + ): + logger.warning( + "VIBEPOD_DEVICE=rocm requested but ROCm is not available — falling back to CPU." + ) + return "cpu" return env - # Auto-detect - return "cuda" if torch.cuda.is_available() else "cpu" + # Auto-detect: ROCm sets torch.version.hip; NVIDIA leaves it None + if torch.cuda.is_available(): + return "rocm" if getattr(torch.version, "hip", None) else "cuda" + return "cpu" + + +def _torch_device(device: str) -> str: + """Map the VibePod device string to the PyTorch device string. + + ROCm GPUs are addressed as 'cuda' in all PyTorch APIs. + """ + return "cuda" if device == "rocm" else device # ── Env-var helpers ───────────────────────────────────────────────────────────── @@ -284,8 +306,10 @@ def _init_model(device: str): Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and apply VibePod-specific performance optimizations. """ + torch_device = _torch_device(device) + is_gpu = device in ("cuda", "rocm") logger.info("Loading model on %s...", device) - if device == "cuda": + if is_gpu: torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -300,14 +324,18 @@ def _init_model(device: str): torch.backends.cuda.mem_efficient_sdp_enabled(), torch.backends.cuda.math_sdp_enabled(), ) + if device == "rocm": + logger.info( + "ROCm/HIP backend detected (torch.version.hip=%s)", torch.version.hip + ) elif device == "cpu": torch.set_float32_matmul_precision("medium") logger.info("CPU runtime configuration: %s", _configure_cpu_runtime()) cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower() - if device == "cuda" and cuda_dtype == "fp16": + if is_gpu and cuda_dtype == "fp16": load_dtype = torch.float16 - elif device == "cuda": + elif is_gpu: load_dtype = torch.bfloat16 else: cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower() @@ -328,7 +356,11 @@ def _init_model(device: str): logger.info("Loading model weights with dtype %s", load_dtype) requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower() has_flash_attn = importlib.util.find_spec("flash_attn") is not None - if requested_attn_impl in {"eager", "sdpa"}: + if device == "rocm": + # flash_attn is not reliably supported on ROCm; always use sdpa. + attn_impl = "sdpa" + logger.info("ROCm: using sdpa attention (flash_attn not supported on ROCm).") + elif requested_attn_impl in {"eager", "sdpa"}: attn_impl = requested_attn_impl elif requested_attn_impl == "flash_attention_2": attn_impl = "flash_attention_2" if has_flash_attn else "sdpa" @@ -348,7 +380,7 @@ def _init_model(device: str): model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, torch_dtype=load_dtype, - device_map=device, + device_map=torch_device, attn_implementation=attn_impl, ) except Exception as exc: @@ -362,7 +394,7 @@ def _init_model(device: str): model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( MODEL_ID, torch_dtype=load_dtype, - device_map=device, + device_map=torch_device, attn_implementation="sdpa", ) @@ -672,11 +704,12 @@ def _load_voice_presets(device: str) -> dict[str, object]: """ Load all pre-downloaded voice tensor files (.pt) from the voices directory. """ + torch_device = _torch_device(device) presets = {} for name, filename in EN_VOICES.items(): path = VOICES_DIR / filename if path.exists(): - presets[name] = torch.load(path, map_location=device, weights_only=False) + presets[name] = torch.load(path, map_location=torch_device, weights_only=False) return presets