Add AMD ROCm GPU support

Introduces a third hardware mode alongside CUDA and CPU: ROCm (AMD GPU).
AMD GPUs present as CUDA devices under PyTorch ROCm, so the existing GPU
path is reused with minimal changes — the main additions are wheel management,
device detection, and suppressing flash_attn (unsupported on ROCm).

- server/vibevoice_server.py: extend _resolve_device() to recognise 'rocm'
  (auto-detected via torch.version.hip); add _torch_device() helper that maps
  'rocm' → 'cuda' for all PyTorch API calls; apply GPU optimisations for both
  cuda and rocm in _init_model(); always use sdpa on ROCm; propagate
  _torch_device() to _load_voice_presets() map_location.
- server/start.sh: add --rocm flag; sync .venv-rocm with uv sync --no-sources
  then replace torch with the ROCm 6.2 wheel via uv pip install; set
  VIBEPOD_DEVICE=rocm for uvicorn.
- server/pyproject.toml: register pytorch-rocm62 index (explicit); add
  .venv-rocm to ruff excludes.
- package.json: add dev:rocm and dev:server:rocm scripts.
- README.md: document ROCm mode, prerequisites (RX 6000+, ROCm 6.2+, Linux),
  and new commands; expand CUDA vs CPU section to CUDA vs CPU vs ROCm.

