refactor test and filtermail to prepare it for BeforeQueue handling

This commit is contained in:
holger krekel
2023-10-18 21:43:06 +02:00
parent 410bc50a8b
commit bbd2773506
5 changed files with 75 additions and 185 deletions

View File

@@ -1,11 +1,14 @@
#!/usr/bin/env python3
import asyncio
import logging
import time
import sys
from email.parser import BytesParser
from email import policy
from email.utils import parseaddr
from aiosmtpd.lmtp import LMTP
from aiosmtpd.smtp import SMTP
from aiosmtpd.controller import UnixSocketController
from smtplib import SMTP as SMTPClient
@@ -32,12 +35,40 @@ def check_encrypted(message):
return True
class ExampleController(UnixSocketController):
def factory(self):
return LMTP(self.handler, **self.SMTP_kwargs)
class BeforeQueueHandler:
transport_class = SMTP
def __init__(self):
self.send_rate_limiter = SendRateLimiter()
async def handle_MAIL(self, server, session, envelope, address, mail_options):
logging.info(f"handle_MAIL from {address}")
if self.send_rate_limiter.is_sending_allowed(address):
envelope.mail_from = address
return "250 OK"
return "400 per-user ratelimit exceeded"
class ExampleHandler:
class SendRateLimiter:
MAX_USER_SEND_PER_MINUTE = 80
def __init__(self):
self.addr2timestamps = {}
def is_sending_allowed(self, mail_from):
last = self.addr2timestamps.setdefault(mail_from, [])
now = time.time()
last[:] = [ts for ts in last if ts >= (now - 60)]
if len(last) <= self.MAX_USER_SEND_PER_MINUTE:
last.append(now)
return True
return False
class AfterQueueHandler:
transport_class = LMTP
async def handle_RCPT(self, server, session, envelope, address, rcpt_options):
envelope.rcpt_tos.append(address)
return "250 OK"
@@ -55,13 +86,6 @@ class ExampleHandler:
return "\r\n".join(res)
async def asyncmain(loop):
controller = ExampleController(
ExampleHandler(), unix_socket="/var/spool/postfix/private/filtermail"
)
controller.start()
def lmtp_handle_DATA(envelope):
"""the central filtering function for e-mails."""
logging.info(f"Processing DATA message from {envelope.mail_from}")
@@ -113,13 +137,25 @@ def lmtp_handle_DATA(envelope):
return valid_recipients, res
class Controller(UnixSocketController):
def factory(self):
return self.handler.transport_class(self.handler, **self.SMTP_kwargs)
async def asyncmain(loop, handler, unix_socket_fn):
Controller(handler, unix_socket=unix_socket_fn).start()
name2Handler = {"beforequeue": BeforeQueueHandler, "afterqueue": AfterQueueHandler}
def main():
args = sys.argv[1:]
assert len(args) == 2
handler = name2Handler[args[0]]()
unix_socket_fn = args[1]
logging.basicConfig(level=logging.INFO)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(asyncmain(loop=loop))
loop.create_task(asyncmain(loop, handler, unix_socket_fn))
loop.run_forever()
if __name__ == "__main__":
main()

View File

@@ -2,7 +2,7 @@
Description=Email filter for chatmail servers
[Service]
ExecStart=/usr/local/bin/filtermail
ExecStart=/usr/local/bin/filtermail afterqueue /var/spool/postfix/private/filtermail
Restart=always
RestartSec=30

View File

@@ -1,6 +1,7 @@
from .filtermail import check_encrypted, lmtp_handle_DATA
from .filtermail import check_encrypted, lmtp_handle_DATA, SendRateLimiter
from email.parser import BytesParser
from email import policy
import pytest
def test_reject_forged_from():
@@ -326,3 +327,16 @@ def test_filtermail():
]
).encode()
)
def test_send_rate_limiter():
limiter = SendRateLimiter()
for i in range(100):
if limiter.is_sending_allowed("some@example.org"):
if i <= SendRateLimiter.MAX_USER_SEND_PER_MINUTE:
continue
pytest.fail("limiter didn't work")
else:
assert i == SendRateLimiter.MAX_USER_SEND_PER_MINUTE + 1
break