diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index e005c57..d273933 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -11,6 +11,11 @@ Startup sequence (background thread): Generation flow: POST /generate -> SSE stream of audio_chunk events (base64 float32 PCM), ends with {type:"complete"} + +Device selection: + Set VIBEPOD_DEVICE=cpu to force CPU inference (e.g. via --cpu flag in start.sh). + Set VIBEPOD_DEVICE=cuda to force CUDA (default when a GPU is available). + If unset, the server auto-detects: CUDA if available, otherwise CPU. """ import asyncio @@ -57,6 +62,24 @@ DEFAULT_SPEAKER = "carter" _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] +# ── Device selection ──────────────────────────────────────────────────────────── +# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. +# Falls back to auto-detection if not set. + +def _resolve_device() -> str: + """Resolve the target device from env var or auto-detect.""" + env = os.environ.get("VIBEPOD_DEVICE", "").strip().lower() + if env in ("cpu", "cuda"): + if env == "cuda" and not torch.cuda.is_available(): + logger.warning( + "VIBEPOD_DEVICE=cuda requested but CUDA is not available — falling back to CPU." + ) + return "cpu" + return env + # Auto-detect + return "cuda" if torch.cuda.is_available() else "cpu" + + # ── Global state ──────────────────────────────────────────────────────────────── ModelStatus = Literal["downloading", "loading", "online", "error"] @@ -147,7 +170,10 @@ def _load_model_sync() -> None: _model_status = "loading" _download_voices() - _device = "cuda" if torch.cuda.is_available() else "cpu" + # Resolve device from env var (set by start.sh --cpu/--cuda) or auto-detect. + _device = _resolve_device() + logger.info("Using device: %s", _device) + load_dtype = torch.bfloat16 if _device == "cuda" else torch.float32 attn_impl = "flash_attention_2" if _device == "cuda" else "sdpa" @@ -236,6 +262,7 @@ async def health() -> dict: body: dict = { "status": _model_status, "model": MODEL_ID, + "device": _device, "voices": list(_voice_presets.keys()), } if _model_status == "downloading": @@ -367,4 +394,3 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) -