perf: CPU async pipeline overlap + INT8 quantization

Overlap acoustic_decode with forward_tts_lm calls using a background
ThreadPoolExecutor, hiding ~72s of decode cost behind tts_lm work.
Achieved 0.67x realtime (up from 0.43x, ~56% improvement).

- vibevoice_generate_patch.py: patched generate() loop reordered to
  submit decode to thread before running connector + tts_lm×2, then
  resolve future. Installed as instance method via types.MethodType so
  uv sync reinstalling the package cannot revert the patch.
- Dynamic INT8 quantization of Linear layers (VIBEPOD_QUANTIZE=1,
  default on CPU). prediction_head excluded — small fixed-size tensors
  regressed ~20% with INT8 due to pack/unpack overhead.
- Auto-detect AVX512_BF16 and load model in bfloat16 if supported
  (VIBEPOD_CPU_BF16=auto, overridable with 0/1).
- CPU thread count auto-configured from logical CPU count; OMP/MKL env
  vars set accordingly. Lock file preserved around uv sync --no-sources
  so CPU mode does not alter the shared uv.lock.
- torch.compile retained as opt-in (VIBEPOD_COMPILE=1) but marked not
  recommended — dynamic KV cache shapes prevent kernel reuse.
This commit is contained in:
2026-04-30 20:46:29 +01:00
parent 75b84b211b
commit 7591d15a52
3 changed files with 685 additions and 2 deletions
+195 -1
View File
@@ -20,12 +20,14 @@ Device selection:
import asyncio
import base64
import concurrent.futures
import copy
import functools
import importlib.util
import json
import logging
import os
import platform
import threading
import time
import types
@@ -64,6 +66,10 @@ DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Decode pipeline executor ────────────────────────────────────────────────────
_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
# Falls back to auto-detection if not set.
@@ -108,6 +114,40 @@ def _env_float(name: str, default: float) -> float:
return default
def _cpu_supports_bf16() -> bool:
"""Return True if the CPU has AVX512_BF16 hardware support."""
return (
hasattr(torch, "cpu")
and hasattr(torch.cpu, "is_avx512_bf16_supported")
and torch.cpu.is_avx512_bf16_supported()
)
def _configure_cpu_runtime() -> dict[str, object]:
logical_cpus = os.cpu_count() or 1
default_threads = (
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
)
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
torch.set_num_threads(max(1, intra_threads))
try:
torch.set_num_interop_threads(max(1, interop_threads))
except RuntimeError as exc:
logger.warning("Could not set CPU inter-op threads: %s", exc)
torch.backends.mkldnn.enabled = mkldnn_enabled
return {
"logical_cpus": logical_cpus,
"threads": torch.get_num_threads(),
"interop_threads": torch.get_num_interop_threads(),
"mkldnn_available": torch.backends.mkldnn.is_available(),
"mkldnn_enabled": torch.backends.mkldnn.enabled,
}
# ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"]
@@ -228,12 +268,29 @@ def _init_model(device: str):
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled(),
)
elif device == "cpu":
torch.set_float32_matmul_precision("medium")
logger.info("CPU runtime configuration: %s", _configure_cpu_runtime())
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
if device == "cuda" and cuda_dtype == "fp16":
load_dtype = torch.float16
elif device == "cuda":
load_dtype = torch.bfloat16
else:
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower()
if cpu_bf16_env == "1":
load_dtype = torch.bfloat16
logger.info("CPU BF16 forced via VIBEPOD_CPU_BF16=1")
elif cpu_bf16_env == "0":
load_dtype = torch.float32
logger.info("CPU float32 forced via VIBEPOD_CPU_BF16=0")
elif _cpu_supports_bf16():
load_dtype = torch.bfloat16
logger.info("AVX512_BF16 detected — loading model in bfloat16")
else:
load_dtype = torch.float32
logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)")
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
@@ -274,8 +331,90 @@ def _init_model(device: str):
)
model.eval()
if device == "cpu":
model = _apply_cpu_optimizations(model)
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
_install_generation_optimizations(model)
if device == "cpu":
# Must run after _install_generation_optimizations so the async wrapper
# sits outside the profiling wrapper (VibeVoice calls async → profiling → real decode).
_install_cpu_pipeline_optimizations(model)
return model
def _apply_cpu_optimizations(model: object) -> object:
"""Apply optional post-load CPU optimizations. Returns (possibly new) model object."""
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
if do_quantize:
logger.info("Applying dynamic INT8 quantization to Linear layers...")
try:
import torch.ao.quantization
# The diffusion prediction_head operates on small fixed-size tensors where
# INT8 pack/unpack overhead exceeds the matmul savings (~+20% regression in
# testing). Save and restore it so it stays in float32.
saved_prediction_head = None
if hasattr(model, "model") and hasattr(model.model, "prediction_head"):
saved_prediction_head = model.model.prediction_head
del model.model.prediction_head
model = torch.ao.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
if saved_prediction_head is not None:
model.model.prediction_head = saved_prediction_head
logger.info(
"Dynamic INT8 quantization applied (prediction_head excluded — stays float32)."
)
else:
logger.info("Dynamic INT8 quantization applied.")
except Exception as exc:
logger.warning("Dynamic quantization failed: %s — skipping", exc)
if do_compile:
# torch.compile with inductor on CPU is ineffective for autoregressive TTS:
# each token step produces a unique input shape, so every step triggers a new
# kernel compile event rather than reusing compiled code. Kept as an escape
# hatch but not recommended.
compile_mode = os.environ.get("VIBEPOD_COMPILE_MODE", "reduce-overhead")
logger.info(
"torch.compile enabled (mode=%s) — NOTE: limited benefit for autoregressive"
" models on CPU due to dynamic sequence lengths.",
compile_mode,
)
_compile_targets: list[tuple[str, object, str, bool]] = [
("forward_tts_lm", model, "forward_tts_lm", True),
]
if hasattr(model, "model"):
inner = model.model
if hasattr(inner, "prediction_head"):
_compile_targets.append(
("prediction_head", inner, "prediction_head", False)
)
if hasattr(inner, "acoustic_tokenizer") and hasattr(
inner.acoustic_tokenizer, "decode"
):
_compile_targets.append(
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
)
for label, obj, attr, dynamic in _compile_targets:
try:
compiled = torch.compile(
getattr(obj, attr),
backend="inductor",
mode=compile_mode,
dynamic=dynamic,
)
setattr(obj, attr, compiled)
logger.info(" compiled: %s", label)
except Exception as exc:
logger.warning(" torch.compile failed for %s: %s — skipping", label, exc)
return model
@@ -403,6 +542,45 @@ def _install_generation_optimizations(model: object) -> None:
logger.info("Installed VibeVoice generation hot-path optimizations.")
def _install_cpu_pipeline_optimizations(model: object) -> None:
"""Install the async-decode generate() patch and its thread pool on the model instance.
The VibeVoice inner loop runs:
decode(speech_latent) → append → put → connector → tts_lm(pos) → tts_lm(neg)
connector and both tts_lm calls only need speech_latent/acoustic_embed, not
audio_chunk. The patched generate() reorders this to:
submit decode to thread → connector → tts_lm(pos) → tts_lm(neg)
→ wait for decode future → append → put
The patch is applied as an instance method via types.MethodType, which shadows
the class-level generate() and is immune to uv sync reinstalling the package.
"""
global _decode_executor
if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1":
logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.")
return
try:
import vibevoice_generate_patch
except ImportError:
logger.warning(
"vibevoice_generate_patch not found — async decode unavailable. "
"Ensure vibevoice_generate_patch.py is in the server directory."
)
return
_decode_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="vibepod-decode"
)
vibevoice_generate_patch.install(model, _decode_executor)
logger.info(
"CPU pipeline: patched generate() installed (async decode enabled) — "
"acoustic_decode overlaps forward_tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0."
)
def _model_float_dtype() -> torch.dtype:
try:
return next(_model.parameters()).dtype
@@ -469,6 +647,20 @@ def _load_model_sync() -> None:
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0)
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0)
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
if is_cpu:
logical_cpus = os.cpu_count() or 1
_config["cpu_threads"] = _env_int(
"VIBEPOD_CPU_THREADS",
max(1, logical_cpus // 2)
if platform.system() == "Windows"
else logical_cpus,
)
_config["cpu_interop_threads"] = _env_int(
"VIBEPOD_CPU_INTEROP_THREADS", 1
)
_config["cpu_mkldnn"] = os.environ.get(
"VIBEPOD_CPU_MKLDNN", "1"
).strip() != "0"
_processor = _init_processor()
_model = _init_model(_device)
@@ -494,6 +686,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
if _decode_executor is not None:
_decode_executor.shutdown(wait=False)
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)