From 0e17c60bdc49c348ce00297a11436d3b5dbe3f63 Mon Sep 17 00:00:00 2001 From: LyAhn Date: Fri, 1 May 2026 19:09:33 +0100 Subject: [PATCH] fix(server): guard flash-attn install behind OS platform check The win_amd64 wheel URL was attempted on any OS with matching Python/ torch/CUDA tags. On Linux CUDA setups with VIBEPOD_ENABLE_FLASH_ATTN=1 this caused `uv pip install` to fail with an incompatible wheel; with set -e the script then exited before launch instead of falling back to SDPA. - Add uname -s case statement inside the version-tag match: only set the wheel URL on MINGW*/CYGWIN*/MSYS* (Windows/Git Bash); all other platforms print a clear message and leave FLASH_ATTN_WHEEL_URL empty - Move the install step into a separate `if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]` block so non-Windows platforms skip it entirely - Wrap `uv pip install` in an `if` so a wheel failure is non-fatal and falls through to SDPA regardless of set -e - Update header comment to reflect cross-platform behaviour --- server/start.sh | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/server/start.sh b/server/start.sh index 060cde5..c50277d 100755 --- a/server/start.sh +++ b/server/start.sh @@ -9,8 +9,8 @@ # # Optional CUDA acceleration: # VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh -# Installs a matching third-party Windows flash-attn wheel when the CUDA venv -# uses Python 3.12, torch 2.6.0, and CUDA 12.4. +# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12, +# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA. # # The two modes maintain completely separate virtual environments so their torch # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; @@ -106,20 +106,35 @@ else TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")" CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")" + FLASH_ATTN_WHEEL_URL="" if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then - FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" - echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..." - uv pip install "$FLASH_ATTN_WHEEL_URL" - if validate_flash_attn; then - echo " flash-attn import check passed." - else - echo " flash-attn import check failed; removing it and continuing with SDPA." - uv pip uninstall flash-attn - fi + case "$(uname -s)" in + MINGW*|CYGWIN*|MSYS*) + FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" + echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4 (Windows)..." + ;; + *) + echo " No pre-built flash-attn wheel available for this platform ($(uname -s))." + echo " Continuing with PyTorch SDPA attention." + ;; + esac else echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG." echo " Continuing with PyTorch SDPA attention." fi + + if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]; then + if uv pip install "$FLASH_ATTN_WHEEL_URL"; then + if validate_flash_attn; then + echo " flash-attn import check passed." + else + echo " flash-attn import check failed; removing it and continuing with SDPA." + uv pip uninstall flash-attn + fi + else + echo " flash-attn wheel install failed; continuing with SDPA." + fi + fi fi fi fi