Merge doveauth and filtermail into chatmaild

This commit is contained in:
link2xt
2023-10-15 15:45:35 +00:00
parent 262eb36a5c
commit b548a8ddbd
16 changed files with 29 additions and 75 deletions

7
chatmaild/README.md Normal file
View File

@@ -0,0 +1,7 @@
# doveauth
doveauth is a python tool
to create dovecot users on login.
It is called by the
[dovecot lua authentication module](https://doc.dovecot.org/configuration_manual/authentication/lua_based_authentication/)

35
chatmaild/pyproject.toml Normal file
View File

@@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools>=45"]
build-backend = "setuptools.build_meta"
[project]
name = "chatmaild"
version = "0.1"
dependencies = [
"aiosmtpd"
]
[project.scripts]
doveauth = "doveauth.doveauth:main"
doveauth-dictproxy = "doveauth.dictproxy:main"
filtermail = "filtermail.filtermail:main"
[tool.pytest.ini_options]
addopts = "-v -ra --strict-markers"
[tool.tox]
legacy_tox_ini = """
[tox]
isolated_build = true
envlist = lint
[testenv:lint]
skipdist = True
skip_install = True
deps =
ruff
black
commands =
black --quiet --check --diff src/
ruff src/
"""

View File

View File

@@ -0,0 +1,140 @@
import sqlite3
import contextlib
import time
from pathlib import Path
class DBError(Exception):
"""error during an operation on the database."""
class Connection:
def __init__(self, sqlconn, write):
self._sqlconn = sqlconn
self._write = write
def close(self):
self._sqlconn.close()
def commit(self):
self._sqlconn.commit()
def rollback(self):
self._sqlconn.rollback()
def execute(self, query, params=()):
cur = self.cursor()
try:
cur.execute(query, params)
except sqlite3.IntegrityError as e:
raise DBError(e)
return cur
def cursor(self):
return self._sqlconn.cursor()
def create_user(self, addr: str, password: str):
"""Create a row in the users table."""
self.execute("PRAGMA foreign_keys=on")
q = """INSERT INTO users (addr, password, last_login)
VALUES (?, ?, ?)"""
self.execute(q, (addr, password, int(time.time())))
def get_user(self, addr: str) -> {}:
"""Get a row from the users table."""
q = "SELECT addr, password, last_login from users WHERE addr = ?"
row = self._sqlconn.execute(q, (addr,)).fetchone()
result = {}
if row:
result = dict(
user=row[0],
password=row[1],
last_login=row[2],
)
return result
class Database:
def __init__(self, path: str):
self.path = Path(path)
self.ensure_tables()
def _get_connection(
self, write=False, transaction=False, closing=False
) -> Connection:
# we let the database serialize all writers at connection time
# to play it very safe (we don't have massive amounts of writes).
mode = "ro"
if write:
mode = "rw"
if not self.path.exists():
mode = "rwc"
uri = "file:%s?mode=%s" % (self.path, mode)
sqlconn = sqlite3.connect(
uri,
timeout=60,
isolation_level=None if transaction else "DEFERRED",
uri=True,
)
# Enable Write-Ahead Logging to avoid readers blocking writers and vice versa.
if write:
sqlconn.execute("PRAGMA journal_mode=wal")
if transaction:
start_time = time.time()
while 1:
try:
sqlconn.execute("begin immediate")
break
except sqlite3.OperationalError:
# another thread may be writing, give it a chance to finish
time.sleep(0.1)
if time.time() - start_time > 5:
# if it takes this long, something is wrong
raise
conn = Connection(sqlconn, write=write)
if closing:
conn = contextlib.closing(conn)
return conn
@contextlib.contextmanager
def write_transaction(self):
conn = self._get_connection(closing=False, write=True, transaction=True)
try:
yield conn
except Exception:
conn.rollback()
conn.close()
raise
else:
conn.commit()
conn.close()
def read_connection(self, closing=True) -> Connection:
return self._get_connection(closing=closing, write=False)
def get_schema_version(self) -> int:
with self.read_connection() as conn:
dbversion = conn.execute("PRAGMA user_version").fetchone()[0]
return dbversion
CURRENT_DBVERSION = 1
def ensure_tables(self):
with self.write_transaction() as conn:
if self.get_schema_version() > 1:
raise DBError(
"version is %s; downgrading schema is not supported"
% (self.get_schema_version(),)
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS users (
addr TEXT PRIMARY KEY,
password TEXT,
last_login INTEGER
)
""",
)
conn.execute("PRAGMA user_version=%s" % (self.CURRENT_DBVERSION,))

View File

@@ -0,0 +1,115 @@
import os
import sys
import json
from socketserver import (
UnixStreamServer,
StreamRequestHandler,
ThreadingMixIn,
)
import pwd
import subprocess
from .database import Database
def encrypt_password(password: str):
password = password.encode("ascii")
# https://doc.dovecot.org/configuration_manual/authentication/password_schemes/
process = subprocess.Popen(
["doveadm", "pw", "-s", "BLF-CRYPT"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
stdout_data, _stderr_data = process.communicate(
input=password + b"\n" + password + b"\n"
)
return stdout_data.decode("ascii").strip()
def create_user(db, user, password):
with db.write_transaction() as conn:
conn.create_user(user, password)
return dict(home=f"/home/vmail/{user}", uid="vmail", gid="vmail", password=password)
def get_user_data(db, user):
with db.read_connection() as conn:
result = conn.get_user(user)
if result:
result["uid"] = "vmail"
result["gid"] = "vmail"
return result
def lookup_userdb(db, user):
return get_user_data(db, user)
def lookup_passdb(db, user, password):
userdata = get_user_data(db, user)
if not userdata:
return create_user(db, user, encrypt_password(password))
userdata["password"] = userdata["password"].strip()
return userdata
def handle_dovecot_request(msg, db):
print(f"received msg: {msg!r}", file=sys.stderr)
short_command = msg[0]
if short_command == "L": # LOOKUP
parts = msg[1:].split("\t")
keyname, user = parts[:2]
namespace, type, arg = keyname.split("/", 3)
reply_command = "F"
res = ""
if namespace == "shared":
if type == "userdb":
res = lookup_userdb(db, user)
if res:
reply_command = "O"
else:
reply_command = "N"
elif type == "passdb":
res = lookup_passdb(db, user, password=arg)
if res:
reply_command = "O"
else:
reply_command = "N"
print(f"res: {res!r}", file=sys.stderr)
json_res = json.dumps(res) if res else ""
return f"{reply_command}{json_res}\n"
return None
class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer):
pass
def main():
socket = sys.argv[1]
passwd_entry = pwd.getpwnam(sys.argv[2])
db = Database(sys.argv[3])
class Handler(StreamRequestHandler):
def handle(self):
while True:
msg = self.rfile.readline().strip().decode()
if not msg:
continue
res = handle_dovecot_request(msg, db)
if res:
print(f"sending result: {res!r}", file=sys.stderr)
self.wfile.write(res.encode("ascii"))
self.wfile.flush()
try:
os.unlink(socket)
except FileNotFoundError:
pass
with ThreadedUnixStreamServer(socket, Handler) as server:
os.chown(socket, uid=passwd_entry.pw_uid, gid=passwd_entry.pw_gid)
try:
server.serve_forever()
except KeyboardInterrupt:
pass

View File

@@ -0,0 +1,10 @@
[Unit]
Description=Dict authentication proxy for dovecot
[Service]
ExecStart=/usr/local/bin/doveauth-dictproxy /run/dovecot/doveauth.socket vmail /home/vmail/passdb.sqlite
Restart=always
RestartSec=30
[Install]
WantedBy=multi-user.target

View File

@@ -0,0 +1,65 @@
#!/usr/bin/env python3
import base64
import sys
from .database import Database
def get_user_data(db, user):
with db.read_connection() as conn:
result = conn.get_user(user)
if result:
result["uid"] = "vmail"
result["gid"] = "vmail"
return result
def create_user(db, user, password):
with db.write_transaction() as conn:
conn.create_user(user, password)
return dict(home=f"/home/vmail/{user}", uid="vmail", gid="vmail", password=password)
def verify_user(db, user, password):
userdata = get_user_data(db, user)
if userdata:
if userdata.get("password") == password:
userdata["status"] = "ok"
else:
userdata["status"] = "fail"
else:
userdata = create_user(db, user, password)
userdata["status"] = "ok"
return userdata
def lookup_user(db, user):
userdata = get_user_data(db, user)
if userdata:
userdata["status"] = "ok"
else:
userdata["status"] = "fail"
return userdata
def dump_result(res):
for key, value in res.items():
print(f"{key}={value}")
def main():
db = Database("/home/vmail/passdb.sqlite")
if sys.argv[1] == "hexauth":
login = base64.b16decode(sys.argv[2]).decode()
password = base64.b16decode(sys.argv[3]).decode()
res = verify_user(db, login, password)
dump_result(res)
elif sys.argv[1] == "hexlookup":
login = base64.b16decode(sys.argv[2]).decode()
res = lookup_user(db, login)
dump_result(res)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,36 @@
import pytest
from .dictproxy import get_user_data
from .doveauth import verify_user
from .database import Database, DBError
@pytest.fixture()
def db(tmpdir):
db_path = tmpdir / "passdb.sqlite"
print("database path:", db_path)
return Database(db_path)
def test_basic(db):
verify_user(db, "link2xt@c1.testrun.org", "asdf")
data = get_user_data(db, "link2xt@c1.testrun.org")
assert data
def test_verify_or_create(db):
res = verify_user(db, "newuser1@something.org", "kajdlkajsldk12l3kj1983")
assert res["status"] == "ok"
res = verify_user(db, "newuser1@something.org", "kajdlqweqwe")
assert res["status"] == "fail"
def test_db_version(db):
assert db.get_schema_version() == 1
def test_too_high_db_version(db):
with db.write_transaction() as conn:
conn.execute("PRAGMA user_version=%s;" % (999,))
with pytest.raises(DBError):
db.ensure_tables()

View File

View File

@@ -0,0 +1,85 @@
#!/usr/bin/env python3
import asyncio
import logging
from aiosmtpd.lmtp import LMTP
from aiosmtpd.controller import UnixSocketController
from smtplib import SMTP as SMTPClient
def check_encrypted(envelope):
"""https://xkcd.com/1181/"""
return "-----BEGIN PGP MESSAGE-----" in envelope.content.decode(
"utf8", errors="replace"
)
class ExampleController(UnixSocketController):
def factory(self):
return LMTP(self.handler, **self.SMTP_kwargs)
class ExampleHandler:
async def handle_RCPT(self, server, session, envelope, address, rcpt_options):
envelope.rcpt_tos.append(address)
return "250 OK"
async def handle_DATA(self, server, session, envelope):
logging.info("Processing DATA message from %s", envelope.mail_from)
valid_recipients = []
mail_encrypted = check_encrypted(envelope)
res = []
for recipient in envelope.rcpt_tos:
my_local_domain = envelope.mail_from.split("@")
if len(my_local_domain) != 2:
res += [f"500 Invalid from address <{envelope.mail_from}>"]
continue
if envelope.mail_from == recipient:
# Always allow sending emails to self.
valid_recipients += [recipient]
res += ["250 OK"]
continue
recipient_local_domain = recipient.split("@")
if len(recipient_local_domain) != 2:
res += [f"500 Invalid address <{recipient}>"]
continue
is_outgoing = recipient_local_domain[1] != my_local_domain[1]
if is_outgoing and not mail_encrypted:
res += ["500 Outgoing mail must be encrypted"]
continue
valid_recipients += [recipient]
res += ["250 OK"]
# Reinject the mail back into Postfix.
if valid_recipients:
logging.info("Reinjecting the mail")
client = SMTPClient("localhost", "10026")
client.sendmail(envelope.mail_from, valid_recipients, envelope.content)
return "\r\n".join(res)
async def asyncmain(loop):
controller = ExampleController(
ExampleHandler(), unix_socket="/var/spool/postfix/private/filtermail"
)
controller.start()
def main():
logging.basicConfig(level=logging.INFO)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(asyncmain(loop=loop))
loop.run_forever()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,10 @@
[Unit]
Description=Email filter for chatmail servers
[Service]
ExecStart=/usr/local/bin/filtermail
Restart=always
RestartSec=30
[Install]
WantedBy=multi-user.target