92 lines
3.5 KiB
Python
92 lines
3.5 KiB
Python
import sqlite3
|
|
from pathlib import Path
|
|
import math
|
|
import struct
|
|
|
|
def connect(db_path: str) -> sqlite3.Connection:
|
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
con = sqlite3.connect(db_path)
|
|
con.execute("PRAGMA foreign_keys = ON;")
|
|
return con
|
|
|
|
def init_db(con: sqlite3.Connection, schema_path: str = "scripts/schema.sql"):
|
|
schema = Path(schema_path).read_text(encoding="utf-8")
|
|
con.executescript(schema)
|
|
con.commit()
|
|
|
|
def pack_vec(vec):
|
|
return struct.pack("<%sf" % len(vec), *vec)
|
|
|
|
def unpack_vec(blob):
|
|
fcount = len(blob)//4
|
|
return list(struct.unpack("<%sf" % fcount, blob))
|
|
|
|
def cosine(a, b):
|
|
na = math.sqrt(sum(x*x for x in a)); nb = math.sqrt(sum(x*x for x in b))
|
|
if na == 0 or nb == 0: return 0.0
|
|
return sum(x*y for x,y in zip(a,b)) / (na*nb)
|
|
|
|
def upsert_source(con, url=None, title=None, publisher=None, date_published=None, content=None, tags=None):
|
|
con.execute(
|
|
"""INSERT INTO sources(url, title, publisher, date_published, content)
|
|
VALUES(?,?,?,?,?)
|
|
ON CONFLICT(url) DO UPDATE SET
|
|
title=COALESCE(excluded.title, title),
|
|
publisher=COALESCE(excluded.publisher, publisher),
|
|
date_published=COALESCE(excluded.date_published, date_published),
|
|
content=COALESCE(excluded.content, content)
|
|
""", (url, title, publisher, date_published, content)
|
|
)
|
|
sid = con.execute("SELECT id FROM sources WHERE url=?", (url,)).fetchone()[0]
|
|
if tags:
|
|
for t in tags:
|
|
con.execute("INSERT OR IGNORE INTO tags(name) VALUES(?)", (t,))
|
|
tid = con.execute("SELECT id FROM tags WHERE name=?", (t,)).fetchone()[0]
|
|
con.execute("INSERT OR IGNORE INTO source_tags(source_id, tag_id) VALUES(?,?)", (sid, tid))
|
|
con.commit()
|
|
return sid
|
|
|
|
def insert_summary(con, source_id, title, summary, newsletter_date=None, tone_version=None):
|
|
cur = con.cursor()
|
|
cur.execute(
|
|
"""INSERT INTO summaries(source_id, title, summary, newsletter_date, tone_version)
|
|
VALUES (?,?,?,?,?)""",
|
|
(source_id, title, summary, newsletter_date, tone_version)
|
|
)
|
|
con.commit()
|
|
return cur.lastrowid
|
|
|
|
def upsert_embedding(con, ref_table, ref_id, model, vec):
|
|
dim = len(vec)
|
|
blob = pack_vec(vec)
|
|
con.execute(
|
|
"""INSERT INTO embeddings(ref_table, ref_id, model, dim, vec)
|
|
VALUES (?,?,?,?,?)
|
|
ON CONFLICT(ref_table, ref_id, model) DO UPDATE SET vec=excluded.vec, dim=excluded.dim""",
|
|
(ref_table, ref_id, model, dim, blob)
|
|
)
|
|
con.commit()
|
|
|
|
def topk_similar(con, model, query_vec, ref_table="summaries", k=3, min_sim=0.78):
|
|
rows = con.execute(
|
|
"SELECT ref_id, dim, vec FROM embeddings WHERE ref_table=? AND model=?;",
|
|
(ref_table, model)
|
|
).fetchall()
|
|
scored = []
|
|
for ref_id, dim, blob in rows:
|
|
vec = unpack_vec(blob)
|
|
if len(vec) != len(query_vec):
|
|
continue
|
|
sim = cosine(query_vec, vec)
|
|
if sim >= min_sim:
|
|
scored.append((sim, ref_id))
|
|
scored.sort(reverse=True)
|
|
ref_ids = [rid for _, rid in scored[:k]]
|
|
if not ref_ids: return []
|
|
if ref_table == "summaries":
|
|
q = "SELECT id, title, summary, newsletter_date FROM summaries WHERE id IN (%s)" % ",".join("?"*len(ref_ids))
|
|
return con.execute(q, ref_ids).fetchall()
|
|
else:
|
|
q = "SELECT id, title, url, date_published FROM sources WHERE id IN (%s)" % ",".join("?"*len(ref_ids))
|
|
return con.execute(q, ref_ids).fetchall()
|