From 52aa7cad06f9042e3bda9a3f8fb16df4832fa2c4 Mon Sep 17 00:00:00 2001 From: holger krekel Date: Sun, 21 Jul 2024 17:39:18 +0200 Subject: [PATCH] make doveauth also use generic dictproxy --- chatmaild/src/chatmaild/dictproxy.py | 32 +++++++++ chatmaild/src/chatmaild/doveauth.py | 67 ++++--------------- .../src/chatmaild/tests/test_doveauth.py | 16 +++-- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/chatmaild/src/chatmaild/dictproxy.py b/chatmaild/src/chatmaild/dictproxy.py index 7191f12e..e2f299c0 100644 --- a/chatmaild/src/chatmaild/dictproxy.py +++ b/chatmaild/src/chatmaild/dictproxy.py @@ -1,4 +1,10 @@ import logging +import os +from socketserver import ( + StreamRequestHandler, + ThreadingMixIn, + UnixStreamServer, +) class DictProxy: @@ -66,3 +72,29 @@ class DictProxy: # because our dovecot config does not involve # multiple set-operations in a single commit return self.transactions.pop(transaction_id)["res"] + + def serve_forever_from_socket(self, socket): + dictproxy = self + + class Handler(StreamRequestHandler): + def handle(self): + try: + dictproxy.loop_forever(self.rfile, self.wfile) + except Exception: + logging.exception("Exception in the handler") + raise + + try: + os.unlink(socket) + except FileNotFoundError: + pass + + with ThreadedUnixStreamServer(socket, Handler) as server: + try: + server.serve_forever() + except KeyboardInterrupt: + pass + + +class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): + request_queue_size = 100 diff --git a/chatmaild/src/chatmaild/doveauth.py b/chatmaild/src/chatmaild/doveauth.py index 7a0cbaac..e4da3b05 100644 --- a/chatmaild/src/chatmaild/doveauth.py +++ b/chatmaild/src/chatmaild/doveauth.py @@ -5,14 +5,10 @@ import os import sys import time from pathlib import Path -from socketserver import ( - StreamRequestHandler, - ThreadingMixIn, - UnixStreamServer, -) from .config import Config, read_config from .database import Database +from .dictproxy import DictProxy NOCREATE_FILE = "/etc/chatmail-nocreate" @@ -178,15 +174,13 @@ def split_and_unescape(s): yield out -def handle_dovecot_request(msg, db, config: Config): - # see https://doc.dovecot.org/3.0/developer_manual/design/dict_protocol/ - short_command = msg[0] - if short_command == "H": # HELLO - # we don't do any checking on versions and just return - return - elif short_command == "L": # LOOKUP - parts = msg[1:].split("\t") +class AuthDictProxy(DictProxy): + def __init__(self, db, config): + super().__init__() + self.db = db + self.config = config + def handle_lookup(self, parts): # Dovecot <2.3.17 has only one part, # do not attempt to read any other parts for compatibility. keyname = parts[0] @@ -194,6 +188,8 @@ def handle_dovecot_request(msg, db, config: Config): namespace, type, args = keyname.split("/", 2) args = list(split_and_unescape(args)) + config = self.config + db = self.db reply_command = "F" res = "" if namespace == "shared": @@ -215,55 +211,20 @@ def handle_dovecot_request(msg, db, config: Config): reply_command = "N" json_res = json.dumps(res) if res else "" return f"{reply_command}{json_res}\n" - elif short_command == "I": # ITERATE + + def handle_iterate(self, parts): # example: I0\t0\tshared/userdb/ - parts = msg[1:].split("\t") if parts[2] == "shared/userdb/": + db = self.db result = "".join(f"Oshared/userdb/{user}\t\n" for user in iter_userdb(db)) return f"{result}\n" - raise UnknownCommand(msg) - - -def handle_dovecot_protocol(rfile, wfile, db: Database, config: Config): - while True: - msg = rfile.readline().strip().decode() - if not msg: - break - try: - res = handle_dovecot_request(msg, db, config) - except UnknownCommand: - logging.warning(f"unknown command: {msg!r}") - else: - if res: - wfile.write(res.encode("ascii")) - wfile.flush() - - -class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): - request_queue_size = 100 - def main(): socket, cfgpath = sys.argv[1:] config = read_config(cfgpath) db = Database(config.passdb_path) - class Handler(StreamRequestHandler): - def handle(self): - try: - handle_dovecot_protocol(self.rfile, self.wfile, db, config) - except Exception: - logging.exception("Exception in the handler") - raise + dictproxy = AuthDictProxy(db=db, config=config) - try: - os.unlink(socket) - except FileNotFoundError: - pass - - with ThreadedUnixStreamServer(socket, Handler) as server: - try: - server.serve_forever() - except KeyboardInterrupt: - pass + dictproxy.serve_forever_from_socket(socket) diff --git a/chatmaild/src/chatmaild/tests/test_doveauth.py b/chatmaild/src/chatmaild/tests/test_doveauth.py index 2ac9e239..9b12d1d5 100644 --- a/chatmaild/src/chatmaild/tests/test_doveauth.py +++ b/chatmaild/src/chatmaild/tests/test_doveauth.py @@ -8,9 +8,8 @@ import chatmaild.doveauth import pytest from chatmaild.database import DBError from chatmaild.doveauth import ( + AuthDictProxy, get_user_data, - handle_dovecot_protocol, - handle_dovecot_request, is_allowed_to_create, iter_userdb, iter_userdb_lastlogin_before, @@ -105,12 +104,14 @@ def test_too_high_db_version(db): def test_handle_dovecot_request(db, example_config): + dictproxy = AuthDictProxy(db=db, config=example_config) + # Test that password can contain ", ', \ and / msg = ( 'Lshared/passdb/laksjdlaksjdlak\\\\sjdlk\\"12j\\\'3l1/k2j3123"' "some42123@chat.example.org\tsome42123@chat.example.org" ) - res = handle_dovecot_request(msg, db, example_config) + res = dictproxy.handle_dovecot_request(msg) assert res assert res[0] == "O" and res.endswith("\n") userdata = json.loads(res[1:].strip()) @@ -120,28 +121,31 @@ def test_handle_dovecot_request(db, example_config): def test_handle_dovecot_protocol_hello_is_skipped(db, example_config, caplog): + dictproxy = AuthDictProxy(db=db, config=example_config) rfile = io.BytesIO(b"H3\t2\t0\t\tauth\n") wfile = io.BytesIO() - handle_dovecot_protocol(rfile, wfile, db, example_config) + dictproxy.loop_forever(rfile, wfile) assert wfile.getvalue() == b"" assert not caplog.messages def test_handle_dovecot_protocol(db, example_config): + dictproxy = AuthDictProxy(db=db, config=example_config) rfile = io.BytesIO( b"H3\t2\t0\t\tauth\nLshared/userdb/foobar@chat.example.org\tfoobar@chat.example.org\n" ) wfile = io.BytesIO() - handle_dovecot_protocol(rfile, wfile, db, example_config) + dictproxy.loop_forever(rfile, wfile) assert wfile.getvalue() == b"N\n" def test_handle_dovecot_protocol_iterate(db, gencreds, example_config): + dictproxy = AuthDictProxy(db=db, config=example_config) lookup_passdb(db, example_config, "asdf00000@chat.example.org", "q9mr3faue") lookup_passdb(db, example_config, "asdf11111@chat.example.org", "q9mr3faue") rfile = io.BytesIO(b"H3\t2\t0\t\tauth\nI0\t0\tshared/userdb/") wfile = io.BytesIO() - handle_dovecot_protocol(rfile, wfile, db, example_config) + dictproxy.loop_forever(rfile, wfile) lines = wfile.getvalue().decode("ascii").split("\n") assert lines[0] == "Oshared/userdb/asdf00000@chat.example.org\t" assert lines[1] == "Oshared/userdb/asdf11111@chat.example.org\t"