mirror of
https://github.com/chatmail/relay.git
synced 2026-05-14 18:04:38 +00:00
refactor test and filtermail to prepare it for BeforeQueue handling
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import sys
|
||||
from email.parser import BytesParser
|
||||
from email import policy
|
||||
from email.utils import parseaddr
|
||||
|
||||
from aiosmtpd.lmtp import LMTP
|
||||
from aiosmtpd.smtp import SMTP
|
||||
from aiosmtpd.controller import UnixSocketController
|
||||
from smtplib import SMTP as SMTPClient
|
||||
|
||||
@@ -32,12 +35,40 @@ def check_encrypted(message):
|
||||
return True
|
||||
|
||||
|
||||
class ExampleController(UnixSocketController):
|
||||
def factory(self):
|
||||
return LMTP(self.handler, **self.SMTP_kwargs)
|
||||
|
||||
class BeforeQueueHandler:
|
||||
transport_class = SMTP
|
||||
|
||||
def __init__(self):
|
||||
self.send_rate_limiter = SendRateLimiter()
|
||||
|
||||
async def handle_MAIL(self, server, session, envelope, address, mail_options):
|
||||
logging.info(f"handle_MAIL from {address}")
|
||||
if self.send_rate_limiter.is_sending_allowed(address):
|
||||
envelope.mail_from = address
|
||||
return "250 OK"
|
||||
return "400 per-user ratelimit exceeded"
|
||||
|
||||
|
||||
class ExampleHandler:
|
||||
class SendRateLimiter:
|
||||
MAX_USER_SEND_PER_MINUTE = 80
|
||||
|
||||
def __init__(self):
|
||||
self.addr2timestamps = {}
|
||||
|
||||
def is_sending_allowed(self, mail_from):
|
||||
last = self.addr2timestamps.setdefault(mail_from, [])
|
||||
now = time.time()
|
||||
last[:] = [ts for ts in last if ts >= (now - 60)]
|
||||
if len(last) <= self.MAX_USER_SEND_PER_MINUTE:
|
||||
last.append(now)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AfterQueueHandler:
|
||||
transport_class = LMTP
|
||||
|
||||
async def handle_RCPT(self, server, session, envelope, address, rcpt_options):
|
||||
envelope.rcpt_tos.append(address)
|
||||
return "250 OK"
|
||||
@@ -55,13 +86,6 @@ class ExampleHandler:
|
||||
return "\r\n".join(res)
|
||||
|
||||
|
||||
async def asyncmain(loop):
|
||||
controller = ExampleController(
|
||||
ExampleHandler(), unix_socket="/var/spool/postfix/private/filtermail"
|
||||
)
|
||||
controller.start()
|
||||
|
||||
|
||||
def lmtp_handle_DATA(envelope):
|
||||
"""the central filtering function for e-mails."""
|
||||
logging.info(f"Processing DATA message from {envelope.mail_from}")
|
||||
@@ -113,13 +137,25 @@ def lmtp_handle_DATA(envelope):
|
||||
return valid_recipients, res
|
||||
|
||||
|
||||
class Controller(UnixSocketController):
|
||||
def factory(self):
|
||||
return self.handler.transport_class(self.handler, **self.SMTP_kwargs)
|
||||
|
||||
|
||||
async def asyncmain(loop, handler, unix_socket_fn):
|
||||
Controller(handler, unix_socket=unix_socket_fn).start()
|
||||
|
||||
|
||||
name2Handler = {"beforequeue": BeforeQueueHandler, "afterqueue": AfterQueueHandler}
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
assert len(args) == 2
|
||||
handler = name2Handler[args[0]]()
|
||||
unix_socket_fn = args[1]
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(asyncmain(loop=loop))
|
||||
loop.create_task(asyncmain(loop, handler, unix_socket_fn))
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Description=Email filter for chatmail servers
|
||||
|
||||
[Service]
|
||||
ExecStart=/usr/local/bin/filtermail
|
||||
ExecStart=/usr/local/bin/filtermail afterqueue /var/spool/postfix/private/filtermail
|
||||
Restart=always
|
||||
RestartSec=30
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .filtermail import check_encrypted, lmtp_handle_DATA
|
||||
from .filtermail import check_encrypted, lmtp_handle_DATA, SendRateLimiter
|
||||
from email.parser import BytesParser
|
||||
from email import policy
|
||||
import pytest
|
||||
|
||||
|
||||
def test_reject_forged_from():
|
||||
@@ -326,3 +327,16 @@ def test_filtermail():
|
||||
]
|
||||
).encode()
|
||||
)
|
||||
|
||||
|
||||
def test_send_rate_limiter():
|
||||
limiter = SendRateLimiter()
|
||||
for i in range(100):
|
||||
if limiter.is_sending_allowed("some@example.org"):
|
||||
if i <= SendRateLimiter.MAX_USER_SEND_PER_MINUTE:
|
||||
continue
|
||||
pytest.fail("limiter didn't work")
|
||||
else:
|
||||
assert i == SendRateLimiter.MAX_USER_SEND_PER_MINUTE + 1
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user