mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-13 03:58:07 +00:00
feat(phase-1): persistent generation library
- Save every completed generation to SQLite (generation_store.py) with WAV and waveform peaks written to data/generations/<id>/ - Deferred DB write until success — cancelled/errored generations never touch the DB and never appear in the library - Fixed cancel+regenerate IndexError: _reset_scheduler_caches() now directly zeros scheduler._step_index and running state in addition to clearing VibePod cache dicts; same explicit resets added in the fresh path of prepare_noise_scheduler as belt-and-suspenders - Added /library page with GenerationCard, WaveformPreview, waveform fetch, play/pause, download, delete, pagination, empty + error states - Added generation API routes (list, single, audio stream, waveform, delete) proxying to Python server - Added Library nav link to Header with active state - Persist script/speaker/CFG to localStorage so generate page state survives navigation - Updated build plan: Phase 0+1 ticked off, better-sqlite3 moved to Phase 2, architectural note on Python owning all persistence
This commit is contained in:
+178
-5
@@ -22,6 +22,7 @@ import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import struct
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
@@ -37,10 +38,16 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
import generation_store
|
||||
import ids
|
||||
import waveform as waveform_module
|
||||
from tqdm import tqdm as _BaseTqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
@@ -49,6 +56,32 @@ logger = logging.getLogger(__name__)
|
||||
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
||||
SAMPLE_RATE = 24_000
|
||||
|
||||
|
||||
def _write_float32_wav(path: Path, samples: np.ndarray, sample_rate: int) -> None:
|
||||
"""Write a mono float32 WAV without relying on libsndfile.
|
||||
|
||||
Uses the same RIFF/IEEE-float layout as the browser's buildWav(), so the
|
||||
file is playable by anything that understands IEEE-float WAV (codec tag 3).
|
||||
"""
|
||||
flat = samples.flatten().astype(np.float32)
|
||||
data = flat.tobytes()
|
||||
data_size = len(data)
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"RIFF")
|
||||
f.write(struct.pack("<I", 36 + data_size))
|
||||
f.write(b"WAVE")
|
||||
f.write(b"fmt ")
|
||||
f.write(struct.pack("<I", 16)) # fmt chunk size
|
||||
f.write(struct.pack("<H", 3)) # codec: IEEE float
|
||||
f.write(struct.pack("<H", 1)) # channels: mono
|
||||
f.write(struct.pack("<I", sample_rate))
|
||||
f.write(struct.pack("<I", sample_rate * 4)) # byte rate
|
||||
f.write(struct.pack("<H", 4)) # block align
|
||||
f.write(struct.pack("<H", 32)) # bits per sample
|
||||
f.write(b"data")
|
||||
f.write(struct.pack("<I", data_size))
|
||||
f.write(data)
|
||||
|
||||
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
||||
VOICE_BASE_URL = (
|
||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
||||
@@ -160,6 +193,39 @@ _voice_presets: dict[str, object] = {}
|
||||
_load_lock = threading.Lock()
|
||||
_generation_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _reset_scheduler_caches() -> None:
|
||||
"""Clear VibePod scheduler caches and reset all scheduler running state.
|
||||
|
||||
Called on every cancel/timeout path so the next generation starts from a
|
||||
completely clean slate. We do two things:
|
||||
|
||||
1. Clear the VibePod cache dicts so prepare_noise_scheduler takes the fresh
|
||||
path and calls set_timesteps(), which re-initialises sigmas/timesteps.
|
||||
|
||||
2. Directly zero out the scheduler's running counters (_step_index,
|
||||
model_outputs, lower_order_nums, _begin_index). This is belt-and-
|
||||
suspenders: VibeVoice's set_timesteps() *does* reset these fields, but
|
||||
if a cancelled thread left _step_index=N and the new generation's
|
||||
_init_step_index guard (``if self.step_index is None``) sees a non-None
|
||||
value it skips initialisation entirely, causing an out-of-bounds access
|
||||
on sigmas[step_index + 1] at the very first step.
|
||||
"""
|
||||
if _model is None:
|
||||
return
|
||||
for attr in ("_vibepod_scheduler_cache", "_vibepod_t_batch_cache"):
|
||||
if hasattr(_model, attr):
|
||||
setattr(_model, attr, {})
|
||||
try:
|
||||
scheduler = _model.model.noise_scheduler
|
||||
scheduler._step_index = None
|
||||
scheduler._begin_index = None
|
||||
scheduler.model_outputs = [None] * scheduler.config.solver_order
|
||||
scheduler.lower_order_nums = 0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Config defaults (can be overridden by env vars)
|
||||
# These are populated in _load_model_sync once the device is known.
|
||||
_config = {
|
||||
@@ -446,6 +512,14 @@ def _install_generation_optimizations(model: object) -> None:
|
||||
|
||||
if cached is None:
|
||||
scheduler.set_timesteps(self.ddpm_inference_steps)
|
||||
# Belt-and-suspenders: explicitly reset running state even though
|
||||
# set_timesteps() should do it, because a prior cancelled generation
|
||||
# may have left _step_index non-None, causing _init_step_index to
|
||||
# be skipped and triggering an out-of-bounds access in step().
|
||||
scheduler._step_index = None
|
||||
scheduler._begin_index = None
|
||||
scheduler.model_outputs = [None] * scheduler.config.solver_order
|
||||
scheduler.lower_order_nums = 0
|
||||
cached = {
|
||||
"num_inference_steps": scheduler.num_inference_steps,
|
||||
"timesteps": scheduler.timesteps,
|
||||
@@ -664,6 +738,7 @@ def _load_model_sync() -> None:
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
generation_store.init_db()
|
||||
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
||||
thread.start()
|
||||
yield
|
||||
@@ -839,12 +914,14 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
job_id = ids.gen_id()
|
||||
start = time.monotonic()
|
||||
streamer = NonBlockingAudioStreamer(batch_size=1)
|
||||
cancel_event = threading.Event()
|
||||
|
||||
accum_size = max(1, _config["chunk_accum"])
|
||||
accumulated_chunks = []
|
||||
all_save_chunks: list[torch.Tensor] = []
|
||||
chunk_count = 0
|
||||
audio_samples = 0
|
||||
first_chunk_at: float | None = None
|
||||
@@ -866,14 +943,22 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
|
||||
except Exception:
|
||||
pass
|
||||
_reset_scheduler_caches()
|
||||
yield _sse({"type": "error", "message": "Generation timed out"})
|
||||
return
|
||||
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
logger.info("Generation client disconnected; stream cancelled.")
|
||||
logger.info("Client disconnected; waiting for inference thread to stop.")
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
|
||||
except Exception:
|
||||
pass
|
||||
_reset_scheduler_caches()
|
||||
return
|
||||
|
||||
if chunk is None: # stop signal
|
||||
@@ -895,6 +980,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
.to("cpu", dtype=torch.float32)
|
||||
.contiguous()
|
||||
)
|
||||
all_save_chunks.append(combined)
|
||||
chunk_count += 1
|
||||
audio_samples += combined.numel()
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
@@ -916,6 +1002,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
.to("cpu", dtype=torch.float32)
|
||||
.contiguous()
|
||||
)
|
||||
all_save_chunks.append(combined)
|
||||
chunk_count += 1
|
||||
audio_samples += combined.numel()
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
@@ -924,7 +1011,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
try:
|
||||
speaker = await future
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generation cancelled.")
|
||||
cancel_event.set()
|
||||
logger.info("Generation cancelled; waiting for inference thread to stop.")
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
|
||||
except Exception:
|
||||
pass
|
||||
_reset_scheduler_caches()
|
||||
yield _sse({"type": "cancelled"})
|
||||
return
|
||||
except Exception as exc:
|
||||
@@ -944,8 +1037,38 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if profile is not None:
|
||||
logger.info("Generation profile: %s", profile)
|
||||
logger.info("Generation complete in %.1fs", elapsed)
|
||||
|
||||
# Persist audio and waveform peaks after streaming is done.
|
||||
audio_path: str | None = None
|
||||
waveform_path: str | None = None
|
||||
try:
|
||||
out_dir = generation_store.job_dir(job_id)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
wav_path = out_dir / "audio.wav"
|
||||
peaks_path = out_dir / "waveform.json"
|
||||
if all_save_chunks:
|
||||
all_audio = torch.cat(all_save_chunks).numpy()
|
||||
_write_float32_wav(wav_path, all_audio, SAMPLE_RATE)
|
||||
waveform_module.write_peaks(wav_path, peaks_path)
|
||||
audio_path = str(wav_path)
|
||||
waveform_path = str(peaks_path)
|
||||
generation_store.save_completed_job(
|
||||
job_id,
|
||||
script=req.text,
|
||||
speaker=speaker,
|
||||
cfg_scale=req.cfg_scale,
|
||||
inference_steps=req.inference_steps,
|
||||
duration_secs=audio_secs,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
audio_path=audio_path or "",
|
||||
waveform_path=waveform_path or "",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to persist generation %s", job_id)
|
||||
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"job_id": job_id,
|
||||
"elapsed": elapsed,
|
||||
"speaker": speaker,
|
||||
"audio_secs": round(audio_secs, 2),
|
||||
@@ -969,3 +1092,53 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── Generation library endpoints ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@app.get("/generations")
|
||||
async def list_generations(limit: int = 50, offset: int = 0) -> dict:
|
||||
jobs = generation_store.list_jobs(limit=min(limit, 200), offset=offset)
|
||||
return {"items": jobs, "limit": limit, "offset": offset}
|
||||
|
||||
|
||||
@app.get("/generations/{job_id}")
|
||||
async def get_generation(job_id: str) -> dict:
|
||||
job = generation_store.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return job
|
||||
|
||||
|
||||
@app.get("/generations/{job_id}/audio")
|
||||
async def get_generation_audio(job_id: str) -> FileResponse:
|
||||
job = generation_store.get_job(job_id)
|
||||
if not job or not job.get("audio_path"):
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
audio_path = Path(job["audio_path"])
|
||||
if not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file missing from disk")
|
||||
return FileResponse(
|
||||
str(audio_path),
|
||||
media_type="audio/wav",
|
||||
filename=f"{job_id}.wav",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/generations/{job_id}/waveform")
|
||||
async def get_generation_waveform(job_id: str) -> dict:
|
||||
job = generation_store.get_job(job_id)
|
||||
if not job or not job.get("waveform_path"):
|
||||
raise HTTPException(status_code=404, detail="Waveform not found")
|
||||
peaks_path = Path(job["waveform_path"])
|
||||
if not peaks_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Waveform file missing from disk")
|
||||
return json.loads(peaks_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
@app.delete("/generations/{job_id}", status_code=204)
|
||||
async def delete_generation(job_id: str) -> None:
|
||||
deleted = generation_store.delete_job(job_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
Reference in New Issue
Block a user