https://claude.ai/code/session_0168pSswiaoEf6LEx6UQWfBu
This commit is contained in:
Claude
2026-05-04 01:54:57 +00:00
parent f4d759c385
commit bb6da662de
5 changed files with 102 additions and 31 deletions
+51 -18
View File
@@ -2,7 +2,7 @@
VibePod — VibeVoice FastAPI TTS Server
This server provides a high-performance Text-to-Speech (TTS) interface for the VibeVoice model,
optimized for real-time streaming on both CPU and NVIDIA GPU hardware.
optimized for real-time streaming on CPU, NVIDIA GPU (CUDA), and AMD GPU (ROCm) hardware.
MAINTAINER GUIDE / FILE MAP:
- Device & Env Configuration: Helpers for hardware detection and runtime tuning via env vars.
@@ -13,12 +13,12 @@ MAINTAINER GUIDE / FILE MAP:
- Audio Streaming: Async bridge (NonBlockingAudioStreamer) between inference and the network.
RUNTIME CONFIGURATION (Environment Variables):
- VIBEPOD_DEVICE: 'cpu' or 'cuda' (auto-detected if unset).
- VIBEPOD_DEVICE: 'cpu', 'cuda', or 'rocm' (auto-detected if unset).
- VIBEPOD_CHUNK_ACCUM: Number of 20ms audio chunks to buffer before sending an SSE event (default: 4 for CPU).
- VIBEPOD_PREBUFFER_SECS: Initial client-side buffer duration (hinted to frontend).
- VIBEPOD_REBUFFER_THRESHOLD_SECS: Buffer level below which the client pauses to refill.
- VIBEPOD_RESUME_THRESHOLD_SECS: Buffer level at which the client resumes playback.
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA).
- VIBEPOD_DEFAULT_INFERENCE_STEPS: Default DDPM steps (default: 8 for CPU, 10 for CUDA/ROCm).
- VIBEPOD_PROFILE_GENERATION: Set to '1' to enable detailed performance logging.
CPU-SPECIFIC OPTIMIZATIONS:
@@ -30,9 +30,9 @@ CPU-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_QUANTIZE: Set to '1' to enable experimental dynamic INT8 quantization.
- VIBEPOD_COMPILE: Set to '1' to enable experimental torch.compile (limited benefit for TTS).
CUDA-SPECIFIC OPTIMIZATIONS:
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'.
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'.
CUDA/ROCM OPTIMIZATIONS:
- VIBEPOD_CUDA_DTYPE: 'bf16' (default) or 'fp16'. Applies to both CUDA and ROCm.
- VIBEPOD_ATTN_IMPL: 'auto', 'sdpa', 'eager', or 'flash_attention_2'. ROCm always uses sdpa.
"""
import asyncio
@@ -93,19 +93,41 @@ _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
def _resolve_device() -> str:
"""
Resolve the target device (CPU or CUDA) by checking the VIBEPOD_DEVICE environment
variable, falling back to CUDA if available, otherwise CPU.
Resolve the target device (cpu, cuda, or rocm) by checking the VIBEPOD_DEVICE
environment variable, falling back to auto-detection.
ROCm presents AMD GPUs as CUDA devices in PyTorch but sets torch.version.hip.
Auto-detection returns 'rocm' when hip is detected, 'cuda' for NVIDIA, else 'cpu'.
The string 'rocm' is used for logging/config; all PyTorch calls use 'cuda' as the
actual device string (see _torch_device()).
"""
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
if env in ("cpu", "cuda"):
if env in ("cpu", "cuda", "rocm"):
if env == "cuda" and not torch.cuda.is_available():
logger.warning(
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
)
return "cpu"
if env == "rocm" and not (
torch.cuda.is_available() and getattr(torch.version, "hip", None)
):
logger.warning(
"VIBEPOD_DEVICE=rocm requested but ROCm is not available — falling back to CPU."
)
return "cpu"
return env
# Auto-detect
return "cuda" if torch.cuda.is_available() else "cpu"
# Auto-detect: ROCm sets torch.version.hip; NVIDIA leaves it None
if torch.cuda.is_available():
return "rocm" if getattr(torch.version, "hip", None) else "cuda"
return "cpu"
def _torch_device(device: str) -> str:
"""Map the VibePod device string to the PyTorch device string.
ROCm GPUs are addressed as 'cuda' in all PyTorch APIs.
"""
return "cuda" if device == "rocm" else device
# ── Env-var helpers ─────────────────────────────────────────────────────────────
@@ -284,8 +306,10 @@ def _init_model(device: str):
Load the VibeVoice model with appropriate precision (BF16/FP16/FP32) and
apply VibePod-specific performance optimizations.
"""
torch_device = _torch_device(device)
is_gpu = device in ("cuda", "rocm")
logger.info("Loading model on %s...", device)
if device == "cuda":
if is_gpu:
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@@ -300,14 +324,18 @@ def _init_model(device: str):
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled(),
)
if device == "rocm":
logger.info(
"ROCm/HIP backend detected (torch.version.hip=%s)", torch.version.hip
)
elif device == "cpu":
torch.set_float32_matmul_precision("medium")
logger.info("CPU runtime configuration: %s", _configure_cpu_runtime())
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
if device == "cuda" and cuda_dtype == "fp16":
if is_gpu and cuda_dtype == "fp16":
load_dtype = torch.float16
elif device == "cuda":
elif is_gpu:
load_dtype = torch.bfloat16
else:
cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower()
@@ -328,7 +356,11 @@ def _init_model(device: str):
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
if requested_attn_impl in {"eager", "sdpa"}:
if device == "rocm":
# flash_attn is not reliably supported on ROCm; always use sdpa.
attn_impl = "sdpa"
logger.info("ROCm: using sdpa attention (flash_attn not supported on ROCm).")
elif requested_attn_impl in {"eager", "sdpa"}:
attn_impl = requested_attn_impl
elif requested_attn_impl == "flash_attention_2":
attn_impl = "flash_attention_2" if has_flash_attn else "sdpa"
@@ -348,7 +380,7 @@ def _init_model(device: str):
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
device_map=torch_device,
attn_implementation=attn_impl,
)
except Exception as exc:
@@ -362,7 +394,7 @@ def _init_model(device: str):
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
MODEL_ID,
torch_dtype=load_dtype,
device_map=device,
device_map=torch_device,
attn_implementation="sdpa",
)
@@ -672,11 +704,12 @@ def _load_voice_presets(device: str) -> dict[str, object]:
"""
Load all pre-downloaded voice tensor files (.pt) from the voices directory.
"""
torch_device = _torch_device(device)
presets = {}
for name, filename in EN_VOICES.items():
path = VOICES_DIR / filename
if path.exists():
presets[name] = torch.load(path, map_location=device, weights_only=False)
presets[name] = torch.load(path, map_location=torch_device, weights_only=False)
return presets