mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
feat: honour VIBEPOD_DEVICE env var for CPU/CUDA device selection
This commit is contained in:
@@ -11,6 +11,11 @@ Startup sequence (background thread):
|
|||||||
Generation flow:
|
Generation flow:
|
||||||
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
|
POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM),
|
||||||
ends with {type:"complete"}
|
ends with {type:"complete"}
|
||||||
|
|
||||||
|
Device selection:
|
||||||
|
Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh).
|
||||||
|
Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available).
|
||||||
|
If unset, the server auto-detects: CUDA if available, otherwise CPU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -57,6 +62,24 @@ DEFAULT_SPEAKER = "carter"
|
|||||||
|
|
||||||
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
|
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
|
||||||
|
|
||||||
|
# ── Device selection ────────────────────────────────────────────────────────────
|
||||||
|
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
||||||
|
# Falls back to auto-detection if not set.
|
||||||
|
|
||||||
|
def _resolve_device() -> str:
|
||||||
|
"""Resolve the target device from env var or auto-detect."""
|
||||||
|
env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower()
|
||||||
|
if env in ("cpu", "cuda"):
|
||||||
|
if env == "cuda" and not torch.cuda.is_available():
|
||||||
|
logger.warning(
|
||||||
|
"VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU."
|
||||||
|
)
|
||||||
|
return "cpu"
|
||||||
|
return env
|
||||||
|
# Auto-detect
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
# ── Global state ────────────────────────────────────────────────────────────────
|
# ── Global state ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
ModelStatus = Literal["downloading", "loading", "online", "error"]
|
ModelStatus = Literal["downloading", "loading", "online", "error"]
|
||||||
@@ -147,7 +170,10 @@ def _load_model_sync() -> None:
|
|||||||
_model_status = "loading"
|
_model_status = "loading"
|
||||||
_download_voices()
|
_download_voices()
|
||||||
|
|
||||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
# Resolve device from env var (set by start.sh --cpu/--cuda) or auto-detect.
|
||||||
|
_device = _resolve_device()
|
||||||
|
logger.info("Using device: %s", _device)
|
||||||
|
|
||||||
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
|
load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32
|
||||||
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
|
attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa"
|
||||||
|
|
||||||
@@ -236,6 +262,7 @@ async def health() -> dict:
|
|||||||
body: dict = {
|
body: dict = {
|
||||||
"status": _model_status,
|
"status": _model_status,
|
||||||
"model": MODEL_ID,
|
"model": MODEL_ID,
|
||||||
|
"device": _device,
|
||||||
"voices": list(_voice_presets.keys()),
|
"voices": list(_voice_presets.keys()),
|
||||||
}
|
}
|
||||||
if _model_status == "downloading":
|
if _model_status == "downloading":
|
||||||
@@ -367,4 +394,3 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user