style: apply ruff formatting and lint fixes to server

This commit is contained in:
2026-05-01 19:06:13 +01:00
parent acb615b918
commit 8d4b3f3af7
2 changed files with 49 additions and 74 deletions
+2 -5
View File
@@ -30,15 +30,12 @@ def download() -> str:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
except ImportError: except ImportError:
print( print(
"ERROR: huggingface_hub is not installed.\n" "ERROR: huggingface_hub is not installed.\nRun: pip install huggingface_hub",
"Run: pip install huggingface_hub",
file=sys.stderr, file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get( token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
"HUGGINGFACE_TOKEN"
)
print(f"Checking / downloading model: {MODEL_ID}") print(f"Checking / downloading model: {MODEL_ID}")
print("(This may take several minutes on first run — the model is ~1 GB)") print("(This may take several minutes on first run — the model is ~1 GB)")
+47 -69
View File
@@ -32,9 +32,10 @@ import threading
import time import time
import types import types
import urllib.request import urllib.request
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator, Literal, Optional from typing import Literal
import torch import torch
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
@@ -50,8 +51,7 @@ SAMPLE_RATE = 24_000
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model" VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = ( VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main" "https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
"/demo/voices/streaming_model"
) )
EN_VOICES: dict[str, str] = { EN_VOICES: dict[str, str] = {
@@ -69,7 +69,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o
# ── Pipeline executor ────────────────────────────────────────────────────────── # ── Pipeline executor ──────────────────────────────────────────────────────────
# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker). # Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker).
_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None _decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
# ── Device selection ──────────────────────────────────────────────────────────── # ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
@@ -126,9 +126,7 @@ def _cpu_supports_bf16() -> bool:
def _configure_cpu_runtime() -> dict[str, object]: def _configure_cpu_runtime() -> dict[str, object]:
logical_cpus = os.cpu_count() or 1 logical_cpus = os.cpu_count() or 1
default_threads = ( default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
)
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads) intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
@@ -157,7 +155,7 @@ _processor = None
_model = None _model = None
_device: str = "cpu" _device: str = "cpu"
_model_status: ModelStatus = "loading" _model_status: ModelStatus = "loading"
_model_error: Optional[str] = None _model_error: str | None = None
_voice_presets: dict[str, object] = {} _voice_presets: dict[str, object] = {}
_load_lock = threading.Lock() _load_lock = threading.Lock()
_generation_lock = asyncio.Lock() _generation_lock = asyncio.Lock()
@@ -204,9 +202,7 @@ def _is_model_cached() -> bool:
try: try:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
)
return True return True
except Exception: except Exception:
return False return False
@@ -215,9 +211,7 @@ def _is_model_cached() -> bool:
def _download_model() -> None: def _download_model() -> None:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get( token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
"HUGGINGFACE_TOKEN"
)
DlTqdm = _make_dl_tqdm() DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID) logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download( snapshot_download(
@@ -231,7 +225,7 @@ def _download_model() -> None:
def _download_voices() -> None: def _download_voices() -> None:
VOICES_DIR.mkdir(parents=True, exist_ok=True) VOICES_DIR.mkdir(parents=True, exist_ok=True)
for name, filename in EN_VOICES.items(): for _name, filename in EN_VOICES.items():
dest = VOICES_DIR / filename dest = VOICES_DIR / filename
if not dest.exists(): if not dest.exists():
url = f"{VOICE_BASE_URL}/{filename}" url = f"{VOICE_BASE_URL}/{filename}"
@@ -393,12 +387,8 @@ def _apply_cpu_optimizations(model: object) -> object:
if hasattr(model, "model"): if hasattr(model, "model"):
inner = model.model inner = model.model
if hasattr(inner, "prediction_head"): if hasattr(inner, "prediction_head"):
_compile_targets.append( _compile_targets.append(("prediction_head", inner, "prediction_head", False))
("prediction_head", inner, "prediction_head", False) if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
)
if hasattr(inner, "acoustic_tokenizer") and hasattr(
inner.acoustic_tokenizer, "decode"
):
_compile_targets.append( _compile_targets.append(
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False) ("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
) )
@@ -493,9 +483,7 @@ def _install_generation_optimizations(model: object) -> None:
t_batches = t_batch_cache.get(t_batch_cache_key) t_batches = t_batch_cache.get(t_batch_cache_key)
if t_batches is None or len(t_batches) != len(scheduler.timesteps): if t_batches is None or len(t_batches) != len(scheduler.timesteps):
t_batches = [ t_batches = [
t.repeat(condition.shape[0]).to( t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
device=condition.device, dtype=condition.dtype
)
for t in scheduler.timesteps for t in scheduler.timesteps
] ]
t_batch_cache[t_batch_cache_key] = t_batches t_batch_cache[t_batch_cache_key] = t_batches
@@ -599,12 +587,8 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec
if isinstance(value, tuple): if isinstance(value, tuple):
return tuple(_move_cached_prompt(v, device, dtype) for v in value) return tuple(_move_cached_prompt(v, device, dtype) for v in value)
if hasattr(value, "key_cache") and hasattr(value, "value_cache"): if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
value.key_cache = [ value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
_move_cached_prompt(t, device, dtype) for t in value.key_cache value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
]
value.value_cache = [
_move_cached_prompt(t, device, dtype) for t in value.value_cache
]
return value return value
@@ -640,33 +624,33 @@ def _load_model_sync() -> None:
is_cpu = _device == "cpu" is_cpu = _device == "cpu"
_config["device"] = _device _config["device"] = _device
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1) _config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0) _config["prebuffer_secs"] = _env_float(
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0) "VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0) )
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) _config["rebuffer_threshold_secs"] = _env_float(
"VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0
)
_config["resume_threshold_secs"] = _env_float(
"VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0
)
_config["default_inference_steps"] = _env_int(
"VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10
)
if is_cpu: if is_cpu:
logical_cpus = os.cpu_count() or 1 logical_cpus = os.cpu_count() or 1
_config["cpu_threads"] = _env_int( _config["cpu_threads"] = _env_int(
"VIBEPOD_CPU_THREADS", "VIBEPOD_CPU_THREADS",
max(1, logical_cpus // 2) max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
if platform.system() == "Windows"
else logical_cpus,
) )
_config["cpu_interop_threads"] = _env_int( _config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
"VIBEPOD_CPU_INTEROP_THREADS", 1 _config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
)
_config["cpu_mkldnn"] = os.environ.get(
"VIBEPOD_CPU_MKLDNN", "1"
).strip() != "0"
_processor = _init_processor() _processor = _init_processor()
_model = _init_model(_device) _model = _init_model(_device)
_voice_presets = _load_voice_presets(_device) _voice_presets = _load_voice_presets(_device)
_model_status = "online" _model_status = "online"
logger.info( logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
)
logger.info("Configuration: %s", _config) logger.info("Configuration: %s", _config)
except Exception as exc: except Exception as exc:
@@ -697,7 +681,7 @@ class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000) text: str = Field(..., min_length=1, max_length=10_000)
speaker: str = Field(default=DEFAULT_SPEAKER) speaker: str = Field(default=DEFAULT_SPEAKER)
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0) cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
inference_steps: Optional[int] = Field(default=None, ge=5, le=20) inference_steps: int | None = Field(default=None, ge=5, le=20)
@field_validator("text") @field_validator("text")
@classmethod @classmethod
@@ -736,8 +720,8 @@ async def health() -> dict:
def _sync_generate( def _sync_generate(
req: GenerateRequest, req: GenerateRequest,
streamer: Optional[object] = None, streamer: object | None = None,
cancel_event: Optional[threading.Event] = None, cancel_event: threading.Event | None = None,
) -> str: ) -> str:
"""Blocking inference. Returns the speaker used. """Blocking inference. Returns the speaker used.
Runs in a thread-pool executor — do not call from the event loop directly. Runs in a thread-pool executor — do not call from the event loop directly.
@@ -748,11 +732,13 @@ def _sync_generate(
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
model_dtype = _model_float_dtype() model_dtype = _model_float_dtype()
voice_preset = _move_cached_prompt( voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
)
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"] steps = (
req.inference_steps
if req.inference_steps is not None
else _config["default_inference_steps"]
)
_model.set_ddpm_inference_steps(num_steps=steps) _model.set_ddpm_inference_steps(num_steps=steps)
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1": if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
_model._vibepod_profile = {} _model._vibepod_profile = {}
@@ -790,7 +776,7 @@ def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n" return f"data: {json.dumps(event)}\n\n"
def _generation_profile() -> Optional[dict[str, dict[str, float]]]: def _generation_profile() -> dict[str, dict[str, float]] | None:
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1": if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
return None return None
stats = getattr(_model, "_vibepod_profile", None) stats = getattr(_model, "_vibepod_profile", None)
@@ -800,9 +786,7 @@ def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
key: { key: {
"count": value["count"], "count": value["count"],
"seconds": round(value["seconds"], 3), "seconds": round(value["seconds"], 3),
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) "avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
if value["count"]
else 0.0,
} }
for key, value in sorted(stats.items()) for key, value in sorted(stats.items())
} }
@@ -843,13 +827,11 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
audio_chunks[i].detach(), audio_chunks[i].detach(),
) )
def end(self, sample_indices: Optional[torch.Tensor] = None) -> None: def end(self, sample_indices: torch.Tensor | None = None) -> None:
if sample_indices is None: if sample_indices is None:
indices_to_end = range(self.batch_size) indices_to_end = range(self.batch_size)
else: else:
indices_to_end = [ indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
s.item() if torch.is_tensor(s) else s for s in sample_indices
]
for idx in indices_to_end: for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]: if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(
@@ -865,8 +847,8 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
accumulated_chunks = [] accumulated_chunks = []
chunk_count = 0 chunk_count = 0
audio_samples = 0 audio_samples = 0
first_chunk_at: Optional[float] = None first_chunk_at: float | None = None
last_chunk_at: Optional[float] = None last_chunk_at: float | None = None
max_chunk_gap = 0.0 max_chunk_gap = 0.0
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
@@ -881,9 +863,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
# stop_signal=None is the default sentinel that ends the queue. # stop_signal=None is the default sentinel that ends the queue.
while True: while True:
try: try:
chunk = await asyncio.wait_for( chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
streamer.audio_queues[0].get(), timeout=120.0
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
cancel_event.set() cancel_event.set()
future.cancel() future.cancel()
@@ -969,9 +949,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
"elapsed": elapsed, "elapsed": elapsed,
"speaker": speaker, "speaker": speaker,
"audio_secs": round(audio_secs, 2), "audio_secs": round(audio_secs, 2),
"realtime_factor": round(realtime_factor, 3) "realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
if realtime_factor is not None
else None,
"chunks": chunk_count, "chunks": chunk_count,
"first_chunk_secs": round(first_chunk_at - start, 2) "first_chunk_secs": round(first_chunk_at - start, 2)
if first_chunk_at is not None if first_chunk_at is not None