mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
style: apply ruff formatting and lint fixes to server
This commit is contained in:
@@ -30,15 +30,12 @@ def download() -> str:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError:
|
||||
print(
|
||||
"ERROR: huggingface_hub is not installed.\n"
|
||||
"Run: pip install huggingface_hub",
|
||||
"ERROR: huggingface_hub is not installed.\nRun: pip install huggingface_hub",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||
"HUGGINGFACE_TOKEN"
|
||||
)
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||
|
||||
print(f"Checking / downloading model: {MODEL_ID}")
|
||||
print("(This may take several minutes on first run — the model is ~1 GB)")
|
||||
|
||||
+47
-69
@@ -32,9 +32,10 @@ import threading
|
||||
import time
|
||||
import types
|
||||
import urllib.request
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
@@ -50,8 +51,7 @@ SAMPLE_RATE = 24_000
|
||||
|
||||
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
||||
VOICE_BASE_URL = (
|
||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
|
||||
"/demo/voices/streaming_model"
|
||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
||||
)
|
||||
|
||||
EN_VOICES: dict[str, str] = {
|
||||
@@ -69,7 +69,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o
|
||||
# ── Pipeline executor ──────────────────────────────────────────────────────────
|
||||
# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker).
|
||||
|
||||
_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
||||
_decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
|
||||
|
||||
# ── Device selection ────────────────────────────────────────────────────────────
|
||||
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
||||
@@ -126,9 +126,7 @@ def _cpu_supports_bf16() -> bool:
|
||||
|
||||
def _configure_cpu_runtime() -> dict[str, object]:
|
||||
logical_cpus = os.cpu_count() or 1
|
||||
default_threads = (
|
||||
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
||||
)
|
||||
default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
||||
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
|
||||
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
||||
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||
@@ -157,7 +155,7 @@ _processor = None
|
||||
_model = None
|
||||
_device: str = "cpu"
|
||||
_model_status: ModelStatus = "loading"
|
||||
_model_error: Optional[str] = None
|
||||
_model_error: str | None = None
|
||||
_voice_presets: dict[str, object] = {}
|
||||
_load_lock = threading.Lock()
|
||||
_generation_lock = asyncio.Lock()
|
||||
@@ -204,9 +202,7 @@ def _is_model_cached() -> bool:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
|
||||
)
|
||||
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -215,9 +211,7 @@ def _is_model_cached() -> bool:
|
||||
def _download_model() -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
|
||||
"HUGGINGFACE_TOKEN"
|
||||
)
|
||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||
DlTqdm = _make_dl_tqdm()
|
||||
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
||||
snapshot_download(
|
||||
@@ -231,7 +225,7 @@ def _download_model() -> None:
|
||||
|
||||
def _download_voices() -> None:
|
||||
VOICES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
for name, filename in EN_VOICES.items():
|
||||
for _name, filename in EN_VOICES.items():
|
||||
dest = VOICES_DIR / filename
|
||||
if not dest.exists():
|
||||
url = f"{VOICE_BASE_URL}/{filename}"
|
||||
@@ -393,12 +387,8 @@ def _apply_cpu_optimizations(model: object) -> object:
|
||||
if hasattr(model, "model"):
|
||||
inner = model.model
|
||||
if hasattr(inner, "prediction_head"):
|
||||
_compile_targets.append(
|
||||
("prediction_head", inner, "prediction_head", False)
|
||||
)
|
||||
if hasattr(inner, "acoustic_tokenizer") and hasattr(
|
||||
inner.acoustic_tokenizer, "decode"
|
||||
):
|
||||
_compile_targets.append(("prediction_head", inner, "prediction_head", False))
|
||||
if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
|
||||
_compile_targets.append(
|
||||
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
|
||||
)
|
||||
@@ -493,9 +483,7 @@ def _install_generation_optimizations(model: object) -> None:
|
||||
t_batches = t_batch_cache.get(t_batch_cache_key)
|
||||
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
||||
t_batches = [
|
||||
t.repeat(condition.shape[0]).to(
|
||||
device=condition.device, dtype=condition.dtype
|
||||
)
|
||||
t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
|
||||
for t in scheduler.timesteps
|
||||
]
|
||||
t_batch_cache[t_batch_cache_key] = t_batches
|
||||
@@ -599,12 +587,8 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
||||
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
||||
value.key_cache = [
|
||||
_move_cached_prompt(t, device, dtype) for t in value.key_cache
|
||||
]
|
||||
value.value_cache = [
|
||||
_move_cached_prompt(t, device, dtype) for t in value.value_cache
|
||||
]
|
||||
value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
|
||||
value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
|
||||
return value
|
||||
|
||||
|
||||
@@ -640,33 +624,33 @@ def _load_model_sync() -> None:
|
||||
is_cpu = _device == "cpu"
|
||||
_config["device"] = _device
|
||||
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
||||
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0)
|
||||
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0)
|
||||
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0)
|
||||
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
|
||||
_config["prebuffer_secs"] = _env_float(
|
||||
"VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0
|
||||
)
|
||||
_config["rebuffer_threshold_secs"] = _env_float(
|
||||
"VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0
|
||||
)
|
||||
_config["resume_threshold_secs"] = _env_float(
|
||||
"VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0
|
||||
)
|
||||
_config["default_inference_steps"] = _env_int(
|
||||
"VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10
|
||||
)
|
||||
if is_cpu:
|
||||
logical_cpus = os.cpu_count() or 1
|
||||
_config["cpu_threads"] = _env_int(
|
||||
"VIBEPOD_CPU_THREADS",
|
||||
max(1, logical_cpus // 2)
|
||||
if platform.system() == "Windows"
|
||||
else logical_cpus,
|
||||
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
|
||||
)
|
||||
_config["cpu_interop_threads"] = _env_int(
|
||||
"VIBEPOD_CPU_INTEROP_THREADS", 1
|
||||
)
|
||||
_config["cpu_mkldnn"] = os.environ.get(
|
||||
"VIBEPOD_CPU_MKLDNN", "1"
|
||||
).strip() != "0"
|
||||
_config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
||||
_config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||
|
||||
_processor = _init_processor()
|
||||
_model = _init_model(_device)
|
||||
_voice_presets = _load_voice_presets(_device)
|
||||
|
||||
_model_status = "online"
|
||||
logger.info(
|
||||
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
|
||||
)
|
||||
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||
logger.info("Configuration: %s", _config)
|
||||
|
||||
except Exception as exc:
|
||||
@@ -697,7 +681,7 @@ class GenerateRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1, max_length=10_000)
|
||||
speaker: str = Field(default=DEFAULT_SPEAKER)
|
||||
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
||||
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
|
||||
inference_steps: int | None = Field(default=None, ge=5, le=20)
|
||||
|
||||
@field_validator("text")
|
||||
@classmethod
|
||||
@@ -736,8 +720,8 @@ async def health() -> dict:
|
||||
|
||||
def _sync_generate(
|
||||
req: GenerateRequest,
|
||||
streamer: Optional[object] = None,
|
||||
cancel_event: Optional[threading.Event] = None,
|
||||
streamer: object | None = None,
|
||||
cancel_event: threading.Event | None = None,
|
||||
) -> str:
|
||||
"""Blocking inference. Returns the speaker used.
|
||||
Runs in a thread-pool executor — do not call from the event loop directly.
|
||||
@@ -748,11 +732,13 @@ def _sync_generate(
|
||||
|
||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||
model_dtype = _model_float_dtype()
|
||||
voice_preset = _move_cached_prompt(
|
||||
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
|
||||
)
|
||||
voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
|
||||
|
||||
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
|
||||
steps = (
|
||||
req.inference_steps
|
||||
if req.inference_steps is not None
|
||||
else _config["default_inference_steps"]
|
||||
)
|
||||
_model.set_ddpm_inference_steps(num_steps=steps)
|
||||
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
|
||||
_model._vibepod_profile = {}
|
||||
@@ -790,7 +776,7 @@ def _sse(event: dict) -> str:
|
||||
return f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
|
||||
def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
|
||||
def _generation_profile() -> dict[str, dict[str, float]] | None:
|
||||
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
|
||||
return None
|
||||
stats = getattr(_model, "_vibepod_profile", None)
|
||||
@@ -800,9 +786,7 @@ def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
|
||||
key: {
|
||||
"count": value["count"],
|
||||
"seconds": round(value["seconds"], 3),
|
||||
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3)
|
||||
if value["count"]
|
||||
else 0.0,
|
||||
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
|
||||
}
|
||||
for key, value in sorted(stats.items())
|
||||
}
|
||||
@@ -843,13 +827,11 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
audio_chunks[i].detach(),
|
||||
)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None) -> None:
|
||||
def end(self, sample_indices: torch.Tensor | None = None) -> None:
|
||||
if sample_indices is None:
|
||||
indices_to_end = range(self.batch_size)
|
||||
else:
|
||||
indices_to_end = [
|
||||
s.item() if torch.is_tensor(s) else s for s in sample_indices
|
||||
]
|
||||
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
||||
for idx in indices_to_end:
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.loop.call_soon_threadsafe(
|
||||
@@ -865,8 +847,8 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
accumulated_chunks = []
|
||||
chunk_count = 0
|
||||
audio_samples = 0
|
||||
first_chunk_at: Optional[float] = None
|
||||
last_chunk_at: Optional[float] = None
|
||||
first_chunk_at: float | None = None
|
||||
last_chunk_at: float | None = None
|
||||
max_chunk_gap = 0.0
|
||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||
|
||||
@@ -881,9 +863,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
# stop_signal=None is the default sentinel that ends the queue.
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
streamer.audio_queues[0].get(), timeout=120.0
|
||||
)
|
||||
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
@@ -969,9 +949,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
"elapsed": elapsed,
|
||||
"speaker": speaker,
|
||||
"audio_secs": round(audio_secs, 2),
|
||||
"realtime_factor": round(realtime_factor, 3)
|
||||
if realtime_factor is not None
|
||||
else None,
|
||||
"realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
|
||||
"chunks": chunk_count,
|
||||
"first_chunk_secs": round(first_chunk_at - start, 2)
|
||||
if first_chunk_at is not None
|
||||
|
||||
Reference in New Issue
Block a user