mirror of
https://github.com/JezzWTF/vibepod.git
synced 2026-06-13 03:58:07 +00:00
feat(phase-1): persistent generation library
- Save every completed generation to SQLite (generation_store.py) with WAV and waveform peaks written to data/generations/<id>/ - Deferred DB write until success — cancelled/errored generations never touch the DB and never appear in the library - Fixed cancel+regenerate IndexError: _reset_scheduler_caches() now directly zeros scheduler._step_index and running state in addition to clearing VibePod cache dicts; same explicit resets added in the fresh path of prepare_noise_scheduler as belt-and-suspenders - Added /library page with GenerationCard, WaveformPreview, waveform fetch, play/pause, download, delete, pagination, empty + error states - Added generation API routes (list, single, audio stream, waveform, delete) proxying to Python server - Added Library nav link to Header with active state - Persist script/speaker/CFG to localStorage so generate page state survives navigation - Updated build plan: Phase 0+1 ticked off, better-sqlite3 moved to Phase 2, architectural note on Python owning all persistence
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
"""SQLite persistence for VibePod generation jobs.
|
||||
|
||||
Schema lives here. The database is created on first use at:
|
||||
<repo_root>/data/db/vibepod.db
|
||||
|
||||
All writes go through this module. The Next.js layer reads the same file
|
||||
via better-sqlite3 for project-level data in later phases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Paths relative to the repo root (one level up from this file's directory).
|
||||
_REPO_ROOT = Path(__file__).parent.parent
|
||||
DATA_DIR = _REPO_ROOT / "data"
|
||||
DB_PATH = DATA_DIR / "db" / "vibepod.db"
|
||||
GENERATIONS_DIR = DATA_DIR / "generations"
|
||||
|
||||
_CREATE_GENERATIONS = """
|
||||
CREATE TABLE IF NOT EXISTS generations (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'generating',
|
||||
script TEXT NOT NULL,
|
||||
speaker TEXT NOT NULL,
|
||||
cfg_scale REAL NOT NULL,
|
||||
inference_steps INTEGER,
|
||||
duration_secs REAL,
|
||||
sample_rate INTEGER,
|
||||
audio_path TEXT,
|
||||
waveform_path TEXT,
|
||||
error_message TEXT
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def _connect() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Create the database directory, database file, and tables if they don't exist."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
GENERATIONS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _connect() as conn:
|
||||
conn.execute(_CREATE_GENERATIONS)
|
||||
|
||||
|
||||
def save_completed_job(
|
||||
job_id: str,
|
||||
script: str,
|
||||
speaker: str,
|
||||
cfg_scale: float,
|
||||
inference_steps: int | None,
|
||||
duration_secs: float,
|
||||
sample_rate: int,
|
||||
audio_path: str,
|
||||
waveform_path: str,
|
||||
) -> None:
|
||||
"""Insert a completed generation in a single write — no intermediate 'generating' row."""
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
with _connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO generations
|
||||
(id, created_at, status, script, speaker, cfg_scale, inference_steps,
|
||||
duration_secs, sample_rate, audio_path, waveform_path)
|
||||
VALUES (?, ?, 'complete', ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
job_id, created_at, script, speaker, cfg_scale, inference_steps,
|
||||
round(duration_secs, 3), sample_rate, audio_path, waveform_path,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def cancel_job(job_id: str) -> None:
|
||||
with _connect() as conn:
|
||||
conn.execute(
|
||||
"UPDATE generations SET status = 'cancelled' WHERE id = ?",
|
||||
(job_id,),
|
||||
)
|
||||
|
||||
|
||||
def fail_job(job_id: str, error_message: str) -> None:
|
||||
with _connect() as conn:
|
||||
conn.execute(
|
||||
"UPDATE generations SET status = 'error', error_message = ? WHERE id = ?",
|
||||
(error_message[:2000], job_id),
|
||||
)
|
||||
|
||||
|
||||
def list_jobs(limit: int = 50, offset: int = 0) -> list[dict]:
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM generations ORDER BY created_at DESC LIMIT ? OFFSET ?",
|
||||
(limit, offset),
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
def get_job(job_id: str) -> dict | None:
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM generations WHERE id = ?", (job_id,)
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def delete_job(job_id: str) -> bool:
|
||||
"""Delete the job record and its files. Returns True if the record existed."""
|
||||
job_dir = GENERATIONS_DIR / job_id
|
||||
if job_dir.exists():
|
||||
shutil.rmtree(job_dir)
|
||||
|
||||
with _connect() as conn:
|
||||
result = conn.execute(
|
||||
"DELETE FROM generations WHERE id = ?", (job_id,)
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
def job_dir(job_id: str) -> Path:
|
||||
return GENERATIONS_DIR / job_id
|
||||
+178
-5
@@ -22,6 +22,7 @@ import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import struct
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
@@ -37,10 +38,16 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
import generation_store
|
||||
import ids
|
||||
import waveform as waveform_module
|
||||
from tqdm import tqdm as _BaseTqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
@@ -49,6 +56,32 @@ logger = logging.getLogger(__name__)
|
||||
MODEL_ID = "microsoft/VibeVoice-Realtime-0.5B"
|
||||
SAMPLE_RATE = 24_000
|
||||
|
||||
|
||||
def _write_float32_wav(path: Path, samples: np.ndarray, sample_rate: int) -> None:
|
||||
"""Write a mono float32 WAV without relying on libsndfile.
|
||||
|
||||
Uses the same RIFF/IEEE-float layout as the browser's buildWav(), so the
|
||||
file is playable by anything that understands IEEE-float WAV (codec tag 3).
|
||||
"""
|
||||
flat = samples.flatten().astype(np.float32)
|
||||
data = flat.tobytes()
|
||||
data_size = len(data)
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"RIFF")
|
||||
f.write(struct.pack("<I", 36 + data_size))
|
||||
f.write(b"WAVE")
|
||||
f.write(b"fmt ")
|
||||
f.write(struct.pack("<I", 16)) # fmt chunk size
|
||||
f.write(struct.pack("<H", 3)) # codec: IEEE float
|
||||
f.write(struct.pack("<H", 1)) # channels: mono
|
||||
f.write(struct.pack("<I", sample_rate))
|
||||
f.write(struct.pack("<I", sample_rate * 4)) # byte rate
|
||||
f.write(struct.pack("<H", 4)) # block align
|
||||
f.write(struct.pack("<H", 32)) # bits per sample
|
||||
f.write(b"data")
|
||||
f.write(struct.pack("<I", data_size))
|
||||
f.write(data)
|
||||
|
||||
VOICES_DIR = Path(__file__).parent / "voices" / "streaming_model"
|
||||
VOICE_BASE_URL = (
|
||||
"https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model"
|
||||
@@ -160,6 +193,39 @@ _voice_presets: dict[str, object] = {}
|
||||
_load_lock = threading.Lock()
|
||||
_generation_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _reset_scheduler_caches() -> None:
|
||||
"""Clear VibePod scheduler caches and reset all scheduler running state.
|
||||
|
||||
Called on every cancel/timeout path so the next generation starts from a
|
||||
completely clean slate. We do two things:
|
||||
|
||||
1. Clear the VibePod cache dicts so prepare_noise_scheduler takes the fresh
|
||||
path and calls set_timesteps(), which re-initialises sigmas/timesteps.
|
||||
|
||||
2. Directly zero out the scheduler's running counters (_step_index,
|
||||
model_outputs, lower_order_nums, _begin_index). This is belt-and-
|
||||
suspenders: VibeVoice's set_timesteps() *does* reset these fields, but
|
||||
if a cancelled thread left _step_index=N and the new generation's
|
||||
_init_step_index guard (``if self.step_index is None``) sees a non-None
|
||||
value it skips initialisation entirely, causing an out-of-bounds access
|
||||
on sigmas[step_index + 1] at the very first step.
|
||||
"""
|
||||
if _model is None:
|
||||
return
|
||||
for attr in ("_vibepod_scheduler_cache", "_vibepod_t_batch_cache"):
|
||||
if hasattr(_model, attr):
|
||||
setattr(_model, attr, {})
|
||||
try:
|
||||
scheduler = _model.model.noise_scheduler
|
||||
scheduler._step_index = None
|
||||
scheduler._begin_index = None
|
||||
scheduler.model_outputs = [None] * scheduler.config.solver_order
|
||||
scheduler.lower_order_nums = 0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Config defaults (can be overridden by env vars)
|
||||
# These are populated in _load_model_sync once the device is known.
|
||||
_config = {
|
||||
@@ -446,6 +512,14 @@ def _install_generation_optimizations(model: object) -> None:
|
||||
|
||||
if cached is None:
|
||||
scheduler.set_timesteps(self.ddpm_inference_steps)
|
||||
# Belt-and-suspenders: explicitly reset running state even though
|
||||
# set_timesteps() should do it, because a prior cancelled generation
|
||||
# may have left _step_index non-None, causing _init_step_index to
|
||||
# be skipped and triggering an out-of-bounds access in step().
|
||||
scheduler._step_index = None
|
||||
scheduler._begin_index = None
|
||||
scheduler.model_outputs = [None] * scheduler.config.solver_order
|
||||
scheduler.lower_order_nums = 0
|
||||
cached = {
|
||||
"num_inference_steps": scheduler.num_inference_steps,
|
||||
"timesteps": scheduler.timesteps,
|
||||
@@ -664,6 +738,7 @@ def _load_model_sync() -> None:
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
generation_store.init_db()
|
||||
thread = threading.Thread(target=_load_model_sync, daemon=True, name="model-loader")
|
||||
thread.start()
|
||||
yield
|
||||
@@ -839,12 +914,14 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
job_id = ids.gen_id()
|
||||
start = time.monotonic()
|
||||
streamer = NonBlockingAudioStreamer(batch_size=1)
|
||||
cancel_event = threading.Event()
|
||||
|
||||
accum_size = max(1, _config["chunk_accum"])
|
||||
accumulated_chunks = []
|
||||
all_save_chunks: list[torch.Tensor] = []
|
||||
chunk_count = 0
|
||||
audio_samples = 0
|
||||
first_chunk_at: float | None = None
|
||||
@@ -866,14 +943,22 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
chunk = await asyncio.wait_for(streamer.audio_queues[0].get(), timeout=120.0)
|
||||
except asyncio.TimeoutError:
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
|
||||
except Exception:
|
||||
pass
|
||||
_reset_scheduler_caches()
|
||||
yield _sse({"type": "error", "message": "Generation timed out"})
|
||||
return
|
||||
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
future.cancel()
|
||||
logger.info("Generation client disconnected; stream cancelled.")
|
||||
logger.info("Client disconnected; waiting for inference thread to stop.")
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
|
||||
except Exception:
|
||||
pass
|
||||
_reset_scheduler_caches()
|
||||
return
|
||||
|
||||
if chunk is None: # stop signal
|
||||
@@ -895,6 +980,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
.to("cpu", dtype=torch.float32)
|
||||
.contiguous()
|
||||
)
|
||||
all_save_chunks.append(combined)
|
||||
chunk_count += 1
|
||||
audio_samples += combined.numel()
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
@@ -916,6 +1002,7 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
.to("cpu", dtype=torch.float32)
|
||||
.contiguous()
|
||||
)
|
||||
all_save_chunks.append(combined)
|
||||
chunk_count += 1
|
||||
audio_samples += combined.numel()
|
||||
pcm_b64 = base64.b64encode(combined.numpy().tobytes()).decode()
|
||||
@@ -924,7 +1011,13 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
try:
|
||||
speaker = await future
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generation cancelled.")
|
||||
cancel_event.set()
|
||||
logger.info("Generation cancelled; waiting for inference thread to stop.")
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(future), timeout=15.0)
|
||||
except Exception:
|
||||
pass
|
||||
_reset_scheduler_caches()
|
||||
yield _sse({"type": "cancelled"})
|
||||
return
|
||||
except Exception as exc:
|
||||
@@ -944,8 +1037,38 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
if profile is not None:
|
||||
logger.info("Generation profile: %s", profile)
|
||||
logger.info("Generation complete in %.1fs", elapsed)
|
||||
|
||||
# Persist audio and waveform peaks after streaming is done.
|
||||
audio_path: str | None = None
|
||||
waveform_path: str | None = None
|
||||
try:
|
||||
out_dir = generation_store.job_dir(job_id)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
wav_path = out_dir / "audio.wav"
|
||||
peaks_path = out_dir / "waveform.json"
|
||||
if all_save_chunks:
|
||||
all_audio = torch.cat(all_save_chunks).numpy()
|
||||
_write_float32_wav(wav_path, all_audio, SAMPLE_RATE)
|
||||
waveform_module.write_peaks(wav_path, peaks_path)
|
||||
audio_path = str(wav_path)
|
||||
waveform_path = str(peaks_path)
|
||||
generation_store.save_completed_job(
|
||||
job_id,
|
||||
script=req.text,
|
||||
speaker=speaker,
|
||||
cfg_scale=req.cfg_scale,
|
||||
inference_steps=req.inference_steps,
|
||||
duration_secs=audio_secs,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
audio_path=audio_path or "",
|
||||
waveform_path=waveform_path or "",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to persist generation %s", job_id)
|
||||
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"job_id": job_id,
|
||||
"elapsed": elapsed,
|
||||
"speaker": speaker,
|
||||
"audio_secs": round(audio_secs, 2),
|
||||
@@ -969,3 +1092,53 @@ async def generate(req: GenerateRequest, request: Request) -> StreamingResponse:
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── Generation library endpoints ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@app.get("/generations")
|
||||
async def list_generations(limit: int = 50, offset: int = 0) -> dict:
|
||||
jobs = generation_store.list_jobs(limit=min(limit, 200), offset=offset)
|
||||
return {"items": jobs, "limit": limit, "offset": offset}
|
||||
|
||||
|
||||
@app.get("/generations/{job_id}")
|
||||
async def get_generation(job_id: str) -> dict:
|
||||
job = generation_store.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return job
|
||||
|
||||
|
||||
@app.get("/generations/{job_id}/audio")
|
||||
async def get_generation_audio(job_id: str) -> FileResponse:
|
||||
job = generation_store.get_job(job_id)
|
||||
if not job or not job.get("audio_path"):
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
audio_path = Path(job["audio_path"])
|
||||
if not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file missing from disk")
|
||||
return FileResponse(
|
||||
str(audio_path),
|
||||
media_type="audio/wav",
|
||||
filename=f"{job_id}.wav",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/generations/{job_id}/waveform")
|
||||
async def get_generation_waveform(job_id: str) -> dict:
|
||||
job = generation_store.get_job(job_id)
|
||||
if not job or not job.get("waveform_path"):
|
||||
raise HTTPException(status_code=404, detail="Waveform not found")
|
||||
peaks_path = Path(job["waveform_path"])
|
||||
if not peaks_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Waveform file missing from disk")
|
||||
return json.loads(peaks_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
@app.delete("/generations/{job_id}", status_code=204)
|
||||
async def delete_generation(job_id: str) -> None:
|
||||
deleted = generation_store.delete_job(job_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
Reference in New Issue
Block a user