From 7591d15a5221b417a57a33e9bf7956e72b0628ac Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 20:46:29 +0100 Subject: [PATCH] perf: CPU async pipeline overlap + INT8 quantization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Overlap acoustic_decode with forward_tts_lm calls using a background ThreadPoolExecutor, hiding ~72s of decode cost behind tts_lm work. Achieved 0.67x realtime (up from 0.43x, ~56% improvement). - vibevoice_generate_patch.py: patched generate() loop reordered to submit decode to thread before running connector + tts_lm×2, then resolve future. Installed as instance method via types.MethodType so uv sync reinstalling the package cannot revert the patch. - Dynamic INT8 quantization of Linear layers (VIBEPOD_QUANTIZE=1, default on CPU). prediction_head excluded — small fixed-size tensors regressed ~20% with INT8 due to pack/unpack overhead. - Auto-detect AVX512_BF16 and load model in bfloat16 if supported (VIBEPOD_CPU_BF16=auto, overridable with 0/1). - CPU thread count auto-configured from logical CPU count; OMP/MKL env vars set accordingly. Lock file preserved around uv sync --no-sources so CPU mode does not alter the shared uv.lock. - torch.compile retained as opt-in (VIBEPOD_COMPILE=1) but marked not recommended — dynamic KV cache shapes prevent kernel reuse. --- server/start.sh | 28 +- server/vibevoice_generate_patch.py | 463 +++++++++++++++++++++++++++++ server/vibevoice_server.py | 196 +++++++++++- 3 files changed, 685 insertions(+), 2 deletions(-) create mode 100644 server/vibevoice_generate_patch.py diff --git a/server/start.sh b/server/start.sh index 2daa340..995fbc1 100755 --- a/server/start.sh +++ b/server/start.sh @@ -79,7 +79,16 @@ echo "" if $CPU_MODE; then echo "--> Syncing CPU Python environment (.venv-cpu)..." export UV_PROJECT_ENVIRONMENT=".venv-cpu" + LOCK_BACKUP="" + if [[ -f uv.lock ]]; then + LOCK_BACKUP="$(mktemp)" + cp uv.lock "$LOCK_BACKUP" + fi uv sync --no-sources + if [[ -n "$LOCK_BACKUP" ]]; then + cp "$LOCK_BACKUP" uv.lock + rm -f "$LOCK_BACKUP" + fi else echo "--> Syncing CUDA Python environment (.venv)..." uv sync @@ -126,11 +135,28 @@ export PYTHONUTF8=1 if $CPU_MODE; then export VIBEPOD_DEVICE="cpu" export UV_PROJECT_ENVIRONMENT=".venv-cpu" + if [[ -z "${VIBEPOD_CPU_THREADS:-}" ]]; then + VIBEPOD_CPU_THREADS="$(uv run --no-sources python -c "import os; print(max(1, (os.cpu_count() or 2) // 2))")" + export VIBEPOD_CPU_THREADS + fi + export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$VIBEPOD_CPU_THREADS}" + export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$VIBEPOD_CPU_THREADS}" + # Dynamic INT8 quantization — on by default for CPU (~22% faster, prediction_head + # excluded automatically to avoid regression on small fixed-size tensors). + # Set VIBEPOD_QUANTIZE=0 to disable if you notice audio quality differences. + export VIBEPOD_QUANTIZE="${VIBEPOD_QUANTIZE:-1}" + # Optional CPU flags: + # VIBEPOD_ASYNC_DECODE=0 Disable async decode/tts_lm overlap (on by default) + # VIBEPOD_CPU_BF16=1 Force bfloat16 weights (auto-detected via AVX512_BF16) + # VIBEPOD_COMPILE=1 torch.compile hot paths (ineffective for autoregressive + # models on CPU — not recommended, kept for experimentation) + UV_RUN_ARGS=(--no-sync --no-sources) else export VIBEPOD_DEVICE="cuda" + UV_RUN_ARGS=() fi -exec uv run uvicorn vibevoice_server:app \ +exec uv run "${UV_RUN_ARGS[@]}" uvicorn vibevoice_server:app \ --host 127.0.0.1 \ --port 8000 \ --log-level info \ diff --git a/server/vibevoice_generate_patch.py b/server/vibevoice_generate_patch.py new file mode 100644 index 0000000..825577d --- /dev/null +++ b/server/vibevoice_generate_patch.py @@ -0,0 +1,463 @@ +""" +VibePod CPU pipeline optimisation — patched VibeVoice generate() loop. + +WHY THIS FILE EXISTS +-------------------- +The VibeVoice inner speech-generation loop runs: + + decode(speech_latent) # 87 ms — VAE decode to audio waveform + audio_chunks.append(chunk) # store for final return value + audio_streamer.put(chunk) # stream to client + acoustic_connector(speech_latent) -> acoustic_embed # 1 ms + forward_tts_lm(acoustic_embed) # ~49 ms (positive) + forward_tts_lm(acoustic_embed) # ~49 ms (negative CFG) + +acoustic_connector and both forward_tts_lm calls depend only on speech_latent / +acoustic_embed — they are completely independent of the decoded audio waveform. +Running decode in a thread while connector + tts_lm run on the main thread hides +~87 ms of decode cost per token behind the ~99 ms of tts_lm work: + + Before: 87 + 1 + 49 + 49 = 186 ms / token + After: max(87, 1 + 49 + 49) = 99 ms / token (~47 % reduction) + +HOW IT WORKS +------------ +At model load time, _install_cpu_pipeline_optimizations() in vibevoice_server.py: + 1. Creates a single-worker ThreadPoolExecutor and attaches it to the model as + model._vibepod_decode_executor. + 2. Installs this module's `patched_generate` as a bound method on the model + instance via types.MethodType, shadowing the class-level generate(). + +Because the patch lives on the *instance*, uv sync reinstalling the VibeVoice +package has no effect — Python resolves instance attributes before class ones. + +MAINTENANCE +----------- +This is a verbatim copy of VibeVoice's generate() method (lines 574–910 of +modeling_vibevoice_streaming_inference.py) with the inner speech loop reordered. +The only changed region is marked with # [VibePod] comments. + +If VibeVoice updates its generate() method, diff the new version against this +file and merge carefully. The sentinel string "[VibePod]" marks every changed +line to make diffing easy. +""" + +import concurrent.futures +import types +from typing import Callable, List, Optional, Union + +import torch +from tqdm import tqdm +from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from transformers.modeling_utils import PreTrainedModel + +from vibevoice.modular.modeling_vibevoice_streaming_inference import ( + TTS_TEXT_WINDOW_SIZE, + TTS_SPEECH_WINDOW_SIZE, + VibeVoiceGenerationOutput, + _update_model_kwargs_for_generation, +) +from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache +from vibevoice.modular.streamer import AudioStreamer, AsyncAudioStreamer + + +def patched_generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speech_input_mask: Optional[torch.BoolTensor] = None, + tts_text_ids: Optional[torch.LongTensor] = None, + return_speech: bool = True, + cfg_scale: float = 1.0, + stop_check_fn: Optional[Callable[[], bool]] = None, + **kwargs, +) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]: + # ── Setup (unchanged from original) ───────────────────────────────────── + tokenizer = kwargs.pop("tokenizer", None) + neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") + + tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None) + tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None) + all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None) + tts_text_ids = tts_text_ids.to(self.device) + + if kwargs.get("max_new_tokens", None) is None: + kwargs["max_new_tokens"] = ( + self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1] + ) + + generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = ( + self._build_generate_config_model_kwargs( + generation_config, inputs, tokenizer, return_processors=True, **kwargs + ) + ) + + negative_kwargs = { + "input_ids": torch.full( + (kwargs["input_ids"].shape[0], 1), + neg_text_input_id, + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "attention_mask": torch.ones( + (kwargs["input_ids"].shape[0], 1), + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + negative_generation_config, negative_model_kwargs, negative_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **negative_kwargs + ) + ) + + tts_lm_kwargs = { + "input_ids": tts_lm_input_ids, + "attention_mask": tts_lm_attention_mask, + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **tts_lm_kwargs + ) + ) + + tts_lm_negative_kwargs = { + "input_ids": torch.full( + (kwargs["input_ids"].shape[0], 1), + neg_text_input_id, + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "attention_mask": torch.ones( + (kwargs["input_ids"].shape[0], 1), + dtype=torch.long, + device=kwargs["input_ids"].device, + ), + "max_new_tokens": kwargs.get("max_new_tokens", 100), + } + tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids = ( + self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs + ) + ) + + acoustic_cache = VibeVoiceTokenizerStreamingCache() + batch_size = input_ids.shape[0] + assert batch_size == 1, "Currently only supports batch size == 1" + device = input_ids.device + finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) + verbose = kwargs.get("verbose", False) + + audio_chunks = [[] for _ in range(batch_size)] + tts_text_window_index = 0 + reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) + first_text_window_size = ( + TTS_TEXT_WINDOW_SIZE + if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE + else tts_text_ids.shape[1] + ) + + outputs = all_prefilled_outputs["lm"] + tts_lm_outputs = all_prefilled_outputs["tts_lm"] + negative_outputs = all_prefilled_outputs["neg_lm"] + tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"] + + model_kwargs = _update_model_kwargs_for_generation( + outputs, model_kwargs, num_new_tokens=first_text_window_size + ) + tts_lm_model_kwargs = _update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size + ) + negative_model_kwargs = self._update_model_kwargs_for_generation( + negative_outputs, negative_model_kwargs, is_encoder_decoder=False + ) + tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False + ) + + step = tts_lm_input_ids.shape[1] + total_generated_speech_tokens = 0 + total_prefilled_text_tokens = 0 + if kwargs.get("show_progress_bar", True): + progress_bar = tqdm( + total=tts_lm_generation_config.max_length, + desc=f"Prefilled {step} tokens, current step ({step} / {tts_lm_generation_config.max_length})", + initial=step, + leave=False, + ) + else: + progress_bar = None + + # [VibePod] Grab the executor once; None means standard sequential path. + _vp_executor: Optional[concurrent.futures.ThreadPoolExecutor] = getattr( + self, "_vibepod_decode_executor", None + ) + + # ── Main generation loop (unchanged from original) ─────────────────────── + while True: + if stop_check_fn is not None and stop_check_fn(): + if verbose: + print(f"Generation stopped externally at step {step + 1}") + if audio_streamer is not None: + audio_streamer.end() + break + + if finished_tags.all(): + if hasattr(progress_bar, "set_description"): + progress_bar.set_description("Generation complete") + break + + cur_input_tts_text_ids = tts_text_ids[ + :, + tts_text_window_index * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1) + * TTS_TEXT_WINDOW_SIZE, + ] + next_text_window_size = tts_text_ids[ + :, + (tts_text_window_index + 1) + * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2) + * TTS_TEXT_WINDOW_SIZE, + ].shape[1] + tts_text_window_index += 1 + + if cur_input_tts_text_ids.shape[1] > 0: + input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1) + tts_lm_input_ids = torch.cat([tts_lm_input_ids, cur_input_tts_text_ids], dim=-1) + + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + if verbose: + print( + f"Reached maximum generation length {generation_config.max_length}, stopped it." + ) + reached_samples = torch.arange(batch_size, device=device)[~finished_tags] + if reached_samples.numel() > 0: + reach_max_step_sample[reached_samples] = True + break + + step += cur_input_tts_text_ids.shape[1] + total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1] + if progress_bar is not None: + progress_bar.update(cur_input_tts_text_ids.shape[1]) + progress_bar.set_description( + f"Prefilled {total_prefilled_text_tokens} text tokens, " + f"generated {total_generated_speech_tokens} speech tokens, " + f"current step ({step} / {tts_lm_generation_config.max_length})" + ) + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self.forward_lm( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + model_kwargs = _update_model_kwargs_for_generation( + outputs, model_kwargs, num_new_tokens=next_text_window_size + ) + + tts_lm_model_inputs = self.prepare_inputs_for_generation( + tts_lm_input_ids, **tts_lm_model_kwargs + ) + tts_lm_additional_inputs = { + "tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]), + "lm_last_hidden_state": outputs.last_hidden_state, + } + tts_lm_outputs = self.forward_tts_lm( + **tts_lm_model_inputs, + **tts_lm_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + tts_lm_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False + ) + + diffusion_indices = torch.LongTensor([0]) + + # ── Inner speech loop ──────────────────────────────────────────────── + for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE): + positive_condition = tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] + negative_condition = tts_lm_negative_outputs.last_hidden_state[ + diffusion_indices, -1, : + ] + + speech_latent = self.sample_speech_tokens( + positive_condition, + negative_condition, + cfg_scale=cfg_scale, + ).unsqueeze(1) + + scaled_latent = ( + speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) + - self.model.speech_bias_factor.to(speech_latent.device) + ) + + # [VibePod] If a decode executor is configured, submit decode to a + # background thread so acoustic_connector and forward_tts_lm can run + # concurrently on the main thread. The future is resolved after both + # tts_lm calls complete, before appending/streaming the audio chunk. + # Without the executor, the original sequential path is used unchanged. + if _vp_executor is not None: + _decode_future: concurrent.futures.Future[torch.Tensor] = _vp_executor.submit( + self.model.acoustic_tokenizer.decode, + scaled_latent.to(self.model.acoustic_tokenizer.device), + cache=acoustic_cache, + sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), + use_cache=True, + debug=False, + ) + else: + audio_chunk = self.model.acoustic_tokenizer.decode( + scaled_latent.to(self.model.acoustic_tokenizer.device), + cache=acoustic_cache, + sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), + use_cache=True, + debug=False, + ) + + # [VibePod] connector + tts_lm run here while decode is in the thread. + acoustic_embed = self.model.acoustic_connector(speech_latent) + tts_lm_input_ids = torch.cat( + [tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 + ) + + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + # [VibePod] Resolve before break so audio_chunks stays consistent. + if _vp_executor is not None: + audio_chunk = _decode_future.result() + for i, sample_idx in enumerate(diffusion_indices): + idx = sample_idx.item() + if not finished_tags[idx]: + audio_chunks[idx].append(audio_chunk[i]) + if audio_streamer is not None: + audio_streamer.put(audio_chunk, diffusion_indices) + break + + step += 1 + total_generated_speech_tokens += 1 + if progress_bar is not None: + progress_bar.update(1) + progress_bar.set_description( + f"Prefilled {total_prefilled_text_tokens} text tokens, " + f"generated {total_generated_speech_tokens} speech tokens, " + f"current step ({step} / {tts_lm_generation_config.max_length})" + ) + + tts_lm_model_inputs = self.prepare_inputs_for_generation( + tts_lm_input_ids, **tts_lm_model_kwargs + ) + tts_lm_additional_inputs = { + "tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]), + "lm_last_hidden_state": acoustic_embed, + } + tts_lm_outputs = self.forward_tts_lm( + **tts_lm_model_inputs, + **tts_lm_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + if cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0: + tts_lm_model_kwargs = _update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size + ) + else: + tts_lm_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False + ) + + tts_lm_negative_input_ids = torch.cat( + [tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 + ) + tts_lm_negative_model_inputs = self.prepare_inputs_for_generation( + tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs + ) + tts_lm_negative_additional_inputs = { + "tts_text_masks": torch.zeros_like(tts_lm_negative_input_ids[:, -1:]), + "lm_last_hidden_state": acoustic_embed, + } + tts_lm_negative_outputs = self.forward_tts_lm( + **tts_lm_negative_model_inputs, + **tts_lm_negative_additional_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( + tts_lm_negative_outputs, + tts_lm_negative_model_kwargs, + is_encoder_decoder=False, + ) + + # [VibePod] Decode is done (or was never async). Resolve future, + # then append + stream — moved here from before connector/tts_lm. + if _vp_executor is not None: + audio_chunk = _decode_future.result() + for i, sample_idx in enumerate(diffusion_indices): + idx = sample_idx.item() + if not finished_tags[idx]: + audio_chunks[idx].append(audio_chunk[i]) + if audio_streamer is not None: + audio_streamer.put(audio_chunk, diffusion_indices) + + tts_eos_logits = torch.sigmoid( + self.tts_eos_classifier( + tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] + ) + ) + if tts_eos_logits[0].item() > 0.5: + finished_tags[diffusion_indices] = True + if audio_streamer is not None: + audio_streamer.end(diffusion_indices) + + if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: + if verbose: + print( + f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." + ) + reached_samples = torch.arange(batch_size, device=device)[~finished_tags] + if reached_samples.numel() > 0: + reach_max_step_sample[reached_samples] = True + break + + if audio_streamer is not None: + audio_streamer.end() + + # ── Audio finalisation (unchanged from original) ───────────────────────── + final_audio_outputs = [] + for sample_chunks in audio_chunks: + if sample_chunks: + concatenated_audio = torch.cat(sample_chunks, dim=-1) + final_audio_outputs.append(concatenated_audio) + else: + final_audio_outputs.append(None) + + if reach_max_step_sample is not None and reach_max_step_sample.any(): + print( + f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." + ) + + return VibeVoiceGenerationOutput( + sequences=tts_lm_input_ids, + speech_outputs=final_audio_outputs if return_speech else None, + reach_max_step_sample=reach_max_step_sample, + ) + + +def install(model: object, executor: concurrent.futures.ThreadPoolExecutor) -> None: + """Install the patched generate() on a model instance and attach the executor.""" + model._vibepod_decode_executor = executor + model.generate = types.MethodType(patched_generate, model) diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 39541f5..14ccb36 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -20,12 +20,14 @@ Device selection: import asyncio import base64 +import concurrent.futures import copy import functools import importlib.util import json import logging import os +import platform import threading import time import types @@ -64,6 +66,10 @@ DEFAULT_SPEAKER = "carter" _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] +# ── Decode pipeline executor ──────────────────────────────────────────────────── + +_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None + # ── Device selection ──────────────────────────────────────────────────────────── # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. # Falls back to auto-detection if not set. @@ -108,6 +114,40 @@ def _env_float(name: str, default: float) -> float: return default +def _cpu_supports_bf16() -> bool: + """Return True if the CPU has AVX512_BF16 hardware support.""" + return ( + hasattr(torch, "cpu") + and hasattr(torch.cpu, "is_avx512_bf16_supported") + and torch.cpu.is_avx512_bf16_supported() + ) + + +def _configure_cpu_runtime() -> dict[str, object]: + logical_cpus = os.cpu_count() or 1 + default_threads = ( + max(1, logical_cpus // 2) if platform.system() == "Windows" else logical_cpus + ) + intra_threads = _env_int("VIBEPOD_CPU_THREADS", default_threads) + interop_threads = _env_int("VIBEPOD_CPU_INTEROP_THREADS", 1) + mkldnn_enabled = os.environ.get("VIBEPOD_CPU_MKLDNN", "1").strip() != "0" + + torch.set_num_threads(max(1, intra_threads)) + try: + torch.set_num_interop_threads(max(1, interop_threads)) + except RuntimeError as exc: + logger.warning("Could not set CPU inter-op threads: %s", exc) + + torch.backends.mkldnn.enabled = mkldnn_enabled + return { + "logical_cpus": logical_cpus, + "threads": torch.get_num_threads(), + "interop_threads": torch.get_num_interop_threads(), + "mkldnn_available": torch.backends.mkldnn.is_available(), + "mkldnn_enabled": torch.backends.mkldnn.enabled, + } + + # ── Global state ──────────────────────────────────────────────────────────────── ModelStatus = Literal["downloading", "loading", "online", "error"] @@ -228,12 +268,29 @@ def _init_model(device: str): torch.backends.cuda.mem_efficient_sdp_enabled(), torch.backends.cuda.math_sdp_enabled(), ) + elif device == "cpu": + torch.set_float32_matmul_precision("medium") + logger.info("CPU runtime configuration: %s", _configure_cpu_runtime()) cuda_dtype = os.environ.get("VIBEPOD_CUDA_DTYPE", "bf16").lower() if device == "cuda" and cuda_dtype == "fp16": load_dtype = torch.float16 + elif device == "cuda": + load_dtype = torch.bfloat16 else: - load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 + cpu_bf16_env = os.environ.get("VIBEPOD_CPU_BF16", "auto").lower() + if cpu_bf16_env == "1": + load_dtype = torch.bfloat16 + logger.info("CPU BF16 forced via VIBEPOD_CPU_BF16=1") + elif cpu_bf16_env == "0": + load_dtype = torch.float32 + logger.info("CPU float32 forced via VIBEPOD_CPU_BF16=0") + elif _cpu_supports_bf16(): + load_dtype = torch.bfloat16 + logger.info("AVX512_BF16 detected — loading model in bfloat16") + else: + load_dtype = torch.float32 + logger.info("No AVX512_BF16 — using float32 (set VIBEPOD_CPU_BF16=1 to override)") logger.info("Loading model weights with dtype %s", load_dtype) requested_attn_impl = os.environ.get("VIBEPOD_ATTN_IMPL", "auto").lower() has_flash_attn = importlib.util.find_spec("flash_attn") is not None @@ -274,8 +331,90 @@ def _init_model(device: str): ) model.eval() + if device == "cpu": + model = _apply_cpu_optimizations(model) model.set_ddpm_inference_steps(num_steps=_config["default_inference_steps"]) _install_generation_optimizations(model) + if device == "cpu": + # Must run after _install_generation_optimizations so the async wrapper + # sits outside the profiling wrapper (VibeVoice calls async → profiling → real decode). + _install_cpu_pipeline_optimizations(model) + return model + + +def _apply_cpu_optimizations(model: object) -> object: + """Apply optional post-load CPU optimizations. Returns (possibly new) model object.""" + + do_quantize = os.environ.get("VIBEPOD_QUANTIZE", "0") == "1" + do_compile = os.environ.get("VIBEPOD_COMPILE", "0") == "1" + + if do_quantize: + logger.info("Applying dynamic INT8 quantization to Linear layers...") + try: + import torch.ao.quantization + + # The diffusion prediction_head operates on small fixed-size tensors where + # INT8 pack/unpack overhead exceeds the matmul savings (~+20% regression in + # testing). Save and restore it so it stays in float32. + saved_prediction_head = None + if hasattr(model, "model") and hasattr(model.model, "prediction_head"): + saved_prediction_head = model.model.prediction_head + del model.model.prediction_head + + model = torch.ao.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) + + if saved_prediction_head is not None: + model.model.prediction_head = saved_prediction_head + logger.info( + "Dynamic INT8 quantization applied (prediction_head excluded — stays float32)." + ) + else: + logger.info("Dynamic INT8 quantization applied.") + except Exception as exc: + logger.warning("Dynamic quantization failed: %s — skipping", exc) + + if do_compile: + # torch.compile with inductor on CPU is ineffective for autoregressive TTS: + # each token step produces a unique input shape, so every step triggers a new + # kernel compile event rather than reusing compiled code. Kept as an escape + # hatch but not recommended. + compile_mode = os.environ.get("VIBEPOD_COMPILE_MODE", "reduce-overhead") + logger.info( + "torch.compile enabled (mode=%s) — NOTE: limited benefit for autoregressive" + " models on CPU due to dynamic sequence lengths.", + compile_mode, + ) + _compile_targets: list[tuple[str, object, str, bool]] = [ + ("forward_tts_lm", model, "forward_tts_lm", True), + ] + if hasattr(model, "model"): + inner = model.model + if hasattr(inner, "prediction_head"): + _compile_targets.append( + ("prediction_head", inner, "prediction_head", False) + ) + if hasattr(inner, "acoustic_tokenizer") and hasattr( + inner.acoustic_tokenizer, "decode" + ): + _compile_targets.append( + ("acoustic_tokenizer.decode", inner.acoustic_tokenizer, "decode", False) + ) + + for label, obj, attr, dynamic in _compile_targets: + try: + compiled = torch.compile( + getattr(obj, attr), + backend="inductor", + mode=compile_mode, + dynamic=dynamic, + ) + setattr(obj, attr, compiled) + logger.info(" compiled: %s", label) + except Exception as exc: + logger.warning(" torch.compile failed for %s: %s — skipping", label, exc) + return model @@ -403,6 +542,45 @@ def _install_generation_optimizations(model: object) -> None: logger.info("Installed VibeVoice generation hot-path optimizations.") +def _install_cpu_pipeline_optimizations(model: object) -> None: + """Install the async-decode generate() patch and its thread pool on the model instance. + + The VibeVoice inner loop runs: + decode(speech_latent) → append → put → connector → tts_lm(pos) → tts_lm(neg) + + connector and both tts_lm calls only need speech_latent/acoustic_embed, not + audio_chunk. The patched generate() reorders this to: + submit decode to thread → connector → tts_lm(pos) → tts_lm(neg) + → wait for decode future → append → put + + The patch is applied as an instance method via types.MethodType, which shadows + the class-level generate() and is immune to uv sync reinstalling the package. + """ + global _decode_executor + + if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1": + logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.") + return + + try: + import vibevoice_generate_patch + except ImportError: + logger.warning( + "vibevoice_generate_patch not found — async decode unavailable. " + "Ensure vibevoice_generate_patch.py is in the server directory." + ) + return + + _decode_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="vibepod-decode" + ) + vibevoice_generate_patch.install(model, _decode_executor) + logger.info( + "CPU pipeline: patched generate() installed (async decode enabled) — " + "acoustic_decode overlaps forward_tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0." + ) + + def _model_float_dtype() -> torch.dtype: try: return next(_model.parameters()).dtype @@ -469,6 +647,20 @@ def _load_model_sync() -> None: _config["rebuffer_threshold_secs"] = _env_float("VIBEPOD_REBUFFER_THRESHOLD_SECS", 1.5 if is_cpu else 1.0) _config["resume_threshold_secs"] = _env_float("VIBEPOD_RESUME_THRESHOLD_SECS", 4.0 if is_cpu else 3.0) _config["default_inference_steps"] = _env_int("VIBEPOD_DEFAULT_INFERENCE_STEPS", 8 if is_cpu else 10) + if is_cpu: + logical_cpus = os.cpu_count() or 1 + _config["cpu_threads"] = _env_int( + "VIBEPOD_CPU_THREADS", + max(1, logical_cpus // 2) + if platform.system() == "Windows" + else logical_cpus, + ) + _config["cpu_interop_threads"] = _env_int( + "VIBEPOD_CPU_INTEROP_THREADS", 1 + ) + _config["cpu_mkldnn"] = os.environ.get( + "VIBEPOD_CPU_MKLDNN", "1" + ).strip() != "0" _processor = _init_processor() _model = _init_model(_device) @@ -494,6 +686,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader") thread.start() yield + if _decode_executor is not None: + _decode_executor.shutdown(wait=False) app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)