perf: CPU async pipeline overlap + INT8 quantization

Overlap acoustic_decode with forward_tts_lm calls using a background
ThreadPoolExecutor, hiding ~72s of decode cost behind tts_lm work.
Achieved 0.67x realtime (up from 0.43x, ~56% improvement).

- vibevoice_generate_patch.py: patched generate() loop reordered to
  submit decode to thread before running connector + tts_lm×2, then
  resolve future. Installed as instance method via types.MethodType so
  uv sync reinstalling the package cannot revert the patch.
- Dynamic INT8 quantization of Linear layers (VIBEPOD_QUANTIZE=1,
  default on CPU). prediction_head excluded — small fixed-size tensors
  regressed ~20% with INT8 due to pack/unpack overhead.
- Auto-detect AVX512_BF16 and load model in bfloat16 if supported
  (VIBEPOD_CPU_BF16=auto, overridable with 0/1).
- CPU thread count auto-configured from logical CPU count; OMP/MKL env
  vars set accordingly. Lock file preserved around uv sync --no-sources
  so CPU mode does not alter the shared uv.lock.
- torch.compile retained as opt-in (VIBEPOD_COMPILE=1) but marked not
  recommended — dynamic KV cache shapes prevent kernel reuse.
This commit is contained in:
2026-04-30 20:46:29 +01:00
parent 75b84b211b
commit 7591d15a52
3 changed files with 685 additions and 2 deletions
+27 -1
View File
@@ -79,7 +79,16 @@ echo ""
if $CPU_MODE; then
echo "--> Syncing CPU Python environment (.venv-cpu)..."
export UV_PROJECT_ENVIRONMENT=".venv-cpu"
LOCK_BACKUP=""
if [[ -f uv.lock ]]; then
LOCK_BACKUP="$(mktemp)"
cp uv.lock "$LOCK_BACKUP"
fi
uv sync --no-sources
if [[ -n "$LOCK_BACKUP" ]]; then
cp "$LOCK_BACKUP" uv.lock
rm -f "$LOCK_BACKUP"
fi
else
echo "--> Syncing CUDA Python environment (.venv)..."
uv sync
@@ -126,11 +135,28 @@ export PYTHONUTF8=1
if $CPU_MODE; then
export VIBEPOD_DEVICE="cpu"
export UV_PROJECT_ENVIRONMENT=".venv-cpu"
if [[ -z "${VIBEPOD_CPU_THREADS:-}" ]]; then
VIBEPOD_CPU_THREADS="$(uv run --no-sources python -c "import os; print(max(1, (os.cpu_count() or 2) // 2))")"
export VIBEPOD_CPU_THREADS
fi
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$VIBEPOD_CPU_THREADS}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$VIBEPOD_CPU_THREADS}"
# Dynamic INT8 quantization — on by default for CPU (~22% faster, prediction_head
# excluded automatically to avoid regression on small fixed-size tensors).
# Set VIBEPOD_QUANTIZE=0 to disable if you notice audio quality differences.
export VIBEPOD_QUANTIZE="${VIBEPOD_QUANTIZE:-1}"
# Optional CPU flags:
# VIBEPOD_ASYNC_DECODE=0 Disable async decode/tts_lm overlap (on by default)
# VIBEPOD_CPU_BF16=1 Force bfloat16 weights (auto-detected via AVX512_BF16)
# VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive
# models on CPU — not recommended, kept for experimentation)
UV_RUN_ARGS=(--no-sync --no-sources)
else
export VIBEPOD_DEVICE="cuda"
UV_RUN_ARGS=()
fi
exec uv run uvicorn vibevoice_server:app \
exec uv run "${UV_RUN_ARGS[@]}" uvicorn vibevoice_server:app \
--host 127.0.0.1 \
--port 8000 \
--log-level info \
+463
View File
@@ -0,0 +1,463 @@
"""
VibePod CPU pipeline optimisation — patched VibeVoice generate() loop.
WHY THIS FILE EXISTS
--------------------
The VibeVoice inner speech-generation loop runs:
decode(speech_latent) # 87 ms — VAE decode to audio waveform
audio_chunks.append(chunk) # store for final return value
audio_streamer.put(chunk) # stream to client
acoustic_connector(speech_latent) -> acoustic_embed # 1 ms
forward_tts_lm(acoustic_embed) # ~49 ms (positive)
forward_tts_lm(acoustic_embed) # ~49 ms (negative CFG)
acoustic_connector and both forward_tts_lm calls depend only on speech_latent /
acoustic_embed — they are completely independent of the decoded audio waveform.
Running decode in a thread while connector + tts_lm run on the main thread hides
~87 ms of decode cost per token behind the ~99 ms of tts_lm work:
Before: 87 + 1 + 49 + 49 = 186 ms / token
After: max(87, 1 + 49 + 49) = 99 ms / token (~47 % reduction)
HOW IT WORKS
------------
At model load time, _install_cpu_pipeline_optimizations() in vibevoice_server.py:
1. Creates a single-worker ThreadPoolExecutor and attaches it to the model as
model._vibepod_decode_executor.
2. Installs this module's `patched_generate` as a bound method on the model
instance via types.MethodType, shadowing the class-level generate().
Because the patch lives on the *instance*, uv sync reinstalling the VibeVoice
package has no effect — Python resolves instance attributes before class ones.
MAINTENANCE
-----------
This is a verbatim copy of VibeVoice's generate() method (lines 574910 of
modeling_vibevoice_streaming_inference.py) with the inner speech loop reordered.
The only changed region is marked with # [VibePod] comments.
If VibeVoice updates its generate() method, diff the new version against this
file and merge carefully. The sentinel string "[VibePod]" marks every changed
line to make diffing easy.
"""
import concurrent.futures
import types
from typing import Callable, List, Optional, Union
import torch
from tqdm import tqdm
from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_utils import PreTrainedModel
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
TTS_TEXT_WINDOW_SIZE,
TTS_SPEECH_WINDOW_SIZE,
VibeVoiceGenerationOutput,
_update_model_kwargs_for_generation,
)
from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache
from vibevoice.modular.streamer import AudioStreamer, AsyncAudioStreamer
def patched_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
tts_text_ids: Optional[torch.LongTensor] = None,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
**kwargs,
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
# ── Setup (unchanged from original) ─────────────────────────────────────
tokenizer = kwargs.pop("tokenizer", None)
neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>")
tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None)
tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None)
all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None)
tts_text_ids = tts_text_ids.to(self.device)
if kwargs.get("max_new_tokens", None) is None:
kwargs["max_new_tokens"] = (
self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1]
)
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = (
self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
)
negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
negative_generation_config, negative_model_kwargs, negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
)
tts_lm_kwargs = {
"input_ids": tts_lm_input_ids,
"attention_mask": tts_lm_attention_mask,
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_kwargs
)
)
tts_lm_negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs
)
)
acoustic_cache = VibeVoiceTokenizerStreamingCache()
batch_size = input_ids.shape[0]
assert batch_size == 1, "Currently only supports batch size == 1"
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
verbose = kwargs.get("verbose", False)
audio_chunks = [[] for _ in range(batch_size)]
tts_text_window_index = 0
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
first_text_window_size = (
TTS_TEXT_WINDOW_SIZE
if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE
else tts_text_ids.shape[1]
)
outputs = all_prefilled_outputs["lm"]
tts_lm_outputs = all_prefilled_outputs["tts_lm"]
negative_outputs = all_prefilled_outputs["neg_lm"]
tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"]
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=first_text_window_size
)
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False
)
step = tts_lm_input_ids.shape[1]
total_generated_speech_tokens = 0
total_prefilled_text_tokens = 0
if kwargs.get("show_progress_bar", True):
progress_bar = tqdm(
total=tts_lm_generation_config.max_length,
desc=f"Prefilled {step} tokens, current step ({step} / {tts_lm_generation_config.max_length})",
initial=step,
leave=False,
)
else:
progress_bar = None
# [VibePod] Grab the executor once; None means standard sequential path.
_vp_executor: Optional[concurrent.futures.ThreadPoolExecutor] = getattr(
self, "_vibepod_decode_executor", None
)
# ── Main generation loop (unchanged from original) ───────────────────────
while True:
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step + 1}")
if audio_streamer is not None:
audio_streamer.end()
break
if finished_tags.all():
if hasattr(progress_bar, "set_description"):
progress_bar.set_description("Generation complete")
break
cur_input_tts_text_ids = tts_text_ids[
:,
tts_text_window_index * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE,
]
next_text_window_size = tts_text_ids[
:,
(tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2)
* TTS_TEXT_WINDOW_SIZE,
].shape[1]
tts_text_window_index += 1
if cur_input_tts_text_ids.shape[1] > 0:
input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1)
tts_lm_input_ids = torch.cat([tts_lm_input_ids, cur_input_tts_text_ids], dim=-1)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {generation_config.max_length}, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
step += cur_input_tts_text_ids.shape[1]
total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1]
if progress_bar is not None:
progress_bar.update(cur_input_tts_text_ids.shape[1])
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens} text tokens, "
f"generated {total_generated_speech_tokens} speech tokens, "
f"current step ({step} / {tts_lm_generation_config.max_length})"
)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self.forward_lm(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=next_text_window_size
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": outputs.last_hidden_state,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
diffusion_indices = torch.LongTensor([0])
# ── Inner speech loop ────────────────────────────────────────────────
for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE):
positive_condition = tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]
negative_condition = tts_lm_negative_outputs.last_hidden_state[
diffusion_indices, -1, :
]
speech_latent = self.sample_speech_tokens(
positive_condition,
negative_condition,
cfg_scale=cfg_scale,
).unsqueeze(1)
scaled_latent = (
speech_latent / self.model.speech_scaling_factor.to(speech_latent.device)
- self.model.speech_bias_factor.to(speech_latent.device)
)
# [VibePod] If a decode executor is configured, submit decode to a
# background thread so acoustic_connector and forward_tts_lm can run
# concurrently on the main thread. The future is resolved after both
# tts_lm calls complete, before appending/streaming the audio chunk.
# Without the executor, the original sequential path is used unchanged.
if _vp_executor is not None:
_decode_future: concurrent.futures.Future[torch.Tensor] = _vp_executor.submit(
self.model.acoustic_tokenizer.decode,
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
use_cache=True,
debug=False,
)
else:
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
use_cache=True,
debug=False,
)
# [VibePod] connector + tts_lm run here while decode is in the thread.
acoustic_embed = self.model.acoustic_connector(speech_latent)
tts_lm_input_ids = torch.cat(
[tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1
)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
# [VibePod] Resolve before break so audio_chunks stays consistent.
if _vp_executor is not None:
audio_chunk = _decode_future.result()
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
break
step += 1
total_generated_speech_tokens += 1
if progress_bar is not None:
progress_bar.update(1)
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens} text tokens, "
f"generated {total_generated_speech_tokens} speech tokens, "
f"current step ({step} / {tts_lm_generation_config.max_length})"
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0:
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size
)
else:
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_input_ids = torch.cat(
[tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1
)
tts_lm_negative_model_inputs = self.prepare_inputs_for_generation(
tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs
)
tts_lm_negative_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_negative_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_negative_outputs = self.forward_tts_lm(
**tts_lm_negative_model_inputs,
**tts_lm_negative_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs,
tts_lm_negative_model_kwargs,
is_encoder_decoder=False,
)
# [VibePod] Decode is done (or was never async). Resolve future,
# then append + stream — moved here from before connector/tts_lm.
if _vp_executor is not None:
audio_chunk = _decode_future.result()
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
tts_eos_logits = torch.sigmoid(
self.tts_eos_classifier(
tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]
)
)
if tts_eos_logits[0].item() > 0.5:
finished_tags[diffusion_indices] = True
if audio_streamer is not None:
audio_streamer.end(diffusion_indices)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
if audio_streamer is not None:
audio_streamer.end()
# ── Audio finalisation (unchanged from original) ─────────────────────────
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
final_audio_outputs.append(None)
if reach_max_step_sample is not None and reach_max_step_sample.any():
print(
f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it."
)
return VibeVoiceGenerationOutput(
sequences=tts_lm_input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
def install(model: object, executor: concurrent.futures.ThreadPoolExecutor) -> None:
"""Install the patched generate() on a model instance and attach the executor."""
model._vibepod_decode_executor = executor
model.generate = types.MethodType(patched_generate, model)
+195 -1
View File
@@ -20,12 +20,14 @@ Device selection:
import asyncio
import base64
import concurrent.futures
import copy
import functools
import importlib.util
import json
import logging
import os
import platform
import threading
import time
import types
@@ -64,6 +66,10 @@ DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Decode pipeline executor ────────────────────────────────────────────────────
_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
# Falls back to auto-detection if not set.
@@ -108,6 +114,40 @@ def _env_float(name: str, default: float) -> float:
return default
def _cpu_supports_bf16() -> bool:
"""Return True if the CPU has AVX512_BF16 hardware support."""
return (
hasattr(torch, "cpu")
and hasattr(torch.cpu, "is_avx512_bf16_supported")
and torch.cpu.is_avx512_bf16_supported()
)
def _configure_cpu_runtime() -> dict[str, object]:
logical_cpus = os.cpu_count() or 1
default_threads = (
max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus
)
intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads)
interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1)
mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0"
torch.set_num_threads(max(1, intra_threads))
try:
torch.set_num_interop_threads(max(1, interop_threads))
except RuntimeError as exc:
logger.warning("Could not set CPU inter-op threads: %s", exc)
torch.backends.mkldnn.enabled = mkldnn_enabled
return {
"logical_cpus": logical_cpus,
"threads": torch.get_num_threads(),
"interop_threads": torch.get_num_interop_threads(),
"mkldnn_available": torch.backends.mkldnn.is_available(),
"mkldnn_enabled": torch.backends.mkldnn.enabled,
}
# ── Global state ────────────────────────────────────────────────────────────────
ModelStatus = Literal["downloading", "loading", "online", "error"]
@@ -228,12 +268,29 @@ def _init_model(device: str):
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled(),
)
elif device == "cpu":
torch.set_float32_matmul_precision("medium")
logger.info("CPU runtime configuration: %s", _configure_cpu_runtime())
cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower()
if device == "cuda" and cuda_dtype == "fp16":
load_dtype = torch.float16
elif device == "cuda":
load_dtype = torch.bfloat16
else:
load_dtype = torch.bfloat16 if device == "cuda" else torch.float32
cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower()
if cpu_bf16_env == "1":
load_dtype = torch.bfloat16
logger.info("CPU BF16 forced via VIBEPOD_CPU_BF16=1")
elif cpu_bf16_env == "0":
load_dtype = torch.float32
logger.info("CPU float32 forced via VIBEPOD_CPU_BF16=0")
elif _cpu_supports_bf16():
load_dtype = torch.bfloat16
logger.info("AVX512_BF16 detected — loading model in bfloat16")
else:
load_dtype = torch.float32
logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)")
logger.info("Loading model weights with dtype %s", load_dtype)
requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower()
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
@@ -274,8 +331,90 @@ def _init_model(device: str):
)
model.eval()
if device == "cpu":
model = _apply_cpu_optimizations(model)
model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"])
_install_generation_optimizations(model)
if device == "cpu":
# Must run after _install_generation_optimizations so the async wrapper
# sits outside the profiling wrapper (VibeVoice calls async → profiling → real decode).
_install_cpu_pipeline_optimizations(model)
return model
def _apply_cpu_optimizations(model: object) -> object:
"""Apply optional post-load CPU optimizations. Returns (possibly new) model object."""
do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1"
do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1"
if do_quantize:
logger.info("Applying dynamic INT8 quantization to Linear layers...")
try:
import torch.ao.quantization
# The diffusion prediction_head operates on small fixed-size tensors where
# INT8 pack/unpack overhead exceeds the matmul savings (~+20% regression in
# testing). Save and restore it so it stays in float32.
saved_prediction_head = None
if hasattr(model, "model") and hasattr(model.model, "prediction_head"):
saved_prediction_head = model.model.prediction_head
del model.model.prediction_head
model = torch.ao.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
if saved_prediction_head is not None:
model.model.prediction_head = saved_prediction_head
logger.info(
"Dynamic INT8 quantization applied (prediction_head excluded — stays float32)."
)
else:
logger.info("Dynamic INT8 quantization applied.")
except Exception as exc:
logger.warning("Dynamic quantization failed: %s — skipping", exc)
if do_compile:
# torch.compile with inductor on CPU is ineffective for autoregressive TTS:
# each token step produces a unique input shape, so every step triggers a new
# kernel compile event rather than reusing compiled code. Kept as an escape
# hatch but not recommended.
compile_mode = os.environ.get("VIBEPOD_COMPILE_MODE", "reduce-overhead")
logger.info(
"torch.compile enabled (mode=%s) — NOTE: limited benefit for autoregressive"
" models on CPU due to dynamic sequence lengths.",
compile_mode,
)
_compile_targets: list[tuple[str, object, str, bool]] = [
("forward_tts_lm", model, "forward_tts_lm", True),
]
if hasattr(model, "model"):
inner = model.model
if hasattr(inner, "prediction_head"):
_compile_targets.append(
("prediction_head", inner, "prediction_head", False)
)
if hasattr(inner, "acoustic_tokenizer") and hasattr(
inner.acoustic_tokenizer, "decode"
):
_compile_targets.append(
("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False)
)
for label, obj, attr, dynamic in _compile_targets:
try:
compiled = torch.compile(
getattr(obj, attr),
backend="inductor",
mode=compile_mode,
dynamic=dynamic,
)
setattr(obj, attr, compiled)
logger.info(" compiled: %s", label)
except Exception as exc:
logger.warning(" torch.compile failed for %s: %s — skipping", label, exc)
return model
@@ -403,6 +542,45 @@ def _install_generation_optimizations(model: object) -> None:
logger.info("Installed VibeVoice generation hot-path optimizations.")
def _install_cpu_pipeline_optimizations(model: object) -> None:
"""Install the async-decode generate() patch and its thread pool on the model instance.
The VibeVoice inner loop runs:
decode(speech_latent) → append → put → connector → tts_lm(pos) → tts_lm(neg)
connector and both tts_lm calls only need speech_latent/acoustic_embed, not
audio_chunk. The patched generate() reorders this to:
submit decode to thread → connector → tts_lm(pos) → tts_lm(neg)
→ wait for decode future → append → put
The patch is applied as an instance method via types.MethodType, which shadows
the class-level generate() and is immune to uv sync reinstalling the package.
"""
global _decode_executor
if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1":
logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.")
return
try:
import vibevoice_generate_patch
except ImportError:
logger.warning(
"vibevoice_generate_patch not found — async decode unavailable. "
"Ensure vibevoice_generate_patch.py is in the server directory."
)
return
_decode_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="vibepod-decode"
)
vibevoice_generate_patch.install(model, _decode_executor)
logger.info(
"CPU pipeline: patched generate() installed (async decode enabled) — "
"acoustic_decode overlaps forward_tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0."
)
def _model_float_dtype() -> torch.dtype:
try:
return next(_model.parameters()).dtype
@@ -469,6 +647,20 @@ def _load_model_sync() -> None:
_config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0)
_config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0)
_config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10)
if is_cpu:
logical_cpus = os.cpu_count() or 1
_config["cpu_threads"] = _env_int(
"VIBEPOD_CPU_THREADS",
max(1, logical_cpus // 2)
if platform.system() == "Windows"
else logical_cpus,
)
_config["cpu_interop_threads"] = _env_int(
"VIBEPOD_CPU_INTEROP_THREADS", 1
)
_config["cpu_mkldnn"] = os.environ.get(
"VIBEPOD_CPU_MKLDNN", "1"
).strip() != "0"
_processor = _init_processor()
_model = _init_model(_device)
@@ -494,6 +686,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
if _decode_executor is not None:
_decode_executor.shutdown(wait=False)
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)