mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-13 03:58:07 +00:00
perf: improve streaming generation pipeline
Add CUDA inference hot-path optimizations, safer attention fallback handling, and generation profiling hooks. Improve SSE streaming, browser buffering telemetry, and playback recovery while preserving default audio quality settings.
This commit is contained in:
@@ -7,6 +7,11 @@
|
||||
# ./start.sh — CUDA mode (default, uses PyTorch CUDA 12.4 wheel, venv: .venv)
|
||||
# ./start.sh --cpu — CPU-only mode (uses PyPI CPU torch wheel, venv: .venv-cpu)
|
||||
#
|
||||
# Optional CUDA acceleration:
|
||||
# VIBEPOD_ENABLE_FLASH_ATTN=1 ./start.sh
|
||||
# Installs a matching third-party Windows flash-attn wheel when the CUDA venv
|
||||
# uses Python 3.12, torch 2.6.0, and CUDA 12.4.
|
||||
#
|
||||
# The two modes maintain completely separate virtual environments so their torch
|
||||
# installations never conflict. UV_PROJECT_ENVIRONMENT tells uv which venv to use;
|
||||
# --no-sources skips [tool.uv.sources] so the CPU run pulls the default PyPI torch wheel.
|
||||
@@ -51,6 +56,19 @@ if ! command -v uv &>/dev/null; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
validate_flash_attn() {
|
||||
uv run python -c "import flash_attn; import triton; import transformers.modeling_utils" &>/dev/null
|
||||
}
|
||||
|
||||
remove_broken_flash_attn() {
|
||||
if uv run python -c "import importlib.util; raise SystemExit(0 if importlib.util.find_spec('flash_attn') else 1)" &>/dev/null; then
|
||||
if ! validate_flash_attn; then
|
||||
echo " Installed flash-attn is not usable in this environment; removing it."
|
||||
uv pip uninstall flash-attn
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Sync Python environment
|
||||
# CPU mode: use .venv-cpu and skip [tool.uv.sources] so uv pulls the
|
||||
@@ -65,6 +83,36 @@ if $CPU_MODE; then
|
||||
else
|
||||
echo "--> Syncing CUDA Python environment (.venv)..."
|
||||
uv sync
|
||||
|
||||
remove_broken_flash_attn
|
||||
|
||||
if [[ "${VIBEPOD_ENABLE_FLASH_ATTN:-0}" == "1" ]]; then
|
||||
echo ""
|
||||
echo "--> Checking optional FlashAttention wheel..."
|
||||
|
||||
if validate_flash_attn; then
|
||||
echo " flash-attn already installed and importable."
|
||||
else
|
||||
PY_TAG="$(uv run python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')")"
|
||||
TORCH_TAG="$(uv run python -c "import torch; print(torch.__version__.split('+', 1)[0])")"
|
||||
CUDA_TAG="$(uv run python -c "import torch; print('cu' + torch.version.cuda.replace('.', ''))")"
|
||||
|
||||
if [[ "$PY_TAG" == "cp312" && "$TORCH_TAG" == "2.6.0" && "$CUDA_TAG" == "cu124" ]]; then
|
||||
FLASH_ATTN_WHEEL_URL="https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%2Bcu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl"
|
||||
echo " Installing flash-attn for Python 3.12, torch 2.6.0, CUDA 12.4..."
|
||||
uv pip install "$FLASH_ATTN_WHEEL_URL"
|
||||
if validate_flash_attn; then
|
||||
echo " flash-attn import check passed."
|
||||
else
|
||||
echo " flash-attn import check failed; removing it and continuing with SDPA."
|
||||
uv pip uninstall flash-attn
|
||||
fi
|
||||
else
|
||||
echo " No known wheel for Python tag $PY_TAG, torch $TORCH_TAG, CUDA $CUDA_TAG."
|
||||
echo " Continuing with PyTorch SDPA attention."
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user