feat: honour VIBEPOD_DEVICE env var for CPU/CUDA device selection

This commit is contained in:
2026-04-28 14:22:38 +01:00
parent 64cf431c2a
commit c8110ccdde
+28 -2
View File
@@ -11,6 +11,11 @@ Startup sequence (background thread):
Generation flow: Generation flow:
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM), POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
ends with {type:"complete"} ends with {type:"complete"}
Device selection:
Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh).
Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available).
If unset, the server auto-detects: CUDA if available, otherwise CPU.
""" """
import asyncio import asyncio
@@ -57,6 +62,24 @@ DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
# Falls back to auto-detection if not set.
def _resolve_device() -> str:
"""Resolve the target device from env var or auto-detect."""
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
if env in ("cpu", "cuda"):
if env == "cuda" and not torch.cuda.is_available():
logger.warning(
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
)
return "cpu"
return env
# Auto-detect
return "cuda" if torch.cuda.is_available() else "cpu"
# ── Global state ──────────────────────────────────────────────────────────────── # ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"] ModelStatus = Literal["downloading", "loading", "online", "error"]
@@ -147,7 +170,10 @@ def _load_model_sync() -> None:
_model_status = "loading" _model_status = "loading"
_download_voices() _download_voices()
_device = "cuda" if torch.cuda.is_available() else "cpu" # Resolve device from env var (set by start.sh --cpu/--cuda) or auto-detect.
_device = _resolve_device()
logger.info("Using device: %s", _device)
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32 load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa" attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
@@ -236,6 +262,7 @@ async def health() -> dict:
body: dict = { body: dict = {
"status": _model_status, "status": _model_status,
"model": MODEL_ID, "model": MODEL_ID,
"device": _device,
"voices": list(_voice_presets.keys()), "voices": list(_voice_presets.keys()),
} }
if _model_status == "downloading": if _model_status == "downloading":
@@ -367,4 +394,3 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
) )