mirror of
https://github.com/chatmail/relay.git
synced 2026-05-19 12:28:06 +00:00
make doveauth also use generic dictproxy
This commit is contained in:
@@ -1,4 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from socketserver import (
|
||||||
|
StreamRequestHandler,
|
||||||
|
ThreadingMixIn,
|
||||||
|
UnixStreamServer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DictProxy:
|
class DictProxy:
|
||||||
@@ -66,3 +72,29 @@ class DictProxy:
|
|||||||
# because our dovecot config does not involve
|
# because our dovecot config does not involve
|
||||||
# multiple set-operations in a single commit
|
# multiple set-operations in a single commit
|
||||||
return self.transactions.pop(transaction_id)["res"]
|
return self.transactions.pop(transaction_id)["res"]
|
||||||
|
|
||||||
|
def serve_forever_from_socket(self, socket):
|
||||||
|
dictproxy = self
|
||||||
|
|
||||||
|
class Handler(StreamRequestHandler):
|
||||||
|
def handle(self):
|
||||||
|
try:
|
||||||
|
dictproxy.loop_forever(self.rfile, self.wfile)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Exception in the handler")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.unlink(socket)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with ThreadedUnixStreamServer(socket, Handler) as server:
|
||||||
|
try:
|
||||||
|
server.serve_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer):
|
||||||
|
request_queue_size = 100
|
||||||
|
|||||||
@@ -5,14 +5,10 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from socketserver import (
|
|
||||||
StreamRequestHandler,
|
|
||||||
ThreadingMixIn,
|
|
||||||
UnixStreamServer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import Config, read_config
|
from .config import Config, read_config
|
||||||
from .database import Database
|
from .database import Database
|
||||||
|
from .dictproxy import DictProxy
|
||||||
|
|
||||||
NOCREATE_FILE = "/etc/chatmail-nocreate"
|
NOCREATE_FILE = "/etc/chatmail-nocreate"
|
||||||
|
|
||||||
@@ -178,15 +174,13 @@ def split_and_unescape(s):
|
|||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
|
||||||
def handle_dovecot_request(msg, db, config: Config):
|
class AuthDictProxy(DictProxy):
|
||||||
# see https://doc.dovecot.org/3.0/developer_manual/design/dict_protocol/
|
def __init__(self, db, config):
|
||||||
short_command = msg[0]
|
super().__init__()
|
||||||
if short_command == "H": # HELLO
|
self.db = db
|
||||||
# we don't do any checking on versions and just return
|
self.config = config
|
||||||
return
|
|
||||||
elif short_command == "L": # LOOKUP
|
|
||||||
parts = msg[1:].split("\t")
|
|
||||||
|
|
||||||
|
def handle_lookup(self, parts):
|
||||||
# Dovecot <2.3.17 has only one part,
|
# Dovecot <2.3.17 has only one part,
|
||||||
# do not attempt to read any other parts for compatibility.
|
# do not attempt to read any other parts for compatibility.
|
||||||
keyname = parts[0]
|
keyname = parts[0]
|
||||||
@@ -194,6 +188,8 @@ def handle_dovecot_request(msg, db, config: Config):
|
|||||||
namespace, type, args = keyname.split("/", 2)
|
namespace, type, args = keyname.split("/", 2)
|
||||||
args = list(split_and_unescape(args))
|
args = list(split_and_unescape(args))
|
||||||
|
|
||||||
|
config = self.config
|
||||||
|
db = self.db
|
||||||
reply_command = "F"
|
reply_command = "F"
|
||||||
res = ""
|
res = ""
|
||||||
if namespace == "shared":
|
if namespace == "shared":
|
||||||
@@ -215,55 +211,20 @@ def handle_dovecot_request(msg, db, config: Config):
|
|||||||
reply_command = "N"
|
reply_command = "N"
|
||||||
json_res = json.dumps(res) if res else ""
|
json_res = json.dumps(res) if res else ""
|
||||||
return f"{reply_command}{json_res}\n"
|
return f"{reply_command}{json_res}\n"
|
||||||
elif short_command == "I": # ITERATE
|
|
||||||
|
def handle_iterate(self, parts):
|
||||||
# example: I0\t0\tshared/userdb/
|
# example: I0\t0\tshared/userdb/
|
||||||
parts = msg[1:].split("\t")
|
|
||||||
if parts[2] == "shared/userdb/":
|
if parts[2] == "shared/userdb/":
|
||||||
|
db = self.db
|
||||||
result = "".join(f"Oshared/userdb/{user}\t\n" for user in iter_userdb(db))
|
result = "".join(f"Oshared/userdb/{user}\t\n" for user in iter_userdb(db))
|
||||||
return f"{result}\n"
|
return f"{result}\n"
|
||||||
|
|
||||||
raise UnknownCommand(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_dovecot_protocol(rfile, wfile, db: Database, config: Config):
|
|
||||||
while True:
|
|
||||||
msg = rfile.readline().strip().decode()
|
|
||||||
if not msg:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
res = handle_dovecot_request(msg, db, config)
|
|
||||||
except UnknownCommand:
|
|
||||||
logging.warning(f"unknown command: {msg!r}")
|
|
||||||
else:
|
|
||||||
if res:
|
|
||||||
wfile.write(res.encode("ascii"))
|
|
||||||
wfile.flush()
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer):
|
|
||||||
request_queue_size = 100
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
socket, cfgpath = sys.argv[1:]
|
socket, cfgpath = sys.argv[1:]
|
||||||
config = read_config(cfgpath)
|
config = read_config(cfgpath)
|
||||||
db = Database(config.passdb_path)
|
db = Database(config.passdb_path)
|
||||||
|
|
||||||
class Handler(StreamRequestHandler):
|
dictproxy = AuthDictProxy(db=db, config=config)
|
||||||
def handle(self):
|
|
||||||
try:
|
|
||||||
handle_dovecot_protocol(self.rfile, self.wfile, db, config)
|
|
||||||
except Exception:
|
|
||||||
logging.exception("Exception in the handler")
|
|
||||||
raise
|
|
||||||
|
|
||||||
try:
|
dictproxy.serve_forever_from_socket(socket)
|
||||||
os.unlink(socket)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
with ThreadedUnixStreamServer(socket, Handler) as server:
|
|
||||||
try:
|
|
||||||
server.serve_forever()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ import chatmaild.doveauth
|
|||||||
import pytest
|
import pytest
|
||||||
from chatmaild.database import DBError
|
from chatmaild.database import DBError
|
||||||
from chatmaild.doveauth import (
|
from chatmaild.doveauth import (
|
||||||
|
AuthDictProxy,
|
||||||
get_user_data,
|
get_user_data,
|
||||||
handle_dovecot_protocol,
|
|
||||||
handle_dovecot_request,
|
|
||||||
is_allowed_to_create,
|
is_allowed_to_create,
|
||||||
iter_userdb,
|
iter_userdb,
|
||||||
iter_userdb_lastlogin_before,
|
iter_userdb_lastlogin_before,
|
||||||
@@ -105,12 +104,14 @@ def test_too_high_db_version(db):
|
|||||||
|
|
||||||
|
|
||||||
def test_handle_dovecot_request(db, example_config):
|
def test_handle_dovecot_request(db, example_config):
|
||||||
|
dictproxy = AuthDictProxy(db=db, config=example_config)
|
||||||
|
|
||||||
# Test that password can contain ", ', \ and /
|
# Test that password can contain ", ', \ and /
|
||||||
msg = (
|
msg = (
|
||||||
'Lshared/passdb/laksjdlaksjdlak\\\\sjdlk\\"12j\\\'3l1/k2j3123"'
|
'Lshared/passdb/laksjdlaksjdlak\\\\sjdlk\\"12j\\\'3l1/k2j3123"'
|
||||||
"some42123@chat.example.org\tsome42123@chat.example.org"
|
"some42123@chat.example.org\tsome42123@chat.example.org"
|
||||||
)
|
)
|
||||||
res = handle_dovecot_request(msg, db, example_config)
|
res = dictproxy.handle_dovecot_request(msg)
|
||||||
assert res
|
assert res
|
||||||
assert res[0] == "O" and res.endswith("\n")
|
assert res[0] == "O" and res.endswith("\n")
|
||||||
userdata = json.loads(res[1:].strip())
|
userdata = json.loads(res[1:].strip())
|
||||||
@@ -120,28 +121,31 @@ def test_handle_dovecot_request(db, example_config):
|
|||||||
|
|
||||||
|
|
||||||
def test_handle_dovecot_protocol_hello_is_skipped(db, example_config, caplog):
|
def test_handle_dovecot_protocol_hello_is_skipped(db, example_config, caplog):
|
||||||
|
dictproxy = AuthDictProxy(db=db, config=example_config)
|
||||||
rfile = io.BytesIO(b"H3\t2\t0\t\tauth\n")
|
rfile = io.BytesIO(b"H3\t2\t0\t\tauth\n")
|
||||||
wfile = io.BytesIO()
|
wfile = io.BytesIO()
|
||||||
handle_dovecot_protocol(rfile, wfile, db, example_config)
|
dictproxy.loop_forever(rfile, wfile)
|
||||||
assert wfile.getvalue() == b""
|
assert wfile.getvalue() == b""
|
||||||
assert not caplog.messages
|
assert not caplog.messages
|
||||||
|
|
||||||
|
|
||||||
def test_handle_dovecot_protocol(db, example_config):
|
def test_handle_dovecot_protocol(db, example_config):
|
||||||
|
dictproxy = AuthDictProxy(db=db, config=example_config)
|
||||||
rfile = io.BytesIO(
|
rfile = io.BytesIO(
|
||||||
b"H3\t2\t0\t\tauth\nLshared/userdb/foobar@chat.example.org\tfoobar@chat.example.org\n"
|
b"H3\t2\t0\t\tauth\nLshared/userdb/foobar@chat.example.org\tfoobar@chat.example.org\n"
|
||||||
)
|
)
|
||||||
wfile = io.BytesIO()
|
wfile = io.BytesIO()
|
||||||
handle_dovecot_protocol(rfile, wfile, db, example_config)
|
dictproxy.loop_forever(rfile, wfile)
|
||||||
assert wfile.getvalue() == b"N\n"
|
assert wfile.getvalue() == b"N\n"
|
||||||
|
|
||||||
|
|
||||||
def test_handle_dovecot_protocol_iterate(db, gencreds, example_config):
|
def test_handle_dovecot_protocol_iterate(db, gencreds, example_config):
|
||||||
|
dictproxy = AuthDictProxy(db=db, config=example_config)
|
||||||
lookup_passdb(db, example_config, "asdf00000@chat.example.org", "q9mr3faue")
|
lookup_passdb(db, example_config, "asdf00000@chat.example.org", "q9mr3faue")
|
||||||
lookup_passdb(db, example_config, "asdf11111@chat.example.org", "q9mr3faue")
|
lookup_passdb(db, example_config, "asdf11111@chat.example.org", "q9mr3faue")
|
||||||
rfile = io.BytesIO(b"H3\t2\t0\t\tauth\nI0\t0\tshared/userdb/")
|
rfile = io.BytesIO(b"H3\t2\t0\t\tauth\nI0\t0\tshared/userdb/")
|
||||||
wfile = io.BytesIO()
|
wfile = io.BytesIO()
|
||||||
handle_dovecot_protocol(rfile, wfile, db, example_config)
|
dictproxy.loop_forever(rfile, wfile)
|
||||||
lines = wfile.getvalue().decode("ascii").split("\n")
|
lines = wfile.getvalue().decode("ascii").split("\n")
|
||||||
assert lines[0] == "Oshared/userdb/asdf00000@chat.example.org\t"
|
assert lines[0] == "Oshared/userdb/asdf00000@chat.example.org\t"
|
||||||
assert lines[1] == "Oshared/userdb/asdf11111@chat.example.org\t"
|
assert lines[1] == "Oshared/userdb/asdf11111@chat.example.org\t"
|
||||||
|
|||||||
Reference in New Issue
Block a user