mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-01 15:22:14 +00:00
style: apply ruff formatting and lint fixes to server
This commit is contained in:
@@ -30,15 +30,12 @@ def download() -> str:
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(
|
print(
|
||||||
"ERROR: huggingface_hub is not installed.\n"
|
"ERROR: huggingface_hub is not installed.\nRun: pip install huggingface_hub",
|
||||||
"Run: pip install huggingface_hub",
|
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get(
|
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||||
"HUGGINGFACE_TOKEN"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Checking / downloading model: {MODEL_ID}")
|
print(f"Checking / downloading model: {MODEL_ID}")
|
||||||
print("(This may take several minutes on first run — the model is ~1 GB)")
|
print("(This may take several minutes on first run — the model is ~1 GB)")
|
||||||
|
|||||||
+47
-69
@@ -32,9 +32,10 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator, Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
@@ -50,8 +51,7 @@ SAMPLE_RATE = 24_000
|
|||||||
|
|
||||||
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
||||||
VOICE_BASE_URL = (
|
VOICE_BASE_URL = (
|
||||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main"
|
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
||||||
"/demo/voices/streaming_model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
EN_VOICES: dict[str, str] = {
|
EN_VOICES: dict[str, str] = {
|
||||||
@@ -69,7 +69,7 @@ _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.o
|
|||||||
# ── Pipeline executor ──────────────────────────────────────────────────────────
|
# ── Pipeline executor ──────────────────────────────────────────────────────────
|
||||||
# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker).
|
# Overlaps acoustic_decode with forward_tts_lm on a background thread (1 worker).
|
||||||
|
|
||||||
_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
_decode_executor: concurrent.futures.ThreadPoolExecutor | None = None
|
||||||
|
|
||||||
# ── Device selection ────────────────────────────────────────────────────────────
|
# ── Device selection ────────────────────────────────────────────────────────────
|
||||||
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
|
||||||
@@ -126,9 +126,7 @@ def _cpu_supports_bf16() -> bool:
|
|||||||
|
|
||||||
def _configure_cpu_runtime() -> dict[str, object]:
|
def _configure_cpu_runtime() -> dict[str, object]:
|
||||||
logical_cpus = os.cpu_count() or 1
|
logical_cpus = os.cpu_count() or 1
|
||||||
default_threads = (
|
default_threads = max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
||||||
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
|
|
||||||
)
|
|
||||||
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
|
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
|
||||||
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
||||||
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||||
@@ -157,7 +155,7 @@ _processor = None
|
|||||||
_model = None
|
_model = None
|
||||||
_device: str = "cpu"
|
_device: str = "cpu"
|
||||||
_model_status: ModelStatus = "loading"
|
_model_status: ModelStatus = "loading"
|
||||||
_model_error: Optional[str] = None
|
_model_error: str | None = None
|
||||||
_voice_presets: dict[str, object] = {}
|
_voice_presets: dict[str, object] = {}
|
||||||
_load_lock = threading.Lock()
|
_load_lock = threading.Lock()
|
||||||
_generation_lock = asyncio.Lock()
|
_generation_lock = asyncio.Lock()
|
||||||
@@ -204,9 +202,7 @@ def _is_model_cached() -> bool:
|
|||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download(
|
snapshot_download(MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS)
|
||||||
MODEL_ID, local_files_only=True, ignore_patterns=_IGNORE_PATTERNS
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@@ -215,9 +211,7 @@ def _is_model_cached() -> bool:
|
|||||||
def _download_model() -> None:
|
def _download_model() -> None:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
token: Optional[str] = os.environ.get("HF_TOKEN") or os.environ.get(
|
token: str | None = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
||||||
"HUGGINGFACE_TOKEN"
|
|
||||||
)
|
|
||||||
DlTqdm = _make_dl_tqdm()
|
DlTqdm = _make_dl_tqdm()
|
||||||
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
logger.info("Model not cached — downloading %s...", MODEL_ID)
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
@@ -231,7 +225,7 @@ def _download_model() -> None:
|
|||||||
|
|
||||||
def _download_voices() -> None:
|
def _download_voices() -> None:
|
||||||
VOICES_DIR.mkdir(parents=True, exist_ok=True)
|
VOICES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
for name, filename in EN_VOICES.items():
|
for _name, filename in EN_VOICES.items():
|
||||||
dest = VOICES_DIR / filename
|
dest = VOICES_DIR / filename
|
||||||
if not dest.exists():
|
if not dest.exists():
|
||||||
url = f"{VOICE_BASE_URL}/{filename}"
|
url = f"{VOICE_BASE_URL}/{filename}"
|
||||||
@@ -393,12 +387,8 @@ def _apply_cpu_optimizations(model: object) -> object:
|
|||||||
if hasattr(model, "model"):
|
if hasattr(model, "model"):
|
||||||
inner = model.model
|
inner = model.model
|
||||||
if hasattr(inner, "prediction_head"):
|
if hasattr(inner, "prediction_head"):
|
||||||
_compile_targets.append(
|
_compile_targets.append(("prediction_head", inner, "prediction_head", False))
|
||||||
("prediction_head", inner, "prediction_head", False)
|
if hasattr(inner, "acoustic_tokenizer") and hasattr(inner.acoustic_tokenizer, "decode"):
|
||||||
)
|
|
||||||
if hasattr(inner, "acoustic_tokenizer") and hasattr(
|
|
||||||
inner.acoustic_tokenizer, "decode"
|
|
||||||
):
|
|
||||||
_compile_targets.append(
|
_compile_targets.append(
|
||||||
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
|
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
|
||||||
)
|
)
|
||||||
@@ -493,9 +483,7 @@ def _install_generation_optimizations(model: object) -> None:
|
|||||||
t_batches = t_batch_cache.get(t_batch_cache_key)
|
t_batches = t_batch_cache.get(t_batch_cache_key)
|
||||||
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
if t_batches is None or len(t_batches) != len(scheduler.timesteps):
|
||||||
t_batches = [
|
t_batches = [
|
||||||
t.repeat(condition.shape[0]).to(
|
t.repeat(condition.shape[0]).to(device=condition.device, dtype=condition.dtype)
|
||||||
device=condition.device, dtype=condition.dtype
|
|
||||||
)
|
|
||||||
for t in scheduler.timesteps
|
for t in scheduler.timesteps
|
||||||
]
|
]
|
||||||
t_batch_cache[t_batch_cache_key] = t_batches
|
t_batch_cache[t_batch_cache_key] = t_batches
|
||||||
@@ -599,12 +587,8 @@ def _move_cached_prompt(value: object, device: str, dtype: torch.dtype) -> objec
|
|||||||
if isinstance(value, tuple):
|
if isinstance(value, tuple):
|
||||||
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
return tuple(_move_cached_prompt(v, device, dtype) for v in value)
|
||||||
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
if hasattr(value, "key_cache") and hasattr(value, "value_cache"):
|
||||||
value.key_cache = [
|
value.key_cache = [_move_cached_prompt(t, device, dtype) for t in value.key_cache]
|
||||||
_move_cached_prompt(t, device, dtype) for t in value.key_cache
|
value.value_cache = [_move_cached_prompt(t, device, dtype) for t in value.value_cache]
|
||||||
]
|
|
||||||
value.value_cache = [
|
|
||||||
_move_cached_prompt(t, device, dtype) for t in value.value_cache
|
|
||||||
]
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@@ -640,33 +624,33 @@ def _load_model_sync() -> None:
|
|||||||
is_cpu = _device == "cpu"
|
is_cpu = _device == "cpu"
|
||||||
_config["device"] = _device
|
_config["device"] = _device
|
||||||
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
_config["chunk_accum"] = _env_int("VIBEPOD_CHUNK_ACCUM", 4 if is_cpu else 1)
|
||||||
_config["prebuffer_secs"] = _env_float("VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0)
|
_config["prebuffer_secs"] = _env_float(
|
||||||
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0)
|
"VIBEPOD_PREBUFFER_SECS", 24.0 if is_cpu else 5.0
|
||||||
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0)
|
)
|
||||||
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
|
_config["rebuffer_threshold_secs"] = _env_float(
|
||||||
|
"VIBEPOD_REBUFFER_THRESHOLD_SECS", 2.0 if is_cpu else 1.0
|
||||||
|
)
|
||||||
|
_config["resume_threshold_secs"] = _env_float(
|
||||||
|
"VIBEPOD_RESUME_THRESHOLD_SECS", 12.0 if is_cpu else 3.0
|
||||||
|
)
|
||||||
|
_config["default_inference_steps"] = _env_int(
|
||||||
|
"VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10
|
||||||
|
)
|
||||||
if is_cpu:
|
if is_cpu:
|
||||||
logical_cpus = os.cpu_count() or 1
|
logical_cpus = os.cpu_count() or 1
|
||||||
_config["cpu_threads"] = _env_int(
|
_config["cpu_threads"] = _env_int(
|
||||||
"VIBEPOD_CPU_THREADS",
|
"VIBEPOD_CPU_THREADS",
|
||||||
max(1, logical_cpus // 2)
|
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus,
|
||||||
if platform.system() == "Windows"
|
|
||||||
else logical_cpus,
|
|
||||||
)
|
)
|
||||||
_config["cpu_interop_threads"] = _env_int(
|
_config["cpu_interop_threads"] = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
|
||||||
"VIBEPOD_CPU_INTEROP_THREADS", 1
|
_config["cpu_mkldnn"] = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
|
||||||
)
|
|
||||||
_config["cpu_mkldnn"] = os.environ.get(
|
|
||||||
"VIBEPOD_CPU_MKLDNN", "1"
|
|
||||||
).strip() != "0"
|
|
||||||
|
|
||||||
_processor = _init_processor()
|
_processor = _init_processor()
|
||||||
_model = _init_model(_device)
|
_model = _init_model(_device)
|
||||||
_voice_presets = _load_voice_presets(_device)
|
_voice_presets = _load_voice_presets(_device)
|
||||||
|
|
||||||
_model_status = "online"
|
_model_status = "online"
|
||||||
logger.info(
|
logger.info("Model ready on %s. Voices: %s", _device, list(_voice_presets.keys()))
|
||||||
"Model ready on %s. Voices: %s", _device, list(_voice_presets.keys())
|
|
||||||
)
|
|
||||||
logger.info("Configuration: %s", _config)
|
logger.info("Configuration: %s", _config)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -697,7 +681,7 @@ class GenerateRequest(BaseModel):
|
|||||||
text: str = Field(..., min_length=1, max_length=10_000)
|
text: str = Field(..., min_length=1, max_length=10_000)
|
||||||
speaker: str = Field(default=DEFAULT_SPEAKER)
|
speaker: str = Field(default=DEFAULT_SPEAKER)
|
||||||
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
cfg_scale: float = Field(default=1.5, ge=0.5, le=4.0)
|
||||||
inference_steps: Optional[int] = Field(default=None, ge=5, le=20)
|
inference_steps: int | None = Field(default=None, ge=5, le=20)
|
||||||
|
|
||||||
@field_validator("text")
|
@field_validator("text")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -736,8 +720,8 @@ async def health() -> dict:
|
|||||||
|
|
||||||
def _sync_generate(
|
def _sync_generate(
|
||||||
req: GenerateRequest,
|
req: GenerateRequest,
|
||||||
streamer: Optional[object] = None,
|
streamer: object | None = None,
|
||||||
cancel_event: Optional[threading.Event] = None,
|
cancel_event: threading.Event | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Blocking inference. Returns the speaker used.
|
"""Blocking inference. Returns the speaker used.
|
||||||
Runs in a thread-pool executor — do not call from the event loop directly.
|
Runs in a thread-pool executor — do not call from the event loop directly.
|
||||||
@@ -748,11 +732,13 @@ def _sync_generate(
|
|||||||
|
|
||||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||||
model_dtype = _model_float_dtype()
|
model_dtype = _model_float_dtype()
|
||||||
voice_preset = _move_cached_prompt(
|
voice_preset = _move_cached_prompt(copy.deepcopy(_voice_presets[speaker]), _device, model_dtype)
|
||||||
copy.deepcopy(_voice_presets[speaker]), _device, model_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
steps = req.inference_steps if req.inference_steps is not None else _config["default_inference_steps"]
|
steps = (
|
||||||
|
req.inference_steps
|
||||||
|
if req.inference_steps is not None
|
||||||
|
else _config["default_inference_steps"]
|
||||||
|
)
|
||||||
_model.set_ddpm_inference_steps(num_steps=steps)
|
_model.set_ddpm_inference_steps(num_steps=steps)
|
||||||
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
|
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") == "1":
|
||||||
_model._vibepod_profile = {}
|
_model._vibepod_profile = {}
|
||||||
@@ -790,7 +776,7 @@ def _sse(event: dict) -> str:
|
|||||||
return f"data: {json.dumps(event)}\n\n"
|
return f"data: {json.dumps(event)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
|
def _generation_profile() -> dict[str, dict[str, float]] | None:
|
||||||
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
|
if os.environ.get("VIBEPOD_PROFILE_GENERATION", "0") != "1":
|
||||||
return None
|
return None
|
||||||
stats = getattr(_model, "_vibepod_profile", None)
|
stats = getattr(_model, "_vibepod_profile", None)
|
||||||
@@ -800,9 +786,7 @@ def _generation_profile() -> Optional[dict[str, dict[str, float]]]:
|
|||||||
key: {
|
key: {
|
||||||
"count": value["count"],
|
"count": value["count"],
|
||||||
"seconds": round(value["seconds"], 3),
|
"seconds": round(value["seconds"], 3),
|
||||||
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3)
|
"avg_ms": round(value["seconds"] * 1000 / value["count"], 3) if value["count"] else 0.0,
|
||||||
if value["count"]
|
|
||||||
else 0.0,
|
|
||||||
}
|
}
|
||||||
for key, value in sorted(stats.items())
|
for key, value in sorted(stats.items())
|
||||||
}
|
}
|
||||||
@@ -843,13 +827,11 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
audio_chunks[i].detach(),
|
audio_chunks[i].detach(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def end(self, sample_indices: Optional[torch.Tensor] = None) -> None:
|
def end(self, sample_indices: torch.Tensor | None = None) -> None:
|
||||||
if sample_indices is None:
|
if sample_indices is None:
|
||||||
indices_to_end = range(self.batch_size)
|
indices_to_end = range(self.batch_size)
|
||||||
else:
|
else:
|
||||||
indices_to_end = [
|
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
||||||
s.item() if torch.is_tensor(s) else s for s in sample_indices
|
|
||||||
]
|
|
||||||
for idx in indices_to_end:
|
for idx in indices_to_end:
|
||||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
@@ -865,8 +847,8 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
accumulated_chunks = []
|
accumulated_chunks = []
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
audio_samples = 0
|
audio_samples = 0
|
||||||
first_chunk_at: Optional[float] = None
|
first_chunk_at: float | None = None
|
||||||
last_chunk_at: Optional[float] = None
|
last_chunk_at: float | None = None
|
||||||
max_chunk_gap = 0.0
|
max_chunk_gap = 0.0
|
||||||
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
speaker = req.speaker if req.speaker in _voice_presets else DEFAULT_SPEAKER
|
||||||
|
|
||||||
@@ -881,9 +863,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
# stop_signal=None is the default sentinel that ends the queue.
|
# stop_signal=None is the default sentinel that ends the queue.
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(
|
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
|
||||||
streamer.audio_queues[0].get(), timeout=120.0
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
cancel_event.set()
|
cancel_event.set()
|
||||||
future.cancel()
|
future.cancel()
|
||||||
@@ -969,9 +949,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
|||||||
"elapsed": elapsed,
|
"elapsed": elapsed,
|
||||||
"speaker": speaker,
|
"speaker": speaker,
|
||||||
"audio_secs": round(audio_secs, 2),
|
"audio_secs": round(audio_secs, 2),
|
||||||
"realtime_factor": round(realtime_factor, 3)
|
"realtime_factor": round(realtime_factor, 3) if realtime_factor is not None else None,
|
||||||
if realtime_factor is not None
|
|
||||||
else None,
|
|
||||||
"chunks": chunk_count,
|
"chunks": chunk_count,
|
||||||
"first_chunk_secs": round(first_chunk_at - start, 2)
|
"first_chunk_secs": round(first_chunk_at - start, 2)
|
||||||
if first_chunk_at is not None
|
if first_chunk_at is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user