First commit
This commit is contained in:
91
db.py
Normal file
91
db.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
import math
|
||||
import struct
|
||||
|
||||
def connect(db_path: str) -> sqlite3.Connection:
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
con = sqlite3.connect(db_path)
|
||||
con.execute("PRAGMA foreign_keys = ON;")
|
||||
return con
|
||||
|
||||
def init_db(con: sqlite3.Connection, schema_path: str = "scripts/schema.sql"):
|
||||
schema = Path(schema_path).read_text(encoding="utf-8")
|
||||
con.executescript(schema)
|
||||
con.commit()
|
||||
|
||||
def pack_vec(vec):
|
||||
return struct.pack("<%sf" % len(vec), *vec)
|
||||
|
||||
def unpack_vec(blob):
|
||||
fcount = len(blob)//4
|
||||
return list(struct.unpack("<%sf" % fcount, blob))
|
||||
|
||||
def cosine(a, b):
|
||||
na = math.sqrt(sum(x*x for x in a)); nb = math.sqrt(sum(x*x for x in b))
|
||||
if na == 0 or nb == 0: return 0.0
|
||||
return sum(x*y for x,y in zip(a,b)) / (na*nb)
|
||||
|
||||
def upsert_source(con, url=None, title=None, publisher=None, date_published=None, content=None, tags=None):
|
||||
con.execute(
|
||||
"""INSERT INTO sources(url, title, publisher, date_published, content)
|
||||
VALUES(?,?,?,?,?)
|
||||
ON CONFLICT(url) DO UPDATE SET
|
||||
title=COALESCE(excluded.title, title),
|
||||
publisher=COALESCE(excluded.publisher, publisher),
|
||||
date_published=COALESCE(excluded.date_published, date_published),
|
||||
content=COALESCE(excluded.content, content)
|
||||
""", (url, title, publisher, date_published, content)
|
||||
)
|
||||
sid = con.execute("SELECT id FROM sources WHERE url=?", (url,)).fetchone()[0]
|
||||
if tags:
|
||||
for t in tags:
|
||||
con.execute("INSERT OR IGNORE INTO tags(name) VALUES(?)", (t,))
|
||||
tid = con.execute("SELECT id FROM tags WHERE name=?", (t,)).fetchone()[0]
|
||||
con.execute("INSERT OR IGNORE INTO source_tags(source_id, tag_id) VALUES(?,?)", (sid, tid))
|
||||
con.commit()
|
||||
return sid
|
||||
|
||||
def insert_summary(con, source_id, title, summary, newsletter_date=None, tone_version=None):
|
||||
cur = con.cursor()
|
||||
cur.execute(
|
||||
"""INSERT INTO summaries(source_id, title, summary, newsletter_date, tone_version)
|
||||
VALUES (?,?,?,?,?)""",
|
||||
(source_id, title, summary, newsletter_date, tone_version)
|
||||
)
|
||||
con.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
def upsert_embedding(con, ref_table, ref_id, model, vec):
|
||||
dim = len(vec)
|
||||
blob = pack_vec(vec)
|
||||
con.execute(
|
||||
"""INSERT INTO embeddings(ref_table, ref_id, model, dim, vec)
|
||||
VALUES (?,?,?,?,?)
|
||||
ON CONFLICT(ref_table, ref_id, model) DO UPDATE SET vec=excluded.vec, dim=excluded.dim""",
|
||||
(ref_table, ref_id, model, dim, blob)
|
||||
)
|
||||
con.commit()
|
||||
|
||||
def topk_similar(con, model, query_vec, ref_table="summaries", k=3, min_sim=0.78):
|
||||
rows = con.execute(
|
||||
"SELECT ref_id, dim, vec FROM embeddings WHERE ref_table=? AND model=?;",
|
||||
(ref_table, model)
|
||||
).fetchall()
|
||||
scored = []
|
||||
for ref_id, dim, blob in rows:
|
||||
vec = unpack_vec(blob)
|
||||
if len(vec) != len(query_vec):
|
||||
continue
|
||||
sim = cosine(query_vec, vec)
|
||||
if sim >= min_sim:
|
||||
scored.append((sim, ref_id))
|
||||
scored.sort(reverse=True)
|
||||
ref_ids = [rid for _, rid in scored[:k]]
|
||||
if not ref_ids: return []
|
||||
if ref_table == "summaries":
|
||||
q = "SELECT id, title, summary, newsletter_date FROM summaries WHERE id IN (%s)" % ",".join("?"*len(ref_ids))
|
||||
return con.execute(q, ref_ids).fetchall()
|
||||
else:
|
||||
q = "SELECT id, title, url, date_published FROM sources WHERE id IN (%s)" % ",".join("?"*len(ref_ids))
|
||||
return con.execute(q, ref_ids).fetchall()
|
Reference in New Issue
Block a user