fix(server): guard flash-attn install behind OS platform check

The win_amd64 wheel URL was attempted on any OS with matching Python/
torch/CUDA tags. On Linux CUDA setups with VIBEPOD_ENABLE_FLASH_ATTN=1
this caused `uv pip install` to fail with an incompatible wheel; with
set -e the script then exited before launch instead of falling back to
SDPA.

- Add uname -s case statement inside the version-tag match: only set
  the wheel URL on MINGW*/CYGWIN*/MSYS* (Windows/Git Bash); all other
  platforms print a clear message and leave FLASH_ATTN_WHEEL_URL empty
- Move the install step into a separate `if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]`
  block so non-Windows platforms skip it entirely
- Wrap `uv pip install` in an `if` so a wheel failure is non-fatal and
  falls through to SDPA regardless of set -e
- Update header comment to reflect cross-platform behaviour
This commit is contained in:
2026-05-01 19:09:33 +01:00
parent 8d4b3f3af7
commit 0e17c60bdc
+26 -11
View File
@@ -9,8 +9,8 @@
# #
# Optional CUDA acceleration: # Optional CUDA acceleration:
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh # VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
# Installs a matching third-party Windows flash-attn wheel when the CUDA venv # Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
# uses Python 3.12, torch 2.6.0, and CUDA 12.4. # torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
# #
# The two modes maintain completely separate virtual environments so their torch # The two modes maintain completely separate virtual environments so their torch
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
@@ -106,20 +106,35 @@ else
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")" TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")" CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
FLASH_ATTN_WHEEL_URL=""
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" case "$(uname -s)" in
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..." MINGW*|CYGWIN*|MSYS*)
uv pip install "$FLASH_ATTN_WHEEL_URL" FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
if validate_flash_attn; then echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4 (Windows)..."
echo " flash-attn import check passed." ;;
else *)
echo " flash-attn import check failed; removing it and continuing with SDPA." echo " No pre-built flash-attn wheel available for this platform ($(uname -s))."
uv pip uninstall flash-attn echo " Continuing with PyTorch SDPA attention."
fi ;;
esac
else else
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG." echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
echo " Continuing with PyTorch SDPA attention." echo " Continuing with PyTorch SDPA attention."
fi fi
if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]; then
if uv pip install "$FLASH_ATTN_WHEEL_URL"; then
if validate_flash_attn; then
echo " flash-attn import check passed."
else
echo " flash-attn import check failed; removing it and continuing with SDPA."
uv pip uninstall flash-attn
fi
else
echo " flash-attn wheel install failed; continuing with SDPA."
fi
fi
fi fi
fi fi
fi fi