feat(phase-1): persistent generation library

- Save every completed generation to SQLite (generation_store.py) with
  WAV and waveform peaks written to data/generations/<id>/
- Deferred DB write until success — cancelled/errored generations never
  touch the DB and never appear in the library
- Fixed cancel+regenerate IndexError: _reset_scheduler_caches() now
  directly zeros scheduler._step_index and running state in addition to
  clearing VibePod cache dicts; same explicit resets added in the fresh
  path of prepare_noise_scheduler as belt-and-suspenders
- Added /library page with GenerationCard, WaveformPreview, waveform
  fetch, play/pause, download, delete, pagination, empty + error states
- Added generation API routes (list, single, audio stream, waveform,
  delete) proxying to Python server
- Added Library nav link to Header with active state
- Persist script/speaker/CFG to localStorage so generate page state
  survives navigation
- Updated build plan: Phase 0+1 ticked off, better-sqlite3 moved to
  Phase 2, architectural note on Python owning all persistence
This commit is contained in:
2026-05-02 23:05:11 +01:00
parent 47e0c7e512
commit 13085166fb
13 changed files with 913 additions and 29 deletions
+133
View File
@@ -0,0 +1,133 @@
"""SQLite persistence for VibePod generation jobs.
Schema lives here. The database is created on first use at:
<repo_root>/data/db/vibepod.db
All writes go through this module. The Next.js layer reads the same file
via better-sqlite3 for project-level data in later phases.
"""
from __future__ import annotations
import json
import shutil
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
# Paths relative to the repo root (one level up from this file's directory).
_REPO_ROOT = Path(__file__).parent.parent
DATA_DIR = _REPO_ROOT / "data"
DB_PATH = DATA_DIR / "db" / "vibepod.db"
GENERATIONS_DIR = DATA_DIR / "generations"
_CREATE_GENERATIONS = """
CREATE TABLE IF NOT EXISTS generations (
id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'generating',
script TEXT NOT NULL,
speaker TEXT NOT NULL,
cfg_scale REAL NOT NULL,
inference_steps INTEGER,
duration_secs REAL,
sample_rate INTEGER,
audio_path TEXT,
waveform_path TEXT,
error_message TEXT
)
"""
def _connect() -> sqlite3.Connection:
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def init_db() -> None:
"""Create the database directory, database file, and tables if they don't exist."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
GENERATIONS_DIR.mkdir(parents=True, exist_ok=True)
with _connect() as conn:
conn.execute(_CREATE_GENERATIONS)
def save_completed_job(
job_id: str,
script: str,
speaker: str,
cfg_scale: float,
inference_steps: int | None,
duration_secs: float,
sample_rate: int,
audio_path: str,
waveform_path: str,
) -> None:
"""Insert a completed generation in a single write — no intermediate 'generating' row."""
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
INSERT INTO generations
(id, created_at, status, script, speaker, cfg_scale, inference_steps,
duration_secs, sample_rate, audio_path, waveform_path)
VALUES (?, ?, 'complete', ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
job_id, created_at, script, speaker, cfg_scale, inference_steps,
round(duration_secs, 3), sample_rate, audio_path, waveform_path,
),
)
def cancel_job(job_id: str) -> None:
with _connect() as conn:
conn.execute(
"UPDATE generations SET status = 'cancelled' WHERE id = ?",
(job_id,),
)
def fail_job(job_id: str, error_message: str) -> None:
with _connect() as conn:
conn.execute(
"UPDATE generations SET status = 'error', error_message = ? WHERE id = ?",
(error_message[:2000], job_id),
)
def list_jobs(limit: int = 50, offset: int = 0) -> list[dict]:
with _connect() as conn:
rows = conn.execute(
"SELECT * FROM generations ORDER BY created_at DESC LIMIT ? OFFSET ?",
(limit, offset),
).fetchall()
return [dict(row) for row in rows]
def get_job(job_id: str) -> dict | None:
with _connect() as conn:
row = conn.execute(
"SELECT * FROM generations WHERE id = ?", (job_id,)
).fetchone()
return dict(row) if row else None
def delete_job(job_id: str) -> bool:
"""Delete the job record and its files. Returns True if the record existed."""
job_dir = GENERATIONS_DIR / job_id
if job_dir.exists():
shutil.rmtree(job_dir)
with _connect() as conn:
result = conn.execute(
"DELETE FROM generations WHERE id = ?", (job_id,)
)
return result.rowcount > 0
def job_dir(job_id: str) -> Path:
return GENERATIONS_DIR / job_id
+178 -5
View File
@@ -22,6 +22,7 @@ import asyncio
import base64
import concurrent.futures
import copy
import struct
import functools
import importlib.util
import json
@@ -37,10 +38,16 @@ from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal
import numpy as np
import soundfile as sf
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel, Field, field_validator
import generation_store
import ids
import waveform as waveform_module
from tqdm import tqdm as _BaseTqdm
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
@@ -49,6 +56,32 @@ logger = logging.getLogger(__name__)
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
SAMPLE_RATE = 24_000
def _write_float32_wav(path: Path, samples: np.ndarray, sample_rate: int) -> None:
"""Write a mono float32 WAV without relying on libsndfile.
Uses the same RIFF/IEEE-float layout as the browser's buildWav(), so the
file is playable by anything that understands IEEE-float WAV (codec tag 3).
"""
flat = samples.flatten().astype(np.float32)
data = flat.tobytes()
data_size = len(data)
with open(path, "wb") as f:
f.write(b"RIFF")
f.write(struct.pack("<I", 36 + data_size))
f.write(b"WAVE")
f.write(b"fmt ")
f.write(struct.pack("<I", 16)) # fmt chunk size
f.write(struct.pack("<H", 3)) # codec: IEEE float
f.write(struct.pack("<H", 1)) # channels: mono
f.write(struct.pack("<I", sample_rate))
f.write(struct.pack("<I", sample_rate * 4)) # byte rate
f.write(struct.pack("<H", 4)) # block align
f.write(struct.pack("<H", 32)) # bits per sample
f.write(b"data")
f.write(struct.pack("<I", data_size))
f.write(data)
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
VOICE_BASE_URL = (
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
@@ -160,6 +193,39 @@ _voice_presets: dict[str, object] = {}
_load_lock = threading.Lock()
_generation_lock = asyncio.Lock()
def _reset_scheduler_caches() -> None:
"""Clear VibePod scheduler caches and reset all scheduler running state.
Called on every cancel/timeout path so the next generation starts from a
completely clean slate. We do two things:
1. Clear the VibePod cache dicts so prepare_noise_scheduler takes the fresh
path and calls set_timesteps(), which re-initialises sigmas/timesteps.
2. Directly zero out the scheduler's running counters (_step_index,
model_outputs, lower_order_nums, _begin_index). This is belt-and-
suspenders: VibeVoice's set_timesteps() *does* reset these fields, but
if a cancelled thread left _step_index=N and the new generation's
_init_step_index guard (``if self.step_index is None``) sees a non-None
value it skips initialisation entirely, causing an out-of-bounds access
on sigmas[step_index + 1] at the very first step.
"""
if _model is None:
return
for attr in ("_vibepod_scheduler_cache", "_vibepod_t_batch_cache"):
if hasattr(_model, attr):
setattr(_model, attr, {})
try:
scheduler = _model.model.noise_scheduler
scheduler._step_index = None
scheduler._begin_index = None
scheduler.model_outputs = [None] * scheduler.config.solver_order
scheduler.lower_order_nums = 0
except Exception:
pass
# Config defaults (can be overridden by env vars)
# These are populated in _load_model_sync once the device is known.
_config = {
@@ -446,6 +512,14 @@ def _install_generation_optimizations(model: object) -> None:
if cached is None:
scheduler.set_timesteps(self.ddpm_inference_steps)
# Belt-and-suspenders: explicitly reset running state even though
# set_timesteps() should do it, because a prior cancelled generation
# may have left _step_index non-None, causing _init_step_index to
# be skipped and triggering an out-of-bounds access in step().
scheduler._step_index = None
scheduler._begin_index = None
scheduler.model_outputs = [None] * scheduler.config.solver_order
scheduler.lower_order_nums = 0
cached = {
"num_inference_steps": scheduler.num_inference_steps,
"timesteps": scheduler.timesteps,
@@ -664,6 +738,7 @@ def _load_model_sync() -> None:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
generation_store.init_db()
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
thread.start()
yield
@@ -839,12 +914,14 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
)
self.finished_flags[idx] = True
job_id = ids.gen_id()
start = time.monotonic()
streamer = NonBlockingAudioStreamer(batch_size=1)
cancel_event = threading.Event()
accum_size = max(1, _config["chunk_accum"])
accumulated_chunks = []
all_save_chunks: list[torch.Tensor] = []
chunk_count = 0
audio_samples = 0
first_chunk_at: float | None = None
@@ -866,14 +943,22 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
except asyncio.TimeoutError:
cancel_event.set()
future.cancel()
try:
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
except Exception:
pass
_reset_scheduler_caches()
yield _sse({"type": "error", "message": "Generation timed out"})
return
if await request.is_disconnected():
cancel_event.set()
future.cancel()
logger.info("Generation client disconnected; stream cancelled.")
logger.info("Client disconnected; waiting for inference thread to stop.")
try:
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
except Exception:
pass
_reset_scheduler_caches()
return
if chunk is None: # stop signal
@@ -895,6 +980,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
.to("cpu", dtype=torch.float32)
.contiguous()
)
all_save_chunks.append(combined)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
@@ -916,6 +1002,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
.to("cpu", dtype=torch.float32)
.contiguous()
)
all_save_chunks.append(combined)
chunk_count += 1
audio_samples += combined.numel()
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
@@ -924,7 +1011,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
try:
speaker = await future
except asyncio.CancelledError:
logger.info("Generation cancelled.")
cancel_event.set()
logger.info("Generation cancelled; waiting for inference thread to stop.")
try:
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
except Exception:
pass
_reset_scheduler_caches()
yield _sse({"type": "cancelled"})
return
except Exception as exc:
@@ -944,8 +1037,38 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
if profile is not None:
logger.info("Generation profile: %s", profile)
logger.info("Generation complete in %.1fs", elapsed)
# Persist audio and waveform peaks after streaming is done.
audio_path: str | None = None
waveform_path: str | None = None
try:
out_dir = generation_store.job_dir(job_id)
out_dir.mkdir(parents=True, exist_ok=True)
wav_path = out_dir / "audio.wav"
peaks_path = out_dir / "waveform.json"
if all_save_chunks:
all_audio = torch.cat(all_save_chunks).numpy()
_write_float32_wav(wav_path, all_audio, SAMPLE_RATE)
waveform_module.write_peaks(wav_path, peaks_path)
audio_path = str(wav_path)
waveform_path = str(peaks_path)
generation_store.save_completed_job(
job_id,
script=req.text,
speaker=speaker,
cfg_scale=req.cfg_scale,
inference_steps=req.inference_steps,
duration_secs=audio_secs,
sample_rate=SAMPLE_RATE,
audio_path=audio_path or "",
waveform_path=waveform_path or "",
)
except Exception:
logger.exception("Failed to persist generation %s", job_id)
complete_event = {
"type": "complete",
"job_id": job_id,
"elapsed": elapsed,
"speaker": speaker,
"audio_secs": round(audio_secs, 2),
@@ -969,3 +1092,53 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
"X-Content-Type-Options": "nosniff",
},
)
# ── Generation library endpoints ────────────────────────────────────────────────
@app.get("/generations")
async def list_generations(limit: int = 50, offset: int = 0) -> dict:
jobs = generation_store.list_jobs(limit=min(limit, 200), offset=offset)
return {"items": jobs, "limit": limit, "offset": offset}
@app.get("/generations/{job_id}")
async def get_generation(job_id: str) -> dict:
job = generation_store.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Generation not found")
return job
@app.get("/generations/{job_id}/audio")
async def get_generation_audio(job_id: str) -> FileResponse:
job = generation_store.get_job(job_id)
if not job or not job.get("audio_path"):
raise HTTPException(status_code=404, detail="Audio not found")
audio_path = Path(job["audio_path"])
if not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file missing from disk")
return FileResponse(
str(audio_path),
media_type="audio/wav",
filename=f"{job_id}.wav",
)
@app.get("/generations/{job_id}/waveform")
async def get_generation_waveform(job_id: str) -> dict:
job = generation_store.get_job(job_id)
if not job or not job.get("waveform_path"):
raise HTTPException(status_code=404, detail="Waveform not found")
peaks_path = Path(job["waveform_path"])
if not peaks_path.exists():
raise HTTPException(status_code=404, detail="Waveform file missing from disk")
return json.loads(peaks_path.read_text(encoding="utf-8"))
@app.delete("/generations/{job_id}", status_code=204)
async def delete_generation(job_id: str) -> None:
deleted = generation_store.delete_job(job_id)
if not deleted:
raise HTTPException(status_code=404, detail="Generation not found")