mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-13 03:58:07 +00:00
Add AMD ROCm GPU support
Introduces a third hardware mode alongside CUDA and CPU: ROCm (AMD GPU). AMD GPUs present as CUDA devices under PyTorch ROCm, so the existing GPU path is reused with minimal changes — the main additions are wheel management, device detection, and suppressing flash_attn (unsupported on ROCm). - server/vibevoice_server.py: extend _resolve_device() to recognise 'rocm' (auto-detected via torch.version.hip); add _torch_device() helper that maps 'rocm' → 'cuda' for all PyTorch API calls; apply GPU optimisations for both cuda and rocm in _init_model(); always use sdpa on ROCm; propagate _torch_device() to _load_voice_presets() map_location. - server/start.sh: add --rocm flag; sync .venv-rocm with uv sync --no-sources then replace torch with the ROCm 6.2 wheel via uv pip install; set VIBEPOD_DEVICE=rocm for uvicorn. - server/pyproject.toml: register pytorch-rocm62 index (explicit); add .venv-rocm to ruff excludes. - package.json: add dev:rocm and dev:server:rocm scripts. - README.md: document ROCm mode, prerequisites (RX 6000+, ROCm 6.2+, Linux), and new commands; expand CUDA vs CPU section to CUDA vs CPU vs ROCm. https://claude.ai/code/session_0168pSswiaoEf6LEx6UQWfBu
This commit is contained in:
@@ -32,7 +32,7 @@ dev = [
|
||||
line-length = 100
|
||||
indent-width = 4
|
||||
target-version = "py310"
|
||||
exclude = [".git", ".venv", ".venv-cpu", "__pycache__"]
|
||||
exclude = [".git", ".venv", ".venv-cpu", ".venv-rocm", "__pycache__"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "UP", "B", "SIM", "I"]
|
||||
@@ -54,6 +54,11 @@ name = "pytorch-cu124"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-rocm62"
|
||||
url = "https://download.pytorch.org/whl/rocm6.2"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
# Pull torch from the PyTorch CUDA 12.4 index instead of PyPI's CPU-only wheel.
|
||||
# CUDA 12.4 runs on any driver >= 525.60 (RTX 30/40 series all qualify).
|
||||
|
||||
Reference in New Issue
Block a user