make doveauth also use generic dictproxy

This commit is contained in:
holger krekel
2024-07-21 17:39:18 +02:00
parent 22d77f4680
commit 52aa7cad06
3 changed files with 56 additions and 59 deletions

View File

@@ -1,4 +1,10 @@
import logging
import os
from socketserver import (
StreamRequestHandler,
ThreadingMixIn,
UnixStreamServer,
)
class DictProxy:
@@ -66,3 +72,29 @@ class DictProxy:
# because our dovecot config does not involve
# multiple set-operations in a single commit
return self.transactions.pop(transaction_id)["res"]
def serve_forever_from_socket(self, socket):
dictproxy = self
class Handler(StreamRequestHandler):
def handle(self):
try:
dictproxy.loop_forever(self.rfile, self.wfile)
except Exception:
logging.exception("Exception in the handler")
raise
try:
os.unlink(socket)
except FileNotFoundError:
pass
with ThreadedUnixStreamServer(socket, Handler) as server:
try:
server.serve_forever()
except KeyboardInterrupt:
pass
class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer):
request_queue_size = 100

View File

@@ -5,14 +5,10 @@ import os
import sys
import time
from pathlib import Path
from socketserver import (
StreamRequestHandler,
ThreadingMixIn,
UnixStreamServer,
)
from .config import Config, read_config
from .database import Database
from .dictproxy import DictProxy
NOCREATE_FILE = "/etc/chatmail-nocreate"
@@ -178,15 +174,13 @@ def split_and_unescape(s):
yield out
def handle_dovecot_request(msg, db, config: Config):
# see https://doc.dovecot.org/3.0/developer_manual/design/dict_protocol/
short_command = msg[0]
if short_command == "H": # HELLO
# we don't do any checking on versions and just return
return
elif short_command == "L": # LOOKUP
parts = msg[1:].split("\t")
class AuthDictProxy(DictProxy):
def __init__(self, db, config):
super().__init__()
self.db = db
self.config = config
def handle_lookup(self, parts):
# Dovecot <2.3.17 has only one part,
# do not attempt to read any other parts for compatibility.
keyname = parts[0]
@@ -194,6 +188,8 @@ def handle_dovecot_request(msg, db, config: Config):
namespace, type, args = keyname.split("/", 2)
args = list(split_and_unescape(args))
config = self.config
db = self.db
reply_command = "F"
res = ""
if namespace == "shared":
@@ -215,55 +211,20 @@ def handle_dovecot_request(msg, db, config: Config):
reply_command = "N"
json_res = json.dumps(res) if res else ""
return f"{reply_command}{json_res}\n"
elif short_command == "I": # ITERATE
def handle_iterate(self, parts):
# example: I0\t0\tshared/userdb/
parts = msg[1:].split("\t")
if parts[2] == "shared/userdb/":
db = self.db
result = "".join(f"Oshared/userdb/{user}\t\n" for user in iter_userdb(db))
return f"{result}\n"
raise UnknownCommand(msg)
def handle_dovecot_protocol(rfile, wfile, db: Database, config: Config):
while True:
msg = rfile.readline().strip().decode()
if not msg:
break
try:
res = handle_dovecot_request(msg, db, config)
except UnknownCommand:
logging.warning(f"unknown command: {msg!r}")
else:
if res:
wfile.write(res.encode("ascii"))
wfile.flush()
class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer):
request_queue_size = 100
def main():
socket, cfgpath = sys.argv[1:]
config = read_config(cfgpath)
db = Database(config.passdb_path)
class Handler(StreamRequestHandler):
def handle(self):
try:
handle_dovecot_protocol(self.rfile, self.wfile, db, config)
except Exception:
logging.exception("Exception in the handler")
raise
dictproxy = AuthDictProxy(db=db, config=config)
try:
os.unlink(socket)
except FileNotFoundError:
pass
with ThreadedUnixStreamServer(socket, Handler) as server:
try:
server.serve_forever()
except KeyboardInterrupt:
pass
dictproxy.serve_forever_from_socket(socket)

View File

@@ -8,9 +8,8 @@ import chatmaild.doveauth
import pytest
from chatmaild.database import DBError
from chatmaild.doveauth import (
AuthDictProxy,
get_user_data,
handle_dovecot_protocol,
handle_dovecot_request,
is_allowed_to_create,
iter_userdb,
iter_userdb_lastlogin_before,
@@ -105,12 +104,14 @@ def test_too_high_db_version(db):
def test_handle_dovecot_request(db, example_config):
dictproxy = AuthDictProxy(db=db, config=example_config)
# Test that password can contain ", ', \ and /
msg = (
'Lshared/passdb/laksjdlaksjdlak\\\\sjdlk\\"12j\\\'3l1/k2j3123"'
"some42123@chat.example.org\tsome42123@chat.example.org"
)
res = handle_dovecot_request(msg, db, example_config)
res = dictproxy.handle_dovecot_request(msg)
assert res
assert res[0] == "O" and res.endswith("\n")
userdata = json.loads(res[1:].strip())
@@ -120,28 +121,31 @@ def test_handle_dovecot_request(db, example_config):
def test_handle_dovecot_protocol_hello_is_skipped(db, example_config, caplog):
dictproxy = AuthDictProxy(db=db, config=example_config)
rfile = io.BytesIO(b"H3\t2\t0\t\tauth\n")
wfile = io.BytesIO()
handle_dovecot_protocol(rfile, wfile, db, example_config)
dictproxy.loop_forever(rfile, wfile)
assert wfile.getvalue() == b""
assert not caplog.messages
def test_handle_dovecot_protocol(db, example_config):
dictproxy = AuthDictProxy(db=db, config=example_config)
rfile = io.BytesIO(
b"H3\t2\t0\t\tauth\nLshared/userdb/foobar@chat.example.org\tfoobar@chat.example.org\n"
)
wfile = io.BytesIO()
handle_dovecot_protocol(rfile, wfile, db, example_config)
dictproxy.loop_forever(rfile, wfile)
assert wfile.getvalue() == b"N\n"
def test_handle_dovecot_protocol_iterate(db, gencreds, example_config):
dictproxy = AuthDictProxy(db=db, config=example_config)
lookup_passdb(db, example_config, "asdf00000@chat.example.org", "q9mr3faue")
lookup_passdb(db, example_config, "asdf11111@chat.example.org", "q9mr3faue")
rfile = io.BytesIO(b"H3\t2\t0\t\tauth\nI0\t0\tshared/userdb/")
wfile = io.BytesIO()
handle_dovecot_protocol(rfile, wfile, db, example_config)
dictproxy.loop_forever(rfile, wfile)
lines = wfile.getvalue().decode("ascii").split("\n")
assert lines[0] == "Oshared/userdb/asdf00000@chat.example.org\t"
assert lines[1] == "Oshared/userdb/asdf11111@chat.example.org\t"