style: apply ruff formatting and lint fixes to server

This commit is contained in:
2026-05-01 19:06:13 +01:00
parent acb615b918
commit 8d4b3f3af7
2 changed files with 49 additions and 74 deletions
+2 -5
View File
@@ -30,15 +30,12 @@ def download() -> str:
from huggingface_hub import snapshot_download
except ImportError:
print(
"ERROR: huggingface_hub is not installed.\n"
"Run: pip install huggingface_hub",
"ERROR: huggingface_hub is not installed.\nRun: pip install huggingface_hub",
file=sys.stderr,
)
sys.exit(1)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
print(f"Checking / downloading model: {MODEL_ID}")
print("(This may take several minutes on first run — the model is ~1 GB)")
+47 -69
View File
@@ -32,9 +32,10 @@ import threading
import time
import types
import urllib.request
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator, Literal, Optional
from typing import Literal
import torch
from fastapi import FastAPI, HTTPException, Request
@@ -50,8 +51,7 @@ SAMPLE_RATE = 24_000
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
"/demo/voices/streaming_model"
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
)
EN_VOICES: dict[str, str] = {
@@ -69,7 +69,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o
# ── Pipeline executor ──────────────────────────────────────────────────────────
# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker).
_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
_decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
@@ -126,9 +126,7 @@ def _cpu_supports_bf16() -> bool:
def _configure_cpu_runtime() -> dict[str, object]:
logical_cpus = os.cpu_count() or 1
default_threads = (
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
)
default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
@@ -157,7 +155,7 @@ _processor = None
_model = None
_device: str = "cpu"
_model_status: ModelStatus = "loading"
_model_error: Optional[str] = None
_model_error: str | None = None
_voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
@@ -204,9 +202,7 @@ def _is_model_cached() -> bool:
try:
from huggingface_hub import snapshot_download
snapshot_download(
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
)
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
return True
except Exception:
return False
@@ -215,9 +211,7 @@ def _is_model_cached() -> bool:
def _download_model() -> None:
from huggingface_hub import snapshot_download
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACE_TOKEN"
)
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
DlTqdm = _make_dl_tqdm()
logger.info("Model not cached — downloading %s...", MODEL_ID)
snapshot_download(
@@ -231,7 +225,7 @@ def _download_model() -> None:
def _download_voices() -> None:
VOICES_DIR.mkdir(parents=True, exist_ok=True)
for name, filename in EN_VOICES.items():
for _name, filename in EN_VOICES.items():
dest = VOICES_DIR / filename
if not dest.exists():
url = f"{VOICE_BASE_URL}/{filename}"
@@ -393,12 +387,8 @@ def _apply_cpu_optimizations(model: object) -> object:
if hasattr(model, "model"):
inner = model.model
if hasattr(inner, "prediction_head"):
_compile_targets.append(
("prediction_head", inner, "prediction_head", False)
)
if hasattr(inner, "acoustic_tokenizer") and hasattr(
inner.acoustic_tokenizer, "decode"
):
_compile_targets.append(("prediction_head", inner, "prediction_head", False))
if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
_compile_targets.append(
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
)
@@ -493,9 +483,7 @@ def _install_generation_optimizations(model: object) -> None:
t_batches = t_batch_cache.get(t_batch_cache_key)
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
t_batches = [
t.repeat(condition.shape[0]).to(
device=condition.device, dtype=condition.dtype
)
t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
for t in scheduler.timesteps
]
t_batch_cache[t_batch_cache_key] = t_batches
@@ -599,12 +587,8 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec
if isinstance(value, tuple):
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
value.key_cache = [
_move_cached_prompt(t, device, dtype) for t in value.key_cache
]
value.value_cache = [
_move_cached_prompt(t, device, dtype) for t in value.value_cache
]
value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
return value
@@ -640,33 +624,33 @@ def _load_model_sync() -> None:
is_cpu = _device == "cpu"
_config["device"] = _device
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0)
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0)
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0)
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
_config["prebuffer_secs"] = _env_float(
"VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0
)
_config["rebuffer_threshold_secs"] = _env_float(
"VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0
)
_config["resume_threshold_secs"] = _env_float(
"VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0
)
_config["default_inference_steps"] = _env_int(
"VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10
)
if is_cpu:
logical_cpus = os.cpu_count() or 1
_config["cpu_threads"] = _env_int(
"VIBEPOD_CPU_THREADS",
max(1, logical_cpus // 2)
if platform.system() == "Windows"
else logical_cpus,
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
)
_config["cpu_interop_threads"] = _env_int(
"VIBEPOD_CPU_INTEROP_THREADS", 1
)
_config["cpu_mkldnn"] = os.environ.get(
"VIBEPOD_CPU_MKLDNN", "1"
).strip() != "0"
_config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
_config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
_processor = _init_processor()
_model = _init_model(_device)
_voice_presets = _load_voice_presets(_device)
_model_status = "online"
logger.info(
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
)
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
logger.info("Configuration: %s", _config)
except Exception as exc:
@@ -697,7 +681,7 @@ class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=10_000)
speaker: str = Field(default=DEFAULT_SPEAKER)
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
inference_steps: int | None = Field(default=None, ge=5, le=20)
@field_validator("text")
@classmethod
@@ -736,8 +720,8 @@ async def health() -> dict:
def _sync_generate(
req: GenerateRequest,
streamer: Optional[object] = None,
cancel_event: Optional[threading.Event] = None,
streamer: object | None = None,
cancel_event: threading.Event | None = None,
) -> str:
"""Blocking inference. Returns the speaker used.
Runs in a thread-pool executor — do not call from the event loop directly.
@@ -748,11 +732,13 @@ def _sync_generate(
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
model_dtype = _model_float_dtype()
voice_preset = _move_cached_prompt(
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
)
voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
steps = (
req.inference_steps
if req.inference_steps is not None
else _config["default_inference_steps"]
)
_model.set_ddpm_inference_steps(num_steps=steps)
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
_model._vibepod_profile = {}
@@ -790,7 +776,7 @@ def _sse(event: dict) -> str:
return f"data: {json.dumps(event)}\n\n"
def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
def _generation_profile() -> dict[str, dict[str, float]] | None:
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
return None
stats = getattr(_model, "_vibepod_profile", None)
@@ -800,9 +786,7 @@ def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
key: {
"count": value["count"],
"seconds": round(value["seconds"], 3),
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3)
if value["count"]
else 0.0,
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
}
for key, value in sorted(stats.items())
}
@@ -843,13 +827,11 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
audio_chunks[i].detach(),
)
def end(self, sample_indices: Optional[torch.Tensor] = None) -> None:
def end(self, sample_indices: torch.Tensor | None = None) -> None:
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [
s.item() if torch.is_tensor(s) else s for s in sample_indices
]
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
@@ -865,8 +847,8 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
accumulated_chunks = []
chunk_count = 0
audio_samples = 0
first_chunk_at: Optional[float] = None
last_chunk_at: Optional[float] = None
first_chunk_at: float | None = None
last_chunk_at: float | None = None
max_chunk_gap = 0.0
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
@@ -881,9 +863,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
# stop_signal=None is the default sentinel that ends the queue.
while True:
try:
chunk = await asyncio.wait_for(
streamer.audio_queues[0].get(), timeout=120.0
)
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
@@ -969,9 +949,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
"elapsed": elapsed,
"speaker": speaker,
"audio_secs": round(audio_secs, 2),
"realtime_factor": round(realtime_factor, 3)
if realtime_factor is not None
else None,
"realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
"chunks": chunk_count,
"first_chunk_secs": round(first_chunk_at - start, 2)
if first_chunk_at is not None