From 98e2bf9237a07b2f79af43e3905e06e71bf8523f Mon Sep 17 00:00:00 2001 From: LyAhn Date: Thu, 30 Apr 2026 21:30:07 +0100 Subject: [PATCH] perf: migrate to JezzWTF/VibeVoice fork, parallel CFG executors Switch vibevoice dependency from microsoft/VibeVoice to JezzWTF/VibeVoice fork (commit e76701f) which contains the async decode + parallel CFG optimisations directly in generate(). Removes the instance-method patching approach (vibevoice_generate_patch.py deleted). server/vibevoice_server.py: - Add _cfg_executor (ThreadPoolExecutor, 1 worker) alongside _decode_executor - _install_cpu_pipeline_optimizations now sets both executors directly as model._vibepod_decode_executor and model._vibepod_cfg_executor - Both executors shut down in lifespan on exit - Remove vibevoice_generate_patch import/install (no longer needed) server/pyproject.toml: - vibevoice source changed to git+https://github.com/JezzWTF/VibeVoice.git - No machine-local paths; works identically on any clone --- server/pyproject.toml | 3 +- server/uv.lock | 16 +- server/vibevoice_generate_patch.py | 463 ----------------------------- server/vibevoice_server.py | 50 ++-- 4 files changed, 36 insertions(+), 496 deletions(-) delete mode 100644 server/vibevoice_generate_patch.py diff --git a/server/pyproject.toml b/server/pyproject.toml index 3756ed3..5099808 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -8,7 +8,8 @@ dependencies = [ # To switch back to CPU-only, remove the [tool.uv.sources] torch entry below. "torch>=2.0.0", # VibeVoice custom model + processor classes (not yet in upstream transformers) - "vibevoice @ git+https://github.com/microsoft/VibeVoice.git", + # Uses JezzWTF/VibeVoice fork so VibePod-specific optimisations land here. + "vibevoice @ git+https://github.com/JezzWTF/VibeVoice.git", # Exact version required by vibevoice's streaming TTS module "transformers==4.51.3", "fastapi>=0.111.0", diff --git a/server/uv.lock b/server/uv.lock index 7fc34c0..187530e 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -1479,7 +1479,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -1490,7 +1490,7 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, @@ -1509,9 +1509,9 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "nvidia-cusparse-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, @@ -1522,7 +1522,7 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, @@ -3109,13 +3109,13 @@ requires-dist = [ { name = "torch", specifier = ">=2.0.0", index = "https://download.pytorch.org/whl/cu124" }, { name = "transformers", specifier = "==4.51.3" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.29.0" }, - { name = "vibevoice", git = "https://github.com/microsoft/VibeVoice.git" }, + { name = "vibevoice", git = "https://github.com/JezzWTF/VibeVoice.git" }, ] [[package]] name = "vibevoice" version = "1.0.0" -source = { git = "https://github.com/microsoft/VibeVoice.git#e73d1e17c3754f046352014856a922f8208fb5d3" } +source = { git = "https://github.com/JezzWTF/VibeVoice.git#e76701f17a0d93cd612d56f0db5865a615c4d16e" } dependencies = [ { name = "absl-py" }, { name = "accelerate" }, diff --git a/server/vibevoice_generate_patch.py b/server/vibevoice_generate_patch.py deleted file mode 100644 index 825577d..0000000 --- a/server/vibevoice_generate_patch.py +++ /dev/null @@ -1,463 +0,0 @@ -""" -VibePod CPU pipeline optimisation — patched VibeVoice generate() loop. - -WHY THIS FILE EXISTS --------------------- -The VibeVoice inner speech-generation loop runs: - - decode(speech_latent) # 87 ms — VAE decode to audio waveform - audio_chunks.append(chunk) # store for final return value - audio_streamer.put(chunk) # stream to client - acoustic_connector(speech_latent) -> acoustic_embed # 1 ms - forward_tts_lm(acoustic_embed) # ~49 ms (positive) - forward_tts_lm(acoustic_embed) # ~49 ms (negative CFG) - -acoustic_connector and both forward_tts_lm calls depend only on speech_latent / -acoustic_embed — they are completely independent of the decoded audio waveform. -Running decode in a thread while connector + tts_lm run on the main thread hides -~87 ms of decode cost per token behind the ~99 ms of tts_lm work: - - Before: 87 + 1 + 49 + 49 = 186 ms / token - After: max(87, 1 + 49 + 49) = 99 ms / token (~47 % reduction) - -HOW IT WORKS ------------- -At model load time, _install_cpu_pipeline_optimizations() in vibevoice_server.py: - 1. Creates a single-worker ThreadPoolExecutor and attaches it to the model as - model._vibepod_decode_executor. - 2. Installs this module's `patched_generate` as a bound method on the model - instance via types.MethodType, shadowing the class-level generate(). - -Because the patch lives on the *instance*, uv sync reinstalling the VibeVoice -package has no effect — Python resolves instance attributes before class ones. - -MAINTENANCE ------------ -This is a verbatim copy of VibeVoice's generate() method (lines 574–910 of -modeling_vibevoice_streaming_inference.py) with the inner speech loop reordered. -The only changed region is marked with # [VibePod] comments. - -If VibeVoice updates its generate() method, diff the new version against this -file and merge carefully. The sentinel string "[VibePod]" marks every changed -line to make diffing easy. -""" - -import concurrent.futures -import types -from typing import Callable, List, Optional, Union - -import torch -from tqdm import tqdm -from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList -from transformers.modeling_utils import PreTrainedModel - -from vibevoice.modular.modeling_vibevoice_streaming_inference import ( - TTS_TEXT_WINDOW_SIZE, - TTS_SPEECH_WINDOW_SIZE, - VibeVoiceGenerationOutput, - _update_model_kwargs_for_generation, -) -from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache -from vibevoice.modular.streamer import AudioStreamer, AsyncAudioStreamer - - -def patched_generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - speech_tensors: Optional[torch.FloatTensor] = None, - speech_masks: Optional[torch.BoolTensor] = None, - speech_input_mask: Optional[torch.BoolTensor] = None, - tts_text_ids: Optional[torch.LongTensor] = None, - return_speech: bool = True, - cfg_scale: float = 1.0, - stop_check_fn: Optional[Callable[[], bool]] = None, - **kwargs, -) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]: - # ── Setup (unchanged from original) ───────────────────────────────────── - tokenizer = kwargs.pop("tokenizer", None) - neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") - - tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None) - tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None) - all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None) - tts_text_ids = tts_text_ids.to(self.device) - - if kwargs.get("max_new_tokens", None) is None: - kwargs["max_new_tokens"] = ( - self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1] - ) - - generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = ( - self._build_generate_config_model_kwargs( - generation_config, inputs, tokenizer, return_processors=True, **kwargs - ) - ) - - negative_kwargs = { - "input_ids": torch.full( - (kwargs["input_ids"].shape[0], 1), - neg_text_input_id, - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "attention_mask": torch.ones( - (kwargs["input_ids"].shape[0], 1), - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "max_new_tokens": kwargs.get("max_new_tokens", 100), - } - negative_generation_config, negative_model_kwargs, negative_input_ids = ( - self._build_generate_config_model_kwargs( - None, None, tokenizer, return_processors=False, **negative_kwargs - ) - ) - - tts_lm_kwargs = { - "input_ids": tts_lm_input_ids, - "attention_mask": tts_lm_attention_mask, - "max_new_tokens": kwargs.get("max_new_tokens", 100), - } - tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = ( - self._build_generate_config_model_kwargs( - None, None, tokenizer, return_processors=False, **tts_lm_kwargs - ) - ) - - tts_lm_negative_kwargs = { - "input_ids": torch.full( - (kwargs["input_ids"].shape[0], 1), - neg_text_input_id, - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "attention_mask": torch.ones( - (kwargs["input_ids"].shape[0], 1), - dtype=torch.long, - device=kwargs["input_ids"].device, - ), - "max_new_tokens": kwargs.get("max_new_tokens", 100), - } - tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids = ( - self._build_generate_config_model_kwargs( - None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs - ) - ) - - acoustic_cache = VibeVoiceTokenizerStreamingCache() - batch_size = input_ids.shape[0] - assert batch_size == 1, "Currently only supports batch size == 1" - device = input_ids.device - finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) - verbose = kwargs.get("verbose", False) - - audio_chunks = [[] for _ in range(batch_size)] - tts_text_window_index = 0 - reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) - first_text_window_size = ( - TTS_TEXT_WINDOW_SIZE - if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE - else tts_text_ids.shape[1] - ) - - outputs = all_prefilled_outputs["lm"] - tts_lm_outputs = all_prefilled_outputs["tts_lm"] - negative_outputs = all_prefilled_outputs["neg_lm"] - tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"] - - model_kwargs = _update_model_kwargs_for_generation( - outputs, model_kwargs, num_new_tokens=first_text_window_size - ) - tts_lm_model_kwargs = _update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size - ) - negative_model_kwargs = self._update_model_kwargs_for_generation( - negative_outputs, negative_model_kwargs, is_encoder_decoder=False - ) - tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False - ) - - step = tts_lm_input_ids.shape[1] - total_generated_speech_tokens = 0 - total_prefilled_text_tokens = 0 - if kwargs.get("show_progress_bar", True): - progress_bar = tqdm( - total=tts_lm_generation_config.max_length, - desc=f"Prefilled {step} tokens, current step ({step} / {tts_lm_generation_config.max_length})", - initial=step, - leave=False, - ) - else: - progress_bar = None - - # [VibePod] Grab the executor once; None means standard sequential path. - _vp_executor: Optional[concurrent.futures.ThreadPoolExecutor] = getattr( - self, "_vibepod_decode_executor", None - ) - - # ── Main generation loop (unchanged from original) ─────────────────────── - while True: - if stop_check_fn is not None and stop_check_fn(): - if verbose: - print(f"Generation stopped externally at step {step + 1}") - if audio_streamer is not None: - audio_streamer.end() - break - - if finished_tags.all(): - if hasattr(progress_bar, "set_description"): - progress_bar.set_description("Generation complete") - break - - cur_input_tts_text_ids = tts_text_ids[ - :, - tts_text_window_index * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 1) - * TTS_TEXT_WINDOW_SIZE, - ] - next_text_window_size = tts_text_ids[ - :, - (tts_text_window_index + 1) - * TTS_TEXT_WINDOW_SIZE : (tts_text_window_index + 2) - * TTS_TEXT_WINDOW_SIZE, - ].shape[1] - tts_text_window_index += 1 - - if cur_input_tts_text_ids.shape[1] > 0: - input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1) - tts_lm_input_ids = torch.cat([tts_lm_input_ids, cur_input_tts_text_ids], dim=-1) - - if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: - if verbose: - print( - f"Reached maximum generation length {generation_config.max_length}, stopped it." - ) - reached_samples = torch.arange(batch_size, device=device)[~finished_tags] - if reached_samples.numel() > 0: - reach_max_step_sample[reached_samples] = True - break - - step += cur_input_tts_text_ids.shape[1] - total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1] - if progress_bar is not None: - progress_bar.update(cur_input_tts_text_ids.shape[1]) - progress_bar.set_description( - f"Prefilled {total_prefilled_text_tokens} text tokens, " - f"generated {total_generated_speech_tokens} speech tokens, " - f"current step ({step} / {tts_lm_generation_config.max_length})" - ) - - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self.forward_lm( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - model_kwargs = _update_model_kwargs_for_generation( - outputs, model_kwargs, num_new_tokens=next_text_window_size - ) - - tts_lm_model_inputs = self.prepare_inputs_for_generation( - tts_lm_input_ids, **tts_lm_model_kwargs - ) - tts_lm_additional_inputs = { - "tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]), - "lm_last_hidden_state": outputs.last_hidden_state, - } - tts_lm_outputs = self.forward_tts_lm( - **tts_lm_model_inputs, - **tts_lm_additional_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - tts_lm_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False - ) - - diffusion_indices = torch.LongTensor([0]) - - # ── Inner speech loop ──────────────────────────────────────────────── - for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE): - positive_condition = tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] - negative_condition = tts_lm_negative_outputs.last_hidden_state[ - diffusion_indices, -1, : - ] - - speech_latent = self.sample_speech_tokens( - positive_condition, - negative_condition, - cfg_scale=cfg_scale, - ).unsqueeze(1) - - scaled_latent = ( - speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - - self.model.speech_bias_factor.to(speech_latent.device) - ) - - # [VibePod] If a decode executor is configured, submit decode to a - # background thread so acoustic_connector and forward_tts_lm can run - # concurrently on the main thread. The future is resolved after both - # tts_lm calls complete, before appending/streaming the audio chunk. - # Without the executor, the original sequential path is used unchanged. - if _vp_executor is not None: - _decode_future: concurrent.futures.Future[torch.Tensor] = _vp_executor.submit( - self.model.acoustic_tokenizer.decode, - scaled_latent.to(self.model.acoustic_tokenizer.device), - cache=acoustic_cache, - sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), - use_cache=True, - debug=False, - ) - else: - audio_chunk = self.model.acoustic_tokenizer.decode( - scaled_latent.to(self.model.acoustic_tokenizer.device), - cache=acoustic_cache, - sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), - use_cache=True, - debug=False, - ) - - # [VibePod] connector + tts_lm run here while decode is in the thread. - acoustic_embed = self.model.acoustic_connector(speech_latent) - tts_lm_input_ids = torch.cat( - [tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 - ) - - if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: - # [VibePod] Resolve before break so audio_chunks stays consistent. - if _vp_executor is not None: - audio_chunk = _decode_future.result() - for i, sample_idx in enumerate(diffusion_indices): - idx = sample_idx.item() - if not finished_tags[idx]: - audio_chunks[idx].append(audio_chunk[i]) - if audio_streamer is not None: - audio_streamer.put(audio_chunk, diffusion_indices) - break - - step += 1 - total_generated_speech_tokens += 1 - if progress_bar is not None: - progress_bar.update(1) - progress_bar.set_description( - f"Prefilled {total_prefilled_text_tokens} text tokens, " - f"generated {total_generated_speech_tokens} speech tokens, " - f"current step ({step} / {tts_lm_generation_config.max_length})" - ) - - tts_lm_model_inputs = self.prepare_inputs_for_generation( - tts_lm_input_ids, **tts_lm_model_kwargs - ) - tts_lm_additional_inputs = { - "tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]), - "lm_last_hidden_state": acoustic_embed, - } - tts_lm_outputs = self.forward_tts_lm( - **tts_lm_model_inputs, - **tts_lm_additional_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - if cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0: - tts_lm_model_kwargs = _update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size - ) - else: - tts_lm_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False - ) - - tts_lm_negative_input_ids = torch.cat( - [tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1 - ) - tts_lm_negative_model_inputs = self.prepare_inputs_for_generation( - tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs - ) - tts_lm_negative_additional_inputs = { - "tts_text_masks": torch.zeros_like(tts_lm_negative_input_ids[:, -1:]), - "lm_last_hidden_state": acoustic_embed, - } - tts_lm_negative_outputs = self.forward_tts_lm( - **tts_lm_negative_model_inputs, - **tts_lm_negative_additional_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation( - tts_lm_negative_outputs, - tts_lm_negative_model_kwargs, - is_encoder_decoder=False, - ) - - # [VibePod] Decode is done (or was never async). Resolve future, - # then append + stream — moved here from before connector/tts_lm. - if _vp_executor is not None: - audio_chunk = _decode_future.result() - for i, sample_idx in enumerate(diffusion_indices): - idx = sample_idx.item() - if not finished_tags[idx]: - audio_chunks[idx].append(audio_chunk[i]) - if audio_streamer is not None: - audio_streamer.put(audio_chunk, diffusion_indices) - - tts_eos_logits = torch.sigmoid( - self.tts_eos_classifier( - tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :] - ) - ) - if tts_eos_logits[0].item() > 0.5: - finished_tags[diffusion_indices] = True - if audio_streamer is not None: - audio_streamer.end(diffusion_indices) - - if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length: - if verbose: - print( - f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." - ) - reached_samples = torch.arange(batch_size, device=device)[~finished_tags] - if reached_samples.numel() > 0: - reach_max_step_sample[reached_samples] = True - break - - if audio_streamer is not None: - audio_streamer.end() - - # ── Audio finalisation (unchanged from original) ───────────────────────── - final_audio_outputs = [] - for sample_chunks in audio_chunks: - if sample_chunks: - concatenated_audio = torch.cat(sample_chunks, dim=-1) - final_audio_outputs.append(concatenated_audio) - else: - final_audio_outputs.append(None) - - if reach_max_step_sample is not None and reach_max_step_sample.any(): - print( - f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it." - ) - - return VibeVoiceGenerationOutput( - sequences=tts_lm_input_ids, - speech_outputs=final_audio_outputs if return_speech else None, - reach_max_step_sample=reach_max_step_sample, - ) - - -def install(model: object, executor: concurrent.futures.ThreadPoolExecutor) -> None: - """Install the patched generate() on a model instance and attach the executor.""" - model._vibepod_decode_executor = executor - model.generate = types.MethodType(patched_generate, model) diff --git a/server/vibevoice_server.py b/server/vibevoice_server.py index 14ccb36..2516a59 100644 --- a/server/vibevoice_server.py +++ b/server/vibevoice_server.py @@ -66,9 +66,12 @@ DEFAULT_SPEAKER = "carter" _IGNORE_PATTERNS = ["*.msgpack", "flax_model*", "tf_model*", "rust_model*", "*.ot"] -# ── Decode pipeline executor ──────────────────────────────────────────────────── +# ── Pipeline executors ───────────────────────────────────────────────────────── +# _decode_executor: overlaps acoustic_decode with forward_tts_lm (1 worker). +# _cfg_executor: runs positive + negative forward_tts_lm in parallel (1 worker). _decode_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None +_cfg_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None # ── Device selection ──────────────────────────────────────────────────────────── # VIBEPOD_DEVICE env var is set by start.sh based on the --cpu / --cuda flag. @@ -543,41 +546,38 @@ def _install_generation_optimizations(model: object) -> None: def _install_cpu_pipeline_optimizations(model: object) -> None: - """Install the async-decode generate() patch and its thread pool on the model instance. + """Attach pipeline executors to the model for the optimised generate() loop. - The VibeVoice inner loop runs: - decode(speech_latent) → append → put → connector → tts_lm(pos) → tts_lm(neg) + The JezzWTF/VibeVoice fork's generate() checks for two optional attributes: - connector and both tts_lm calls only need speech_latent/acoustic_embed, not - audio_chunk. The patched generate() reorders this to: - submit decode to thread → connector → tts_lm(pos) → tts_lm(neg) - → wait for decode future → append → put + model._vibepod_decode_executor — ThreadPoolExecutor (1 worker) used to + overlap acoustic_decode with acoustic_connector + forward_tts_lm. - The patch is applied as an instance method via types.MethodType, which shadows - the class-level generate() and is immune to uv sync reinstalling the package. + model._vibepod_cfg_executor — ThreadPoolExecutor (1 worker) used to + run the positive and negative forward_tts_lm calls in parallel, so + both CFG passes execute concurrently instead of sequentially. + + Both are None by default, making the fork's generate() behave identically + to upstream on CUDA or any machine where these aren't set. """ - global _decode_executor + global _decode_executor, _cfg_executor if os.environ.get("VIBEPOD_ASYNC_DECODE", "1") != "1": - logger.info("CPU async decode disabled via VIBEPOD_ASYNC_DECODE=0.") - return - - try: - import vibevoice_generate_patch - except ImportError: - logger.warning( - "vibevoice_generate_patch not found — async decode unavailable. " - "Ensure vibevoice_generate_patch.py is in the server directory." - ) + logger.info("CPU async decode/CFG parallelism disabled via VIBEPOD_ASYNC_DECODE=0.") return _decode_executor = concurrent.futures.ThreadPoolExecutor( max_workers=1, thread_name_prefix="vibepod-decode" ) - vibevoice_generate_patch.install(model, _decode_executor) + _cfg_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="vibepod-cfg" + ) + model._vibepod_decode_executor = _decode_executor + model._vibepod_cfg_executor = _cfg_executor logger.info( - "CPU pipeline: patched generate() installed (async decode enabled) — " - "acoustic_decode overlaps forward_tts_lm. Disable with VIBEPOD_ASYNC_DECODE=0." + "CPU pipeline: decode executor and CFG executor attached — " + "acoustic_decode overlaps tts_lm, pos/neg CFG runs in parallel. " + "Disable with VIBEPOD_ASYNC_DECODE=0." ) @@ -688,6 +688,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: yield if _decode_executor is not None: _decode_executor.shutdown(wait=False) + if _cfg_executor is not None: + _cfg_executor.shutdown(wait=False) app = FastAPI(title="VibePod TTS Server", version="0.1.0", lifespan=lifespan)