mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
fix(server): guard flash-attn install behind OS platform check
The win_amd64 wheel URL was attempted on any OS with matching Python/ torch/CUDA tags. On Linux CUDA setups with VIBEPOD_ENABLE_FLASH_ATTN=1 this caused `uv pip install` to fail with an incompatible wheel; with set -e the script then exited before launch instead of falling back to SDPA. - Add uname -s case statement inside the version-tag match: only set the wheel URL on MINGW*/CYGWIN*/MSYS* (Windows/Git Bash); all other platforms print a clear message and leave FLASH_ATTN_WHEEL_URL empty - Move the install step into a separate `if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]` block so non-Windows platforms skip it entirely - Wrap `uv pip install` in an `if` so a wheel failure is non-fatal and falls through to SDPA regardless of set -e - Update header comment to reflect cross-platform behaviour
This commit is contained in:
+26
-11
@@ -9,8 +9,8 @@
|
|||||||
#
|
#
|
||||||
# Optional CUDA acceleration:
|
# Optional CUDA acceleration:
|
||||||
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
||||||
# Installs a matching third-party Windows flash-attn wheel when the CUDA venv
|
# Installs a pre-built flash-attn wheel when the CUDA venv uses Python 3.12,
|
||||||
# uses Python 3.12, torch 2.6.0, and CUDA 12.4.
|
# torch 2.6.0, and CUDA 12.4 on Windows. Other platforms fall back to SDPA.
|
||||||
#
|
#
|
||||||
# The two modes maintain completely separate virtual environments so their torch
|
# The two modes maintain completely separate virtual environments so their torch
|
||||||
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
||||||
@@ -106,20 +106,35 @@ else
|
|||||||
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
|
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
|
||||||
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
|
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
|
||||||
|
|
||||||
|
FLASH_ATTN_WHEEL_URL=""
|
||||||
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
|
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
|
||||||
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
|
case "$(uname -s)" in
|
||||||
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..."
|
MINGW*|CYGWIN*|MSYS*)
|
||||||
uv pip install "$FLASH_ATTN_WHEEL_URL"
|
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
|
||||||
if validate_flash_attn; then
|
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4 (Windows)..."
|
||||||
echo " flash-attn import check passed."
|
;;
|
||||||
else
|
*)
|
||||||
echo " flash-attn import check failed; removing it and continuing with SDPA."
|
echo " No pre-built flash-attn wheel available for this platform ($(uname -s))."
|
||||||
uv pip uninstall flash-attn
|
echo " Continuing with PyTorch SDPA attention."
|
||||||
fi
|
;;
|
||||||
|
esac
|
||||||
else
|
else
|
||||||
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
|
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
|
||||||
echo " Continuing with PyTorch SDPA attention."
|
echo " Continuing with PyTorch SDPA attention."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [[ -n "$FLASH_ATTN_WHEEL_URL" ]]; then
|
||||||
|
if uv pip install "$FLASH_ATTN_WHEEL_URL"; then
|
||||||
|
if validate_flash_attn; then
|
||||||
|
echo " flash-attn import check passed."
|
||||||
|
else
|
||||||
|
echo " flash-attn import check failed; removing it and continuing with SDPA."
|
||||||
|
uv pip uninstall flash-attn
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo " flash-attn wheel install failed; continuing with SDPA."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|||||||
Reference in New Issue
Block a user