mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
fix(server): guard flash-attn install behind OS platform check
The win_amd64 wheel URL was attempted on any OS with matching Python/ torch/CUDA tags. On Linux CUDA setups with VIBEPOD_ENABLE_FLASH_ATTN=1 this caused `uv pip install` to fail with an incompatible wheel; with set -e the script then exited before launch instead of falling back to SDPA. - Add uname -s case statement inside the version-tag match: only set the wheel URL on MINGW*/CYGWIN*/MSYS* (Windows/Git Bash); all other platforms print a clear message and leave FLASH_ATTN_WHEEL_URL empty - Move the install step into a separate `if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]` block so non-Windows platforms skip it entirely - Wrap `uv pip install` in an `if` so a wheel failure is non-fatal and falls through to SDPA regardless of set -e - Update header comment to reflect cross-platform behaviour
This commit is contained in:
+21
-6
@@ -9,8 +9,8 @@
|
||||
#
|
||||
# Optional CUDA acceleration:
|
||||
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
||||
# Installs a matching third-party Windows flash-attn wheel when the CUDA venv
|
||||
# uses Python 3.12, torch 2.6.0, and CUDA 12.4.
|
||||
# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
|
||||
# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
|
||||
#
|
||||
# The two modes maintain completely separate virtual environments so their torch
|
||||
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
||||
@@ -106,10 +106,25 @@ else
|
||||
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
|
||||
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
|
||||
|
||||
FLASH_ATTN_WHEEL_URL=""
|
||||
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
|
||||
case "$(uname -s)" in
|
||||
MINGW*|CYGWIN*|MSYS*)
|
||||
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
|
||||
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..."
|
||||
uv pip install "$FLASH_ATTN_WHEEL_URL"
|
||||
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4 (Windows)..."
|
||||
;;
|
||||
*)
|
||||
echo " No pre-built flash-attn wheel available for this platform ($(uname -s))."
|
||||
echo " Continuing with PyTorch SDPA attention."
|
||||
;;
|
||||
esac
|
||||
else
|
||||
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
|
||||
echo " Continuing with PyTorch SDPA attention."
|
||||
fi
|
||||
|
||||
if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]; then
|
||||
if uv pip install "$FLASH_ATTN_WHEEL_URL"; then
|
||||
if validate_flash_attn; then
|
||||
echo " flash-attn import check passed."
|
||||
else
|
||||
@@ -117,8 +132,8 @@ else
|
||||
uv pip uninstall flash-attn
|
||||
fi
|
||||
else
|
||||
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
|
||||
echo " Continuing with PyTorch SDPA attention."
|
||||
echo " flash-attn wheel install failed; continuing with SDPA."
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
Reference in New Issue
Block a user