diff --git a/server/download_model.py b/server/download_model.py index 1377d17..369a0fb 100644 --- a/server/download_model.py +++ b/server/download_model.py @@ -30,15 +30,12 @@ def download() -> str: from huggingface_hub import snapshot_download except ImportError: print( - "ERROR: huggingface_hub is not installed.\n" - "Run: pip install huggingface_hub", + "ERROR: huggingface_hub is not installed.\nRun: pip install huggingface_hub", file=sys.stderr, ) sys.exit(1) - token: str | None = os.environ.get("HF_TOKEN") or os.environ.get( - "HUGGINGFACE_TOKEN" - ) + token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") print(f"Checking / downloading model: {MODEL_ID}") print("(This may take several minutes on first run — the model is ~1 GB)") diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 74fbd94..0c32e6c 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -32,9 +32,10 @@ import threading import time import types import urllib.request +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator, Literal, Optional +from typing import Literal import torch from fastapi import FastAPI, HTTPException, Request @@ -50,8 +51,7 @@ SAMPLE_RATE = 24_000 VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model" VOICE_BASE_URL = ( - "https://raw.githubusercontent.com/microsoft/VibeVoice/main" - "/demo/voices/streaming_model" + "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model" ) EN_VOICES: dict[str, str] = { @@ -69,7 +69,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o # ── Pipeline executor ────────────────────────────────────────────────────────── # Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker). -_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None +_decode_executor: concurrent.futures.ThreadPoolExecutor | None = None # ── Device selection ──────────────────────────────────────────────────────────── # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. @@ -126,9 +126,7 @@ def _cpu_supports_bf16() -> bool: def _configure_cpu_runtime() -> dict[str, object]: logical_cpus = os.cpu_count() or 1 - default_threads = ( - max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus - ) + default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads) interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" @@ -157,7 +155,7 @@ _processor = None _model = None _device: str = "cpu" _model_status: ModelStatus = "loading" -_model_error: Optional[str] = None +_model_error: str | None = None _voice_presets: dict[str, object] = {} _load_lock = threading.Lock() _generation_lock = asyncio.Lock() @@ -204,9 +202,7 @@ def _is_model_cached() -> bool: try: from huggingface_hub import snapshot_download - snapshot_download( - MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS - ) + snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS) return True except Exception: return False @@ -215,9 +211,7 @@ def _is_model_cached() -> bool: def _download_model() -> None: from huggingface_hub import snapshot_download - token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get( - "HUGGINGFACE_TOKEN" - ) + token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") DlTqdm = _make_dl_tqdm() logger.info("Model not cached — downloading %s...", MODEL_ID) snapshot_download( @@ -231,7 +225,7 @@ def _download_model() -> None: def _download_voices() -> None: VOICES_DIR.mkdir(parents=True, exist_ok=True) - for name, filename in EN_VOICES.items(): + for _name, filename in EN_VOICES.items(): dest = VOICES_DIR / filename if not dest.exists(): url = f"{VOICE_BASE_URL}/{filename}" @@ -393,12 +387,8 @@ def _apply_cpu_optimizations(model: object) -> object: if hasattr(model, "model"): inner = model.model if hasattr(inner, "prediction_head"): - _compile_targets.append( - ("prediction_head", inner, "prediction_head", False) - ) - if hasattr(inner, "acoustic_tokenizer") and hasattr( - inner.acoustic_tokenizer, "decode" - ): + _compile_targets.append(("prediction_head", inner, "prediction_head", False)) + if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"): _compile_targets.append( ("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False) ) @@ -493,9 +483,7 @@ def _install_generation_optimizations(model: object) -> None: t_batches = t_batch_cache.get(t_batch_cache_key) if t_batches is None or len(t_batches) != len(scheduler.timesteps): t_batches = [ - t.repeat(condition.shape[0]).to( - device=condition.device, dtype=condition.dtype - ) + t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype) for t in scheduler.timesteps ] t_batch_cache[t_batch_cache_key] = t_batches @@ -599,12 +587,8 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec if isinstance(value, tuple): return tuple(_move_cached_prompt(v, device, dtype) for v in value) if hasattr(value, "key_cache") and hasattr(value, "value_cache"): - value.key_cache = [ - _move_cached_prompt(t, device, dtype) for t in value.key_cache - ] - value.value_cache = [ - _move_cached_prompt(t, device, dtype) for t in value.value_cache - ] + value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache] + value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache] return value @@ -640,33 +624,33 @@ def _load_model_sync() -> None: is_cpu = _device == "cpu" _config["device"] = _device _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) - _config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0) - _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0) - _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0) - _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) + _config["prebuffer_secs"] = _env_float( + "VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0 + ) + _config["rebuffer_threshold_secs"] = _env_float( + "VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0 + ) + _config["resume_threshold_secs"] = _env_float( + "VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0 + ) + _config["default_inference_steps"] = _env_int( + "VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10 + ) if is_cpu: logical_cpus = os.cpu_count() or 1 _config["cpu_threads"] = _env_int( "VIBEPOD_CPU_THREADS", - max(1, logical_cpus // 2) - if platform.system() == "Windows" - else logical_cpus, + max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus, ) - _config["cpu_interop_threads"] = _env_int( - "VIBEPOD_CPU_INTEROP_THREADS", 1 - ) - _config["cpu_mkldnn"] = os.environ.get( - "VIBEPOD_CPU_MKLDNN", "1" - ).strip() != "0" + _config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) + _config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" _processor = _init_processor() _model = _init_model(_device) _voice_presets = _load_voice_presets(_device) _model_status = "online" - logger.info( - "Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()) - ) + logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())) logger.info("Configuration: %s", _config) except Exception as exc: @@ -697,7 +681,7 @@ class GenerateRequest(BaseModel): text: str = Field(..., min_length=1, max_length=10_000) speaker: str = Field(default=DEFAULT_SPEAKER) cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0) - inference_steps: Optional[int] = Field(default=None, ge=5, le=20) + inference_steps: int | None = Field(default=None, ge=5, le=20) @field_validator("text") @classmethod @@ -736,8 +720,8 @@ async def health() -> dict: def _sync_generate( req: GenerateRequest, - streamer: Optional[object] = None, - cancel_event: Optional[threading.Event] = None, + streamer: object | None = None, + cancel_event: threading.Event | None = None, ) -> str: """Blocking inference. Returns the speaker used. Runs in a thread-pool executor — do not call from the event loop directly. @@ -748,11 +732,13 @@ def _sync_generate( speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER model_dtype = _model_float_dtype() - voice_preset = _move_cached_prompt( - copy.deepcopy(_voice_presets[speaker]), _device, model_dtype - ) + voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype) - steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"] + steps = ( + req.inference_steps + if req.inference_steps is not None + else _config["default_inference_steps"] + ) _model.set_ddpm_inference_steps(num_steps=steps) if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1": _model._vibepod_profile = {} @@ -790,7 +776,7 @@ def _sse(event: dict) -> str: return f"data: {json.dumps(event)}\n\n" -def _generation_profile() -> Optional[dict[str, dict[str, float]]]: +def _generation_profile() -> dict[str, dict[str, float]] | None: if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1": return None stats = getattr(_model, "_vibepod_profile", None) @@ -800,9 +786,7 @@ def _generation_profile() -> Optional[dict[str, dict[str, float]]]: key: { "count": value["count"], "seconds": round(value["seconds"], 3), - "avg_ms": round(value["seconds"] * 1000 / value["count"], 3) - if value["count"] - else 0.0, + "avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0, } for key, value in sorted(stats.items()) } @@ -843,13 +827,11 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: audio_chunks[i].detach(), ) - def end(self, sample_indices: Optional[torch.Tensor] = None) -> None: + def end(self, sample_indices: torch.Tensor | None = None) -> None: if sample_indices is None: indices_to_end = range(self.batch_size) else: - indices_to_end = [ - s.item() if torch.is_tensor(s) else s for s in sample_indices - ] + indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] for idx in indices_to_end: if idx < self.batch_size and not self.finished_flags[idx]: self.loop.call_soon_threadsafe( @@ -865,8 +847,8 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: accumulated_chunks = [] chunk_count = 0 audio_samples = 0 - first_chunk_at: Optional[float] = None - last_chunk_at: Optional[float] = None + first_chunk_at: float | None = None + last_chunk_at: float | None = None max_chunk_gap = 0.0 speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER @@ -881,9 +863,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: # stop_signal=None is the default sentinel that ends the queue. while True: try: - chunk = await asyncio.wait_for( - streamer.audio_queues[0].get(), timeout=120.0 - ) + chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0) except asyncio.TimeoutError: cancel_event.set() future.cancel() @@ -969,9 +949,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse: "elapsed": elapsed, "speaker": speaker, "audio_secs": round(audio_secs, 2), - "realtime_factor": round(realtime_factor, 3) - if realtime_factor is not None - else None, + "realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None, "chunks": chunk_count, "first_chunk_secs": round(first_chunk_at - start, 2) if first_chunk_at is not None