From 31e08832a6718d4f7d47cbcfe8defa75489b654d Mon Sep 17 00:00:00 2001 From: holger krekel Date: Sat, 21 Oct 2023 01:40:58 +0200 Subject: [PATCH] shift functions to a DictProxy class --- chatmaild/src/chatmaild/dictproxy.py | 115 ++++++++++++++------------- scripts/get_imap_capabilities.py | 1 - scripts/measure_tls_and_logins.py | 4 +- tests/chatmaild/conftest.py | 2 - tests/chatmaild/test_dictproxy.py | 30 ++++--- 5 files changed, 79 insertions(+), 73 deletions(-) diff --git a/chatmaild/src/chatmaild/dictproxy.py b/chatmaild/src/chatmaild/dictproxy.py index 068c3703..5b571369 100644 --- a/chatmaild/src/chatmaild/dictproxy.py +++ b/chatmaild/src/chatmaild/dictproxy.py @@ -21,66 +21,68 @@ def encrypt_password(password: str): return "{SHA512-CRYPT}" + passhash -def create_user(db, user, password): - if os.path.exists(NOCREATE_FILE): - logging.warning( - f"Didn't create account: {NOCREATE_FILE} exists. Delete the file to enable account creation." - ) - return - with db.write_transaction() as conn: - conn.create_user(user, password) - return dict(home=f"/home/vmail/{user}", uid="vmail", gid="vmail", password=password) +class DictProxy: + def __init__(self, db, mail_domain): + self.db = db + self.mail_domain = mail_domain + + def create_user(self, user, password): + if os.path.exists(NOCREATE_FILE): + logging.warning(f"Didn't create account: {NOCREATE_FILE} exists.") + return + with self.db.write_transaction() as conn: + conn.create_user(user, password) + return dict(home=f"/home/vmail/{user}", uid="vmail", gid="vmail", password=password) + + def get_user_data(self, user): + with self.db.read_connection() as conn: + result = conn.get_user(user) + if result: + result["uid"] = "vmail" + result["gid"] = "vmail" + return result -def get_user_data(db, user): - with db.read_connection() as conn: - result = conn.get_user(user) - if result: - result["uid"] = "vmail" - result["gid"] = "vmail" - return result + def lookup_userdb(self, user): + return self.get_user_data(user) -def lookup_userdb(db, user): - return get_user_data(db, user) + def lookup_passdb(self, user, password): + userdata = self.get_user_data(user) + if not userdata: + return self.create_user(user, encrypt_password(password)) + userdata["password"] = userdata["password"].strip() + return userdata -def lookup_passdb(db, user, password): - userdata = get_user_data(db, user) - if not userdata: - return create_user(db, user, encrypt_password(password)) - userdata["password"] = userdata["password"].strip() - return userdata - - -def handle_dovecot_request(msg, db, mail_domain): - print(f"received msg: {msg!r}", file=sys.stderr) - short_command = msg[0] - if short_command == "L": # LOOKUP - parts = msg[1:].split("\t") - keyname, user = parts[:2] - namespace, type, *args = keyname.split("/") - reply_command = "F" - res = "" - if namespace == "shared": - if type == "userdb": - if user.endswith(f"@{mail_domain}"): - res = lookup_userdb(db, user) - if res: - reply_command = "O" - else: - reply_command = "N" - elif type == "passdb": - if user.endswith(f"@{mail_domain}"): - res = lookup_passdb(db, user, password=args[0]) - if res: - reply_command = "O" - else: - reply_command = "N" - print(f"res: {res!r}", file=sys.stderr) - json_res = json.dumps(res) if res else "" - return f"{reply_command}{json_res}\n" - return None + def handle_dovecot_request(self, msg): + print(f"received msg: {msg!r}", file=sys.stderr) + short_command = msg[0] + if short_command == "L": # LOOKUP + parts = msg[1:].split("\t") + keyname, user = parts[:2] + namespace, type, *args = keyname.split("/") + reply_command = "F" + res = "" + if namespace == "shared": + if type == "userdb": + if user.endswith(f"@{self.mail_domain}"): + res = lookup_userdb(db, user) + if res: + reply_command = "O" + else: + reply_command = "N" + elif type == "passdb": + if user.endswith(f"@{self.mail_domain}"): + res = lookup_passdb(db, user, password=args[0]) + if res: + reply_command = "O" + else: + reply_command = "N" + print(f"res: {res!r}", file=sys.stderr) + json_res = json.dumps(res) if res else "" + return f"{reply_command}{json_res}\n" + return None class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): @@ -90,17 +92,18 @@ class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): def main(): socket = sys.argv[1] passwd_entry = pwd.getpwnam(sys.argv[2]) - db = Database(sys.argv[3]) with open("/etc/mailname", "r") as fp: mail_domain = fp.read().strip() + db = Database(sys.argv[3]) + dictproxy = DictProxy(db, mail_domain) class Handler(StreamRequestHandler): def handle(self): while True: msg = self.rfile.readline().strip().decode() if not msg: break - res = handle_dovecot_request(msg, db, mail_domain) + res = dictproxy.handle_dovecot_request(msg) if res: print(f"sending result: {res!r}", file=sys.stderr) self.wfile.write(res.encode("ascii")) diff --git a/scripts/get_imap_capabilities.py b/scripts/get_imap_capabilities.py index 6b8e171a..a300366a 100644 --- a/scripts/get_imap_capabilities.py +++ b/scripts/get_imap_capabilities.py @@ -11,4 +11,3 @@ conn.login(f"imapcapa", "pass") status, res = conn.capability() for capa in sorted(res[0].decode().split()): print(capa) - diff --git a/scripts/measure_tls_and_logins.py b/scripts/measure_tls_and_logins.py index fc0b6d1c..11e2c0ee 100755 --- a/scripts/measure_tls_and_logins.py +++ b/scripts/measure_tls_and_logins.py @@ -5,7 +5,7 @@ import imaplib domain = os.environ.get("CHATMAIL_DOMAIN", "c3.testrun.org") -NUM_CONNECTIONS=10 +NUM_CONNECTIONS = 10 conns = [] @@ -16,7 +16,7 @@ for i in range(NUM_CONNECTIONS): conns.append(conn) tlsdone = time.time() -duration = tlsdone-start +duration = tlsdone - start print(f"{duration}: TLS connections opening TLS connections") for i, conn in enumerate(conns): diff --git a/tests/chatmaild/conftest.py b/tests/chatmaild/conftest.py index 5e22e19f..c75f7510 100644 --- a/tests/chatmaild/conftest.py +++ b/tests/chatmaild/conftest.py @@ -1,4 +1,3 @@ - import pytest from chatmaild.database import Database @@ -8,4 +7,3 @@ def db(tmpdir): db_path = tmpdir / "passdb.sqlite" print("database path:", db_path) return Database(db_path) - diff --git a/tests/chatmaild/test_dictproxy.py b/tests/chatmaild/test_dictproxy.py index ee88f13a..babc09e2 100644 --- a/tests/chatmaild/test_dictproxy.py +++ b/tests/chatmaild/test_dictproxy.py @@ -3,32 +3,38 @@ import os import pytest import chatmaild.dictproxy -from chatmaild.dictproxy import get_user_data, lookup_passdb +from chatmaild.dictproxy import DictProxy from chatmaild.database import DBError -def test_basic(db, tmpdir, monkeypatch): - monkeypatch.setattr(chatmaild.dictproxy, "NOCREATE_FILE", tmpdir.join("nocreate").strpath) - lookup_passdb(db, "link2xt@c1.testrun.org", "asdf") - data = get_user_data(db, "link2xt@c1.testrun.org") - assert data +@pytest.fixture +def dictproxy(db, maildomain): + return DictProxy(db, maildomain) -def test_dont_overwrite_password_on_wrong_login(db): +def test_basic(dictproxy, tmpdir, monkeypatch): + monkeypatch.setattr( + chatmaild.dictproxy, "NOCREATE_FILE", tmpdir.join("nocreate").strpath + ) + dictproxy.lookup_passdb("link2xt@c1.testrun.org", "asdf") + assert dictproxy.get_user_data("link2xt@c1.testrun.org") + + +def test_dont_overwrite_password_on_wrong_login(dictproxy): """Test that logging in with a different password doesn't create a new user""" - res = lookup_passdb(db, "newuser1@something.org", "kajdlkajsldk12l3kj1983") + res = dictproxy.lookup_passdb("newuser1@something.org", "kajdlkajsldk12l3kj1983") assert res["password"] - res2 = lookup_passdb(db, "newuser1@something.org", "kajdlqweqwe") + res2 = dictproxy.lookup_passdb("newuser1@something.org", "kajdlqweqwe") # this function always returns a password hash, which is actually compared by dovecot. assert res["password"] == res2["password"] -def test_nocreate_file(db, tmpdir, monkeypatch): +def test_nocreate_file(dictproxy, tmpdir, monkeypatch): nocreate = tmpdir.join("nocreate") monkeypatch.setattr(chatmaild.dictproxy, "NOCREATE_FILE", str(nocreate)) nocreate.write("") - lookup_passdb(db, "newuser1@something.org", "kajdlqweqwe") - assert not get_user_data(db, "newuser1@something.org") + dictproxy.lookup_passdb("newuser1@something.org", "kajdlqweqwe") + assert not dictproxy.get_user_data("newuser1@something.org") def test_db_version(db):