diff --git a/server/start.sh b/server/start.sh index 060cde5..c50277d 100755 --- a/server/start.sh +++ b/server/start.sh @@ -9,8 +9,8 @@ # # Optional CUDA acceleration: # VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh -# Installs a matching third-party Windows flash-attn wheel when the CUDA venv -# uses Python 3.12, torch 2.6.0, and CUDA 12.4. +# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12, +# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA. # # The two modes maintain completely separate virtual environments so their torch # installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use; @@ -106,20 +106,35 @@ else TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")" CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")" + FLASH_ATTN_WHEEL_URL="" if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then - FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" - echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..." - uv pip install "$FLASH_ATTN_WHEEL_URL" - if validate_flash_attn; then - echo " flash-attn import check passed." - else - echo " flash-attn import check failed; removing it and continuing with SDPA." - uv pip uninstall flash-attn - fi + case "$(uname -s)" in + MINGW*|CYGWIN*|MSYS*) + FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl" + echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4 (Windows)..." + ;; + *) + echo " No pre-built flash-attn wheel available for this platform ($(uname -s))." + echo " Continuing with PyTorch SDPA attention." + ;; + esac else echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG." echo " Continuing with PyTorch SDPA attention." fi + + if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]; then + if uv pip install "$FLASH_ATTN_WHEEL_URL"; then + if validate_flash_attn; then + echo " flash-attn import check passed." + else + echo " flash-attn import check failed; removing it and continuing with SDPA." + uv pip uninstall flash-attn + fi + else + echo " flash-attn wheel install failed; continuing with SDPA." + fi + fi fi fi fi