perf: migrate to JezzWTF/VibeVoice fork, parallel CFG executors

Switch vibevoice dependency from microsoft/VibeVoice to JezzWTF/VibeVoice
fork (commit e76701f) which contains the async decode + parallel CFG
optimisations directly in generate(). Removes the instance-method
patching approach (vibevoice_generate_patch.py deleted).

server/vibevoice_server.py:
- Add _cfg_executor (ThreadPoolExecutor, 1 worker) alongside _decode_executor
- _install_cpu_pipeline_optimizations now sets both executors directly as
  model._vibepod_decode_executor and model._vibepod_cfg_executor
- Both executors shut down in lifespan on exit
- Remove vibevoice_generate_patch import/install (no longer needed)

server/pyproject.toml:
- vibevoice source changed to git+https://github.com/JezzWTF/VibeVoice.git
- No machine-local paths; works identically on any clone
This commit is contained in:
2026-04-30 21:30:07 +01:00
parent 7591d15a52
commit 98e2bf9237
4 changed files with 36 additions and 496 deletions
+2 -1
View File
@@ -8,7 +8,8 @@ dependencies = [
# To switch back to CPU-only, remove the [tool.uv.sources] torch entry below.
"torch>=2.0.0",
# VibeVoice custom model + processor classes (not yet in upstream transformers)
"vibevoice @ git+https://github.com/microsoft/VibeVoice.git",
# Uses JezzWTF/VibeVoice fork so VibePod-specific optimisations land here.
"vibevoice @ git+https://github.com/JezzWTF/VibeVoice.git",
# Exact version required by vibevoice's streaming TTS module
"transformers==4.51.3",
"fastapi>=0.111.0",
+8 -8
View File
@@ -1479,7 +1479,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12" },
{ name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -1490,7 +1490,7 @@ name = "nvidia-cufft-cu12"
version = "11.2.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
@@ -1509,9 +1509,9 @@ name = "nvidia-cusolver-cu12"
version = "11.6.1.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12" },
{ name = "nvidia-cusparse-cu12" },
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
{ name = "nvidia-cusparse-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
@@ -1522,7 +1522,7 @@ name = "nvidia-cusparse-cu12"
version = "12.3.1.170"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12" },
{ name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
@@ -3109,13 +3109,13 @@ requires-dist = [
{ name = "torch", specifier = ">=2.0.0", index = "https://download.pytorch.org/whl/cu124" },
{ name = "transformers", specifier = "==4.51.3" },
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.29.0" },
{ name = "vibevoice", git = "https://github.com/microsoft/VibeVoice.git" },
{ name = "vibevoice", git = "https://github.com/JezzWTF/VibeVoice.git" },
]
[[package]]
name = "vibevoice"
version = "1.0.0"
source = { git = "https://github.com/microsoft/VibeVoice.git#e73d1e17c3754f046352014856a922f8208fb5d3" }
source = { git = "https://github.com/JezzWTF/VibeVoice.git#e76701f17a0d93cd612d56f0db5865a615c4d16e" }
dependencies = [
{ name = "absl-py" },
{ name = "accelerate" },
-463
View File
@@ -1,463 +0,0 @@
"""
VibePod CPU pipeline optimisation — patched VibeVoice generate() loop.
WHY THIS FILE EXISTS
--------------------
The VibeVoice inner speech-generation loop runs:
decode(speech_latent) # 87 ms — VAE decode to audio waveform
audio_chunks.append(chunk) # store for final return value
audio_streamer.put(chunk) # stream to client
acoustic_connector(speech_latent) -> acoustic_embed # 1 ms
forward_tts_lm(acoustic_embed) # ~49 ms (positive)
forward_tts_lm(acoustic_embed) # ~49 ms (negative CFG)
acoustic_connector and both forward_tts_lm calls depend only on speech_latent /
acoustic_embed — they are completely independent of the decoded audio waveform.
Running decode in a thread while connector + tts_lm run on the main thread hides
~87 ms of decode cost per token behind the ~99 ms of tts_lm work:
Before: 87 + 1 + 49 + 49 = 186 ms / token
After: max(87, 1 + 49 + 49) = 99 ms / token (~47 % reduction)
HOW IT WORKS
------------
At model load time, _install_cpu_pipeline_optimizations() in vibevoice_server.py:
1. Creates a single-worker ThreadPoolExecutor and attaches it to the model as
model._vibepod_decode_executor.
2. Installs this module's `patched_generate` as a bound method on the model
instance via types.MethodType, shadowing the class-level generate().
Because the patch lives on the *instance*, uv sync reinstalling the VibeVoice
package has no effect — Python resolves instance attributes before class ones.
MAINTENANCE
-----------
This is a verbatim copy of VibeVoice's generate() method (lines 574910 of
modeling_vibevoice_streaming_inference.py) with the inner speech loop reordered.
The only changed region is marked with # [VibePod] comments.
If VibeVoice updates its generate() method, diff the new version against this
file and merge carefully. The sentinel string "[VibePod]" marks every changed
line to make diffing easy.
"""
import concurrent.futures
import types
from typing import Callable, List, Optional, Union
import torch
from tqdm import tqdm
from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_utils import PreTrainedModel
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
TTS_TEXT_WINDOW_SIZE,
TTS_SPEECH_WINDOW_SIZE,
VibeVoiceGenerationOutput,
_update_model_kwargs_for_generation,
)
from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache
from vibevoice.modular.streamer import AudioStreamer, AsyncAudioStreamer
def patched_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
tts_text_ids: Optional[torch.LongTensor] = None,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
**kwargs,
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
# ── Setup (unchanged from original) ─────────────────────────────────────
tokenizer = kwargs.pop("tokenizer", None)
neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>")
tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None)
tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None)
all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None)
tts_text_ids = tts_text_ids.to(self.device)
if kwargs.get("max_new_tokens", None) is None:
kwargs["max_new_tokens"] = (
self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1]
)
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = (
self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
)
negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
negative_generation_config, negative_model_kwargs, negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
)
tts_lm_kwargs = {
"input_ids": tts_lm_input_ids,
"attention_mask": tts_lm_attention_mask,
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_kwargs
)
)
tts_lm_negative_kwargs = {
"input_ids": torch.full(
(kwargs["input_ids"].shape[0], 1),
neg_text_input_id,
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"attention_mask": torch.ones(
(kwargs["input_ids"].shape[0], 1),
dtype=torch.long,
device=kwargs["input_ids"].device,
),
"max_new_tokens": kwargs.get("max_new_tokens", 100),
}
tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids = (
self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs
)
)
acoustic_cache = VibeVoiceTokenizerStreamingCache()
batch_size = input_ids.shape[0]
assert batch_size == 1, "Currently only supports batch size == 1"
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
verbose = kwargs.get("verbose", False)
audio_chunks = [[] for _ in range(batch_size)]
tts_text_window_index = 0
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
first_text_window_size = (
TTS_TEXT_WINDOW_SIZE
if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE
else tts_text_ids.shape[1]
)
outputs = all_prefilled_outputs["lm"]
tts_lm_outputs = all_prefilled_outputs["tts_lm"]
negative_outputs = all_prefilled_outputs["neg_lm"]
tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"]
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=first_text_window_size
)
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False
)
step = tts_lm_input_ids.shape[1]
total_generated_speech_tokens = 0
total_prefilled_text_tokens = 0
if kwargs.get("show_progress_bar", True):
progress_bar = tqdm(
total=tts_lm_generation_config.max_length,
desc=f"Prefilled {step} tokens, current step ({step} / {tts_lm_generation_config.max_length})",
initial=step,
leave=False,
)
else:
progress_bar = None
# [VibePod] Grab the executor once; None means standard sequential path.
_vp_executor: Optional[concurrent.futures.ThreadPoolExecutor] = getattr(
self, "_vibepod_decode_executor", None
)
# ── Main generation loop (unchanged from original) ───────────────────────
while True:
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step + 1}")
if audio_streamer is not None:
audio_streamer.end()
break
if finished_tags.all():
if hasattr(progress_bar, "set_description"):
progress_bar.set_description("Generation complete")
break
cur_input_tts_text_ids = tts_text_ids[
:,
tts_text_window_index * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE,
]
next_text_window_size = tts_text_ids[
:,
(tts_text_window_index + 1)
* TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2)
* TTS_TEXT_WINDOW_SIZE,
].shape[1]
tts_text_window_index += 1
if cur_input_tts_text_ids.shape[1] > 0:
input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1)
tts_lm_input_ids = torch.cat([tts_lm_input_ids, cur_input_tts_text_ids], dim=-1)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {generation_config.max_length}, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
step += cur_input_tts_text_ids.shape[1]
total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1]
if progress_bar is not None:
progress_bar.update(cur_input_tts_text_ids.shape[1])
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens} text tokens, "
f"generated {total_generated_speech_tokens} speech tokens, "
f"current step ({step} / {tts_lm_generation_config.max_length})"
)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self.forward_lm(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=next_text_window_size
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": outputs.last_hidden_state,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
diffusion_indices = torch.LongTensor([0])
# ── Inner speech loop ────────────────────────────────────────────────
for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE):
positive_condition = tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]
negative_condition = tts_lm_negative_outputs.last_hidden_state[
diffusion_indices, -1, :
]
speech_latent = self.sample_speech_tokens(
positive_condition,
negative_condition,
cfg_scale=cfg_scale,
).unsqueeze(1)
scaled_latent = (
speech_latent / self.model.speech_scaling_factor.to(speech_latent.device)
- self.model.speech_bias_factor.to(speech_latent.device)
)
# [VibePod] If a decode executor is configured, submit decode to a
# background thread so acoustic_connector and forward_tts_lm can run
# concurrently on the main thread. The future is resolved after both
# tts_lm calls complete, before appending/streaming the audio chunk.
# Without the executor, the original sequential path is used unchanged.
if _vp_executor is not None:
_decode_future: concurrent.futures.Future[torch.Tensor] = _vp_executor.submit(
self.model.acoustic_tokenizer.decode,
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
use_cache=True,
debug=False,
)
else:
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache,
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
use_cache=True,
debug=False,
)
# [VibePod] connector + tts_lm run here while decode is in the thread.
acoustic_embed = self.model.acoustic_connector(speech_latent)
tts_lm_input_ids = torch.cat(
[tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1
)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
# [VibePod] Resolve before break so audio_chunks stays consistent.
if _vp_executor is not None:
audio_chunk = _decode_future.result()
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
break
step += 1
total_generated_speech_tokens += 1
if progress_bar is not None:
progress_bar.update(1)
progress_bar.set_description(
f"Prefilled {total_prefilled_text_tokens} text tokens, "
f"generated {total_generated_speech_tokens} speech tokens, "
f"current step ({step} / {tts_lm_generation_config.max_length})"
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(
tts_lm_input_ids, **tts_lm_model_kwargs
)
tts_lm_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs,
**tts_lm_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0:
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size
)
else:
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False
)
tts_lm_negative_input_ids = torch.cat(
[tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1
)
tts_lm_negative_model_inputs = self.prepare_inputs_for_generation(
tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs
)
tts_lm_negative_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_negative_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_negative_outputs = self.forward_tts_lm(
**tts_lm_negative_model_inputs,
**tts_lm_negative_additional_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs,
tts_lm_negative_model_kwargs,
is_encoder_decoder=False,
)
# [VibePod] Decode is done (or was never async). Resolve future,
# then append + stream — moved here from before connector/tts_lm.
if _vp_executor is not None:
audio_chunk = _decode_future.result()
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
if audio_streamer is not None:
audio_streamer.put(audio_chunk, diffusion_indices)
tts_eos_logits = torch.sigmoid(
self.tts_eos_classifier(
tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]
)
)
if tts_eos_logits[0].item() > 0.5:
finished_tags[diffusion_indices] = True
if audio_streamer is not None:
audio_streamer.end(diffusion_indices)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(
f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it."
)
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
if audio_streamer is not None:
audio_streamer.end()
# ── Audio finalisation (unchanged from original) ─────────────────────────
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
final_audio_outputs.append(None)
if reach_max_step_sample is not None and reach_max_step_sample.any():
print(
f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it."
)
return VibeVoiceGenerationOutput(
sequences=tts_lm_input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
def install(model: object, executor: concurrent.futures.ThreadPoolExecutor) -> None:
"""Install the patched generate() on a model instance and attach the executor."""
model._vibepod_decode_executor = executor
model.generate = types.MethodType(patched_generate, model)
+26 -24
View File
@@ -66,9 +66,12 @@ DEFAULT_SPEAKER = "carter"
_IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"]
# ── Decode pipeline executor ────────────────────────────────────────────────────
# ── Pipeline executors ─────────────────────────────────────────────────────────
# _decode_executor: overlaps acoustic_decode with forward_tts_lm (1 worker).
# _cfg_executor: runs positive + negative forward_tts_lm in parallel (1 worker).
_decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
_cfg_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
# ── Device selection ────────────────────────────────────────────────────────────
# VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag.
@@ -543,41 +546,38 @@ def _install_generation_optimizations(model: object) -> None:
def _install_cpu_pipeline_optimizations(model: object) -> None:
"""Install the async-decode generate() patch and its thread pool on the model instance.
"""Attach pipeline executors to the model for the optimised generate() loop.
The VibeVoice inner loop runs:
decode(speech_latent) → append → put → connector → tts_lm(pos) → tts_lm(neg)
The JezzWTF/VibeVoice fork's generate() checks for two optional attributes:
connector and both tts_lm calls only need speech_latent/acoustic_embed, not
audio_chunk. The patched generate() reorders this to:
submit decode to thread → connector → tts_lm(pos) → tts_lm(neg)
→ wait for decode future → append → put
model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) used to
overlap acoustic_decode with acoustic_connector + forward_tts_lm.
The patch is applied as an instance method via types.MethodType, which shadows
the class-level generate() and is immune to uv sync reinstalling the package.
model._vibepod_cfg_executor — ThreadPoolExecutor (1 worker) used to
run the positive and negative forward_tts_lm calls in parallel, so
both CFG passes execute concurrently instead of sequentially.
Both are None by default, making the fork's generate() behave identically
to upstream on CUDA or any machine where these aren't set.
"""
global _decode_executor
global _decode_executor, _cfg_executor
if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1":
logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.")
return
try:
import vibevoice_generate_patch
except ImportError:
logger.warning(
"vibevoice_generate_patch not found — async decode unavailable. "
"Ensure vibevoice_generate_patch.py is in the server directory."
)
logger.info("CPU async decode/CFG parallelism disabled via VIBEPOD_ASYNC_DECODE=0.")
return
_decode_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="vibepod-decode"
)
vibevoice_generate_patch.install(model, _decode_executor)
_cfg_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="vibepod-cfg"
)
model._vibepod_decode_executor = _decode_executor
model._vibepod_cfg_executor = _cfg_executor
logger.info(
"CPU pipeline: patched generate() installed (async decode enabled)"
"acoustic_decode overlaps forward_tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0."
"CPU pipeline: decode executor and CFG executor attached — "
"acoustic_decode overlaps tts_lm, pos/neg CFG runs in parallel. "
"Disable with VIBEPOD_ASYNC_DECODE=0."
)
@@ -688,6 +688,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
if _decode_executor is not None:
_decode_executor.shutdown(wait=False)
if _cfg_executor is not None:
_cfg_executor.shutdown(wait=False)
app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)