mirror of
https://github.com/chatmail/relay.git
synced 2026-05-10 16:04:37 +00:00
implement and test migration from sqlite to storing password in userdir
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import iniconfig
|
||||
@@ -53,19 +51,6 @@ class Config:
|
||||
addr=addr, home=str(home), uid="vmail", gid="vmail", password=enc_password
|
||||
)
|
||||
|
||||
def set_user_password(self, addr, enc_password):
|
||||
# reading and writing user data needs to be atomic
|
||||
# to allow concurrent logins to succeed.
|
||||
assert not addr.startswith("echo@"), addr
|
||||
userdir = self.get_user_maildir(addr)
|
||||
userdir.mkdir(exist_ok=True)
|
||||
password_path = userdir.joinpath("password")
|
||||
password_path_tmp = userdir.joinpath("password.tmp")
|
||||
password_path_tmp.write_text(enc_password)
|
||||
os.rename(password_path_tmp, password_path)
|
||||
print(f"Created address: {addr}", file=sys.stderr)
|
||||
return self.get_user_dict(addr=addr, enc_password=enc_password)
|
||||
|
||||
|
||||
def write_initial_config(inipath, mail_domain, overrides):
|
||||
"""Write out default config file, using the specified config value overrides."""
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
import contextlib
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class DBError(Exception):
|
||||
"""error during an operation on the database."""
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, sqlconn, write):
|
||||
self._sqlconn = sqlconn
|
||||
self._write = write
|
||||
|
||||
def close(self):
|
||||
self._sqlconn.close()
|
||||
|
||||
def commit(self):
|
||||
self._sqlconn.commit()
|
||||
|
||||
def rollback(self):
|
||||
self._sqlconn.rollback()
|
||||
|
||||
def execute(self, query, params=()):
|
||||
cur = self.cursor()
|
||||
try:
|
||||
cur.execute(query, params)
|
||||
except sqlite3.IntegrityError as e:
|
||||
raise DBError(e)
|
||||
return cur
|
||||
|
||||
def cursor(self):
|
||||
return self._sqlconn.cursor()
|
||||
|
||||
def get_user(self, addr: str) -> {}:
|
||||
"""Get a row from the users table."""
|
||||
q = "SELECT addr, password from users WHERE addr = ?"
|
||||
row = self._sqlconn.execute(q, (addr,)).fetchone()
|
||||
return dict(user=row[0], password=row[1]) if row else {}
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, path: str):
|
||||
self.path = Path(path)
|
||||
self.ensure_tables()
|
||||
|
||||
def _get_connection(
|
||||
self, write=False, transaction=False, closing=False
|
||||
) -> Connection:
|
||||
# we let the database serialize all writers at connection time
|
||||
# to play it very safe (we don't have massive amounts of writes).
|
||||
mode = "ro"
|
||||
if write:
|
||||
mode = "rw"
|
||||
if not self.path.exists():
|
||||
mode = "rwc"
|
||||
uri = "file:%s?mode=%s" % (self.path, mode)
|
||||
sqlconn = sqlite3.connect(
|
||||
uri,
|
||||
timeout=60,
|
||||
isolation_level=None if transaction else "DEFERRED",
|
||||
uri=True,
|
||||
)
|
||||
|
||||
# Enable Write-Ahead Logging to avoid readers blocking writers and vice versa.
|
||||
if write:
|
||||
sqlconn.execute("PRAGMA journal_mode=wal")
|
||||
|
||||
if transaction:
|
||||
start_time = time.time()
|
||||
while 1:
|
||||
try:
|
||||
sqlconn.execute("begin immediate")
|
||||
break
|
||||
except sqlite3.OperationalError:
|
||||
# another thread may be writing, give it a chance to finish
|
||||
time.sleep(0.1)
|
||||
if time.time() - start_time > 5:
|
||||
# if it takes this long, something is wrong
|
||||
raise
|
||||
conn = Connection(sqlconn, write=write)
|
||||
if closing:
|
||||
conn = contextlib.closing(conn)
|
||||
return conn
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_transaction(self):
|
||||
conn = self._get_connection(closing=False, write=True, transaction=True)
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
raise
|
||||
else:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def read_connection(self, closing=True) -> Connection:
|
||||
return self._get_connection(closing=closing, write=False)
|
||||
|
||||
def get_schema_version(self) -> int:
|
||||
with self.read_connection() as conn:
|
||||
dbversion = conn.execute("PRAGMA user_version").fetchone()[0]
|
||||
return dbversion
|
||||
|
||||
CURRENT_DBVERSION = 1
|
||||
|
||||
def ensure_tables(self):
|
||||
with self.write_transaction() as conn:
|
||||
if self.get_schema_version() > 1:
|
||||
raise DBError(
|
||||
"version is %s; downgrading schema is not supported"
|
||||
% (self.get_schema_version(),)
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
addr TEXT PRIMARY KEY,
|
||||
password TEXT,
|
||||
last_login INTEGER
|
||||
)
|
||||
""",
|
||||
)
|
||||
conn.execute("PRAGMA user_version=%s" % (self.CURRENT_DBVERSION,))
|
||||
@@ -6,6 +6,8 @@ import sys
|
||||
|
||||
from .config import Config, echobot_password_path, read_config
|
||||
from .dictproxy import DictProxy
|
||||
from .lastlogin import set_user_password
|
||||
from .migrate_db import migrate_from_db_to_maildir
|
||||
|
||||
NOCREATE_FILE = "/etc/chatmail-nocreate"
|
||||
|
||||
@@ -141,6 +143,9 @@ class AuthDictProxy(DictProxy):
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
else:
|
||||
if not enc_password:
|
||||
# writing the password might have crashed and file is empty
|
||||
return {}
|
||||
return self.config.get_user_dict(user, enc_password=enc_password)
|
||||
|
||||
def lookup_passdb(self, user, cleartext_password):
|
||||
@@ -151,7 +156,8 @@ class AuthDictProxy(DictProxy):
|
||||
return
|
||||
|
||||
enc_password = encrypt_password(cleartext_password)
|
||||
self.config.set_user_password(user, enc_password=enc_password)
|
||||
set_user_password(self.config, user, enc_password=enc_password)
|
||||
print(f"Created address: {user}", file=sys.stderr)
|
||||
return self.config.get_user_dict(user, enc_password=enc_password)
|
||||
|
||||
|
||||
@@ -159,6 +165,8 @@ def main():
|
||||
socket, cfgpath = sys.argv[1:]
|
||||
config = read_config(cfgpath)
|
||||
|
||||
migrate_from_db_to_maildir(config)
|
||||
|
||||
dictproxy = AuthDictProxy(config=config)
|
||||
|
||||
dictproxy.serve_forever_from_socket(socket)
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import filelock
|
||||
|
||||
from .config import read_config
|
||||
from .dictproxy import DictProxy
|
||||
|
||||
@@ -16,18 +18,32 @@ def get_daytimestamp(timestamp) -> int:
|
||||
def write_last_login_to_userdir(userdir, timestamp):
|
||||
target = userdir.joinpath(LAST_LOGIN)
|
||||
timestamp = get_daytimestamp(timestamp)
|
||||
st = target.stat()
|
||||
if int(st.st_mtime) != timestamp:
|
||||
os.utime(target, (timestamp, timestamp))
|
||||
try:
|
||||
s = os.stat(target)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
else:
|
||||
if int(s.st_mtime) != timestamp:
|
||||
os.utime(target, (timestamp, timestamp))
|
||||
|
||||
|
||||
def get_last_login_from_userdir(userdir) -> int:
|
||||
if "@" not in userdir.name:
|
||||
if "@" not in userdir.name or userdir.name.startswith("echo@"):
|
||||
return get_daytimestamp(time.time())
|
||||
target = userdir.joinpath(LAST_LOGIN)
|
||||
return int(target.stat().st_mtime)
|
||||
|
||||
|
||||
def set_user_password(config, addr, enc_password):
|
||||
assert not addr.startswith("echo@"), addr
|
||||
userdir = config.get_user_maildir(addr)
|
||||
userdir.mkdir(exist_ok=True)
|
||||
password_path = userdir.joinpath("password")
|
||||
lock_path = password_path.with_suffix(".lock")
|
||||
with filelock.FileLock(lock_path):
|
||||
password_path.write_bytes(enc_password.encode("ascii"))
|
||||
|
||||
|
||||
class LastLoginDictProxy(DictProxy):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -38,6 +54,8 @@ class LastLoginDictProxy(DictProxy):
|
||||
value = parts[2] if len(parts) > 2 else ""
|
||||
addr = self.transactions[transaction_id]["addr"]
|
||||
if keyname[0] == "shared" and keyname[1] == "last-login":
|
||||
if addr.startswith("echo@"):
|
||||
return
|
||||
addr = keyname[2]
|
||||
timestamp = int(value)
|
||||
userdir = self.config.get_user_maildir(addr)
|
||||
|
||||
64
chatmaild/src/chatmaild/migrate_db.py
Normal file
64
chatmaild/src/chatmaild/migrate_db.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
migration code from old sqlite databases into per-maildir "password" files
|
||||
where mtime reflects and is updated to be the "last-login" time.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
from chatmaild.config import read_config
|
||||
from chatmaild.lastlogin import set_user_password, write_last_login_to_userdir
|
||||
|
||||
|
||||
def get_all_rows(path):
|
||||
assert path.exists()
|
||||
uri = f"file:{path}?mode=ro"
|
||||
sqlconn = sqlite3.connect(uri, timeout=60, isolation_level="DEFERRED", uri=True)
|
||||
cur = sqlconn.cursor()
|
||||
cur.execute("SELECT * from users")
|
||||
rows = cur.fetchall()
|
||||
sqlconn.close()
|
||||
return rows
|
||||
|
||||
|
||||
def migrate_from_db_to_maildir(config, chunking=10000):
|
||||
path = config.passdb_path
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
all_rows = get_all_rows(path)
|
||||
|
||||
rows = [row for row in all_rows if row[0][:3] not in ("ci-", "ac_")]
|
||||
|
||||
logging.info(f"ignoring {len(all_rows)-len(rows)} CI accounts")
|
||||
logging.info(f"migrating {len(rows)} sqlite database passwords to user dirs")
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
addr = row[0]
|
||||
# don't transfer special/CI accounts (IOLO)
|
||||
if addr.startswith("echo@"):
|
||||
continue
|
||||
enc_password = row[1]
|
||||
set_user_password(config, addr, enc_password=enc_password)
|
||||
if len(row) == 3 and row[2]:
|
||||
homedir = config.mailboxes_dir.joinpath(addr)
|
||||
timestamp = int(row[2])
|
||||
write_last_login_to_userdir(homedir, timestamp)
|
||||
|
||||
if i > 0 and i % chunking == 0:
|
||||
logging.info(f"migration-progress: {i} passwords transferred")
|
||||
|
||||
logging.info("migration: all passwords migrated")
|
||||
oldpath = config.passdb_path.with_suffix(config.passdb_path.suffix + ".old")
|
||||
os.rename(config.passdb_path, oldpath)
|
||||
for path in config.passdb_path.parent.iterdir():
|
||||
if path.name.startswith(config.passdb_path.name + "-"):
|
||||
path.unlink()
|
||||
logging.info(f"migration: moved database to {oldpath!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = read_config(sys.argv[1])
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
migrate_from_db_to_maildir(config)
|
||||
@@ -19,7 +19,7 @@ def test_login_timestamps(tmp_path):
|
||||
assert get_last_login_from_userdir(userdir) == 86400 * 2
|
||||
|
||||
|
||||
def test_delete_skips_non_email_dir(db, example_config):
|
||||
def test_delete_skips_non_email_dir(example_config):
|
||||
userdir = example_config.get_user_maildir("something")
|
||||
userdir.mkdir()
|
||||
get_last_login_from_userdir(userdir)
|
||||
|
||||
@@ -43,6 +43,30 @@ def test_handle_dovecot_request_last_login(testaddr, example_config):
|
||||
assert len(dictproxy.transactions) == 0
|
||||
|
||||
|
||||
def test_handle_dovecot_request_last_login_echobot(example_config):
|
||||
dictproxy = LastLoginDictProxy(config=example_config)
|
||||
|
||||
authproxy = AuthDictProxy(config=example_config)
|
||||
testaddr = f"echo@{example_config.mail_domain}"
|
||||
authproxy.lookup_passdb(testaddr, "ignore")
|
||||
userdir = dictproxy.config.get_user_maildir(testaddr)
|
||||
|
||||
# set last-login info for user
|
||||
tx = "1111"
|
||||
msg = f"B{tx}\t{testaddr}"
|
||||
res = dictproxy.handle_dovecot_request(msg)
|
||||
assert not res
|
||||
assert dictproxy.transactions == {tx: dict(addr=testaddr, res="O\n")}
|
||||
|
||||
timestamp = int(time.time())
|
||||
msg = f"S{tx}\tshared/last-login/{testaddr}\t{timestamp}"
|
||||
res = dictproxy.handle_dovecot_request(msg)
|
||||
assert not res
|
||||
assert len(dictproxy.transactions) == 1
|
||||
read_timestamp = get_last_login_from_userdir(userdir)
|
||||
assert read_timestamp == time.time() // 86400 * 86400
|
||||
|
||||
|
||||
def test_login_timestamp(testaddr, example_config):
|
||||
dictproxy = LastLoginDictProxy(config=example_config)
|
||||
authproxy = AuthDictProxy(config=example_config)
|
||||
|
||||
67
chatmaild/src/chatmaild/tests/test_migrate_db.py
Normal file
67
chatmaild/src/chatmaild/tests/test_migrate_db.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import sqlite3
|
||||
|
||||
from chatmaild.lastlogin import get_last_login_from_userdir
|
||||
from chatmaild.migrate_db import migrate_from_db_to_maildir
|
||||
|
||||
|
||||
def test_migration_not_exists(tmp_path, example_config):
|
||||
example_config.passdb_path = tmp_path.joinpath("sqlite")
|
||||
|
||||
|
||||
def test_migration(tmp_path, example_config, caplog):
|
||||
passdb_path = tmp_path.joinpath("passdb.sqlite")
|
||||
uri = f"file:{passdb_path}?mode=rwc"
|
||||
sqlconn = sqlite3.connect(uri, timeout=60, uri=True)
|
||||
sqlconn.execute(
|
||||
"""
|
||||
CREATE TABLE users (
|
||||
addr TEXT PRIMARY KEY,
|
||||
password TEXT,
|
||||
last_login INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
all = {}
|
||||
|
||||
for i in range(500):
|
||||
values = (f"somsom{i:03}@example.org", f"passwo{i:03}", i * 86400)
|
||||
sqlconn.execute(
|
||||
"""
|
||||
INSERT INTO users (addr, password, last_login)
|
||||
VALUES (?, ?, ?)""",
|
||||
values,
|
||||
)
|
||||
all[values[0]] = values[1:]
|
||||
|
||||
for i in range(500):
|
||||
values = (f"pompom{i:03}@example.org", f"wopass{i:03}", "")
|
||||
sqlconn.execute(
|
||||
"""
|
||||
INSERT INTO users (addr, password, last_login)
|
||||
VALUES (?, ?, ?)""",
|
||||
values,
|
||||
)
|
||||
all[values[0]] = values[1:]
|
||||
|
||||
sqlconn.commit()
|
||||
sqlconn.close()
|
||||
|
||||
assert passdb_path.stat().st_size > 10000
|
||||
|
||||
example_config.passdb_path = passdb_path
|
||||
|
||||
assert not caplog.records
|
||||
|
||||
migrate_from_db_to_maildir(example_config, chunking=500)
|
||||
assert len(caplog.records) > 3
|
||||
|
||||
for path in example_config.mailboxes_dir.iterdir():
|
||||
if "@" not in path.name:
|
||||
continue
|
||||
password, last_login = all.pop(path.name)
|
||||
if last_login:
|
||||
assert get_last_login_from_userdir(path) == last_login
|
||||
assert password == path.joinpath("password").read_text()
|
||||
|
||||
assert not all
|
||||
assert not example_config.passdb_path.exists()
|
||||
Reference in New Issue
Block a user