feat(phase-1): persistent generation library

- Save every completed generation to SQLite (generation_store.py) with
  WAV and waveform peaks written to data/generations/<id>/
- Deferred DB write until success — cancelled/errored generations never
  touch the DB and never appear in the library
- Fixed cancel+regenerate IndexError: _reset_scheduler_caches() now
  directly zeros scheduler._step_index and running state in addition to
  clearing VibePod cache dicts; same explicit resets added in the fresh
  path of prepare_noise_scheduler as belt-and-suspenders
- Added /library page with GenerationCard, WaveformPreview, waveform
  fetch, play/pause, download, delete, pagination, empty + error states
- Added generation API routes (list, single, audio stream, waveform,
  delete) proxying to Python server
- Added Library nav link to Header with active state
- Persist script/speaker/CFG to localStorage so generate page state
  survives navigation
- Updated build plan: Phase 0+1 ticked off, better-sqlite3 moved to
  Phase 2, architectural note on Python owning all persistence
This commit is contained in:
2026-05-02 23:05:11 +01:00
parent 47e0c7e512
commit 13085166fb
13 changed files with 913 additions and 29 deletions
+178 -5
View File
@@ -22,6 +22,7 @@ import asyncio
import base64
import concurrent.futures
import copy
import struct
import functools
import importlib.util
import json
@@ -37,10 +38,16 @@ from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal
import numpy as np
import soundfile as sf
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel, Field, field_validator
import generation_store
import ids
import waveform as waveform_module
from tqdm import tqdm as _BaseTqdm
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
@@ -49,6 +56,32 @@ logger = logging.getLogger(__name__)
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
SAMPLE_RATE = 24_000
def _write_float32_wav(path: Path, samples: np.ndarray, sample_rate: int) -> None:
"""Write a mono float32 WAV without relying on libsndfile.
Uses the same RIFF/IEEE-float layout as the browser's buildWav(), so the
file is playable by anything that understands IEEE-float WAV (codec tag 3).
"""
flat = samples.flatten().astype(np.float32)
data = flat.tobytes()
data_size = len(data)
with open(path, "wb") as f:
f.write(b"RIFF")
f.write(struct.pack("<I", 36 + data_size))
f.write(b"WAVE")
f.write(b"fmt ")
f.write(struct.pack("<I", 16)) # fmt chunk size
f.write(struct.pack("<H", 3)) # codec: IEEE float
f.write(struct.pack("<H", 1)) # channels: mono
f.write(struct.pack("<I", sample_rate))
f.write(struct.pack("<I", sample_rate * 4)) # byte rate
f.write(struct.pack("<H", 4)) # block align
f.write(struct.pack("<H", 32)) # bits per sample
f.write(b"data")
f.write(struct.pack("<I", data_size))
f.write(data)
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
@@ -160,6 +193,39 @@ _voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
def _reset_scheduler_caches() -> None:
"""Clear VibePod scheduler caches and reset all scheduler running state.
Called on every cancel/timeout path so the next generation starts from a
completely clean slate. We do two things:
1. Clear the VibePod cache dicts so prepare_noise_scheduler takes the fresh
path and calls set_timesteps(), which re-initialises sigmas/timesteps.
2. Directly zero out the scheduler's running counters (_step_index,
model_outputs, lower_order_nums, _begin_index). This is belt-and-
suspenders: VibeVoice's set_timesteps() *does* reset these fields, but
if a cancelled thread left _step_index=N and the new generation's
_init_step_index guard (``if self.step_index is None``) sees a non-None
value it skips initialisation entirely, causing an out-of-bounds access
on sigmas[step_index + 1] at the very first step.
"""
if _model is None:
return
for attr in ("_vibepod_scheduler_cache", "_vibepod_t_batch_cache"):
if hasattr(_model, attr):
setattr(_model, attr, {})
try:
scheduler = _model.model.noise_scheduler
scheduler._step_index = None
scheduler._begin_index = None
scheduler.model_outputs = [None] * scheduler.config.solver_order
scheduler.lower_order_nums = 0
except Exception:
pass
# Config defaults (can be overridden by env vars)
# These are populated in _load_model_sync once the device is known.
_config = {
@@ -446,6 +512,14 @@ def _install_generation_optimizations(model: object) -> None:
if cached is None:
scheduler.set_timesteps(self.ddpm_inference_steps)
# Belt-and-suspenders: explicitly reset running state even though
# set_timesteps() should do it, because a prior cancelled generation
# may have left _step_index non-None, causing _init_step_index to
# be skipped and triggering an out-of-bounds access in step().
scheduler._step_index = None
scheduler._begin_index = None
scheduler.model_outputs = [None] * scheduler.config.solver_order
scheduler.lower_order_nums = 0
cached = {
"num_inference_steps": scheduler.num_inference_steps,
"timesteps": scheduler.timesteps,
@@ -664,6 +738,7 @@ def _load_model_sync() -> None:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
generation_store.init_db()
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
@@ -839,12 +914,14 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
)
self.finished_flags[idx] = True
job_id = ids.gen_id()
start = time.monotonic()
streamer = NonBlockingAudioStreamer(batch_size=1)
cancel_event = threading.Event()
accum_size = max(1, _config["chunk_accum"])
accumulated_chunks = []
all_save_chunks: list[torch.Tensor] = []
chunk_count = 0
audio_samples = 0
first_chunk_at: float | None = None
@@ -866,14 +943,22 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
try:
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
except Exception:
pass
_reset_scheduler_caches()
yield _sse({"type": "error", "message": "Generation timed out"})
return
if await request.is_disconnected():
cancel_event.set()
future.cancel()
logger.info("Generation client disconnected; stream cancelled.")
logger.info("Client disconnected; waiting for inference thread to stop.")
try:
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
except Exception:
pass
_reset_scheduler_caches()
return
if chunk is None: # stop signal
@@ -895,6 +980,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
.to("cpu", dtype=torch.float32)
.contiguous()
)
all_save_chunks.append(combined)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
@@ -916,6 +1002,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
.to("cpu", dtype=torch.float32)
.contiguous()
)
all_save_chunks.append(combined)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
@@ -924,7 +1011,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
try:
speaker = await future
except asyncio.CancelledError:
logger.info("Generation cancelled.")
cancel_event.set()
logger.info("Generation cancelled; waiting for inference thread to stop.")
try:
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
except Exception:
pass
_reset_scheduler_caches()
yield _sse({"type": "cancelled"})
return
except Exception as exc:
@@ -944,8 +1037,38 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if profile is not None:
logger.info("Generation profile: %s", profile)
logger.info("Generation complete in %.1fs", elapsed)
# Persist audio and waveform peaks after streaming is done.
audio_path: str | None = None
waveform_path: str | None = None
try:
out_dir = generation_store.job_dir(job_id)
out_dir.mkdir(parents=True, exist_ok=True)
wav_path = out_dir / "audio.wav"
peaks_path = out_dir / "waveform.json"
if all_save_chunks:
all_audio = torch.cat(all_save_chunks).numpy()
_write_float32_wav(wav_path, all_audio, SAMPLE_RATE)
waveform_module.write_peaks(wav_path, peaks_path)
audio_path = str(wav_path)
waveform_path = str(peaks_path)
generation_store.save_completed_job(
job_id,
script=req.text,
speaker=speaker,
cfg_scale=req.cfg_scale,
inference_steps=req.inference_steps,
duration_secs=audio_secs,
sample_rate=SAMPLE_RATE,
audio_path=audio_path or "",
waveform_path=waveform_path or "",
)
except Exception:
logger.exception("Failed to persist generation %s", job_id)
complete_event = {
"type": "complete",
"job_id": job_id,
"elapsed": elapsed,
"speaker": speaker,
"audio_secs": round(audio_secs, 2),
@@ -969,3 +1092,53 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
"X-Content-Type-Options": "nosniff",
},
)
# ── Generation library endpoints ────────────────────────────────────────────────
@app.get("/generations")
async def list_generations(limit: int = 50, offset: int = 0) -> dict:
jobs = generation_store.list_jobs(limit=min(limit, 200), offset=offset)
return {"items": jobs, "limit": limit, "offset": offset}
@app.get("/generations/{job_id}")
async def get_generation(job_id: str) -> dict:
job = generation_store.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Generation not found")
return job
@app.get("/generations/{job_id}/audio")
async def get_generation_audio(job_id: str) -> FileResponse:
job = generation_store.get_job(job_id)
if not job or not job.get("audio_path"):
raise HTTPException(status_code=404, detail="Audio not found")
audio_path = Path(job["audio_path"])
if not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file missing from disk")
return FileResponse(
str(audio_path),
media_type="audio/wav",
filename=f"{job_id}.wav",
)
@app.get("/generations/{job_id}/waveform")
async def get_generation_waveform(job_id: str) -> dict:
job = generation_store.get_job(job_id)
if not job or not job.get("waveform_path"):
raise HTTPException(status_code=404, detail="Waveform not found")
peaks_path = Path(job["waveform_path"])
if not peaks_path.exists():
raise HTTPException(status_code=404, detail="Waveform file missing from disk")
return json.loads(peaks_path.read_text(encoding="utf-8"))
@app.delete("/generations/{job_id}", status_code=204)
async def delete_generation(job_id: str) -> None:
deleted = generation_store.delete_job(job_id)
if not deleted:
raise HTTPException(status_code=404, detail="Generation not found")