mirror of
https://github.com/chatmail/relay.git
synced 2026-05-12 09:04:36 +00:00
Compare commits
8 Commits
hpk/cidebu
...
hpk/debug3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f175eec94 | ||
|
|
1cb64b4777 | ||
|
|
f88bc86c54 | ||
|
|
db1054f4bd | ||
|
|
134f498778 | ||
|
|
c4f46dc499 | ||
|
|
c1fd573de2 | ||
|
|
c6b083472f |
@@ -111,7 +111,7 @@ def check_encrypted(message):
|
|||||||
"""
|
"""
|
||||||
if not message.is_multipart():
|
if not message.is_multipart():
|
||||||
return False
|
return False
|
||||||
if message.get("subject") != "...":
|
if message.get("subject") not in {"...", "[...]"}:
|
||||||
return False
|
return False
|
||||||
if message.get_content_type() != "multipart/encrypted":
|
if message.get_content_type() != "multipart/encrypted":
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
From: {from_addr}
|
From: {from_addr}
|
||||||
To: {to_addr}
|
To: {to_addr}
|
||||||
Subject: ...
|
Subject: {subject}
|
||||||
Date: Sun, 15 Oct 2023 16:43:21 +0000
|
Date: Sun, 15 Oct 2023 16:43:21 +0000
|
||||||
Message-ID: <Mr.UVyJWZmkCKM.hGzNc6glBE_@c2.testrun.org>
|
Message-ID: <Mr.UVyJWZmkCKM.hGzNc6glBE_@c2.testrun.org>
|
||||||
In-Reply-To: <Mr.MvmCz-GQbi_.6FGRkhDf05c@c2.testrun.org>
|
In-Reply-To: <Mr.MvmCz-GQbi_.6FGRkhDf05c@c2.testrun.org>
|
||||||
|
|||||||
@@ -71,11 +71,11 @@ def maildata(request):
|
|||||||
|
|
||||||
assert datadir.exists(), datadir
|
assert datadir.exists(), datadir
|
||||||
|
|
||||||
def maildata(name, from_addr, to_addr):
|
def maildata(name, from_addr, to_addr, subject="..."):
|
||||||
# Using `.read_bytes().decode()` instead of `.read_text()` to preserve newlines.
|
# Using `.read_bytes().decode()` instead of `.read_text()` to preserve newlines.
|
||||||
data = datadir.joinpath(name).read_bytes().decode()
|
data = datadir.joinpath(name).read_bytes().decode()
|
||||||
|
|
||||||
text = data.format(from_addr=from_addr, to_addr=to_addr)
|
text = data.format(from_addr=from_addr, to_addr=to_addr, subject=subject)
|
||||||
return BytesParser(policy=policy.default).parsebytes(text.encode())
|
return BytesParser(policy=policy.default).parsebytes(text.encode())
|
||||||
|
|
||||||
return maildata
|
return maildata
|
||||||
|
|||||||
@@ -54,10 +54,16 @@ def test_filtermail_no_encryption_detection(maildata):
|
|||||||
|
|
||||||
|
|
||||||
def test_filtermail_encryption_detection(maildata):
|
def test_filtermail_encryption_detection(maildata):
|
||||||
msg = maildata("encrypted.eml", from_addr="1@example.org", to_addr="2@example.org")
|
for subject in ("...", "[...]"):
|
||||||
assert check_encrypted(msg)
|
msg = maildata(
|
||||||
|
"encrypted.eml",
|
||||||
|
from_addr="1@example.org",
|
||||||
|
to_addr="2@example.org",
|
||||||
|
subject=subject,
|
||||||
|
)
|
||||||
|
assert check_encrypted(msg)
|
||||||
|
|
||||||
# if the subject is not "..." it is not considered ac-encrypted
|
# if the subject is not a known encrypted subject value, it is not considered ac-encrypted
|
||||||
msg.replace_header("Subject", "Click this link")
|
msg.replace_header("Subject", "Click this link")
|
||||||
assert not check_encrypted(msg)
|
assert not check_encrypted(msg)
|
||||||
|
|
||||||
@@ -72,7 +78,7 @@ def test_filtermail_unencrypted_mdn(maildata, gencreds):
|
|||||||
"""Unencrypted MDNs should not pass."""
|
"""Unencrypted MDNs should not pass."""
|
||||||
from_addr = gencreds()[0]
|
from_addr = gencreds()[0]
|
||||||
to_addr = gencreds()[0] + ".other"
|
to_addr = gencreds()[0] + ".other"
|
||||||
msg = maildata("mdn.eml", from_addr, to_addr)
|
msg = maildata("mdn.eml", from_addr=from_addr, to_addr=to_addr)
|
||||||
|
|
||||||
assert not check_encrypted(msg)
|
assert not check_encrypted(msg)
|
||||||
|
|
||||||
@@ -95,7 +101,7 @@ def test_excempt_privacy(maildata, gencreds, handler):
|
|||||||
handler.config.passthrough_recipients = [to_addr]
|
handler.config.passthrough_recipients = [to_addr]
|
||||||
false_to = "privacy@something.org"
|
false_to = "privacy@something.org"
|
||||||
|
|
||||||
msg = maildata("plain.eml", from_addr, to_addr)
|
msg = maildata("plain.eml", from_addr=from_addr, to_addr=to_addr)
|
||||||
|
|
||||||
class env:
|
class env:
|
||||||
mail_from = from_addr
|
mail_from = from_addr
|
||||||
@@ -118,7 +124,7 @@ def test_passthrough_senders(gencreds, handler, maildata):
|
|||||||
to_addr = "recipient@something.org"
|
to_addr = "recipient@something.org"
|
||||||
handler.config.passthrough_senders = [acc1]
|
handler.config.passthrough_senders = [acc1]
|
||||||
|
|
||||||
msg = maildata("plain.eml", acc1, to_addr)
|
msg = maildata("plain.eml", from_addr=acc1, to_addr=to_addr)
|
||||||
|
|
||||||
class env:
|
class env:
|
||||||
mail_from = acc1
|
mail_from = acc1
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def run_cmd(args, out):
|
|||||||
"""Deploy chatmail services on the remote server."""
|
"""Deploy chatmail services on the remote server."""
|
||||||
|
|
||||||
remote_data = dns.get_initial_remote_data(args, out)
|
remote_data = dns.get_initial_remote_data(args, out)
|
||||||
if not remote_data:
|
if not dns.check_initial_remote_data(remote_data, print=out.red):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
@@ -283,16 +283,14 @@ def main(args=None):
|
|||||||
if not hasattr(args, "func"):
|
if not hasattr(args, "func"):
|
||||||
return parser.parse_args(["-h"])
|
return parser.parse_args(["-h"])
|
||||||
|
|
||||||
ssh_exec_cache = []
|
ssh_cache = []
|
||||||
|
|
||||||
def get_sshexec():
|
def get_sshexec():
|
||||||
if not ssh_exec_cache:
|
if not ssh_cache:
|
||||||
print(f"[ssh] login to {args.config.mail_domain}")
|
print(f"[ssh] login to {args.config.mail_domain}")
|
||||||
ssh_exec = SSHExec(
|
ssh = SSHExec(args.config.mail_domain, remote_funcs, verbose=args.verbose)
|
||||||
args.config.mail_domain, remote_funcs, verbose=args.verbose
|
ssh_cache.append(ssh)
|
||||||
)
|
return ssh_cache[0]
|
||||||
ssh_exec_cache.append(ssh_exec)
|
|
||||||
return ssh_exec_cache[0]
|
|
||||||
|
|
||||||
args.get_sshexec = get_sshexec
|
args.get_sshexec = get_sshexec
|
||||||
|
|
||||||
@@ -313,7 +311,6 @@ def main(args=None):
|
|||||||
if res is None:
|
if res is None:
|
||||||
res = 0
|
res = 0
|
||||||
return res
|
return res
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
out.red("KeyboardInterrupt")
|
out.red("KeyboardInterrupt")
|
||||||
sys.exit(130)
|
sys.exit(130)
|
||||||
|
|||||||
@@ -9,15 +9,18 @@ from . import remote_funcs
|
|||||||
def get_initial_remote_data(args, out):
|
def get_initial_remote_data(args, out):
|
||||||
sshexec = args.get_sshexec()
|
sshexec = args.get_sshexec()
|
||||||
mail_domain = args.config.mail_domain
|
mail_domain = args.config.mail_domain
|
||||||
remote_data = sshexec.logged(
|
return sshexec.logged(
|
||||||
call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain)
|
call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_initial_remote_data(remote_data, print=print):
|
||||||
|
mail_domain = remote_data["mail_domain"]
|
||||||
if not remote_data["A"] and not remote_data["AAAA"]:
|
if not remote_data["A"] and not remote_data["AAAA"]:
|
||||||
out.red("Missing A and/or AAAA DNS records for {mail_domain}!")
|
print("Missing A and/or AAAA DNS records for {mail_domain}!")
|
||||||
elif not remote_data["MTA_STS"]:
|
elif not remote_data["MTA_STS"]:
|
||||||
out.red("Missing MTA_STS record:")
|
print("Missing MTA-STS CNAME record:")
|
||||||
out(f"{mail_domain}. CNAME {mail_domain}")
|
print(f"mta-sts.{mail_domain}. CNAME {mail_domain}")
|
||||||
else:
|
else:
|
||||||
return remote_data
|
return remote_data
|
||||||
|
|
||||||
@@ -62,7 +65,7 @@ def show_dns(args, out, remote_data) -> int:
|
|||||||
with open(args.zonefile, "w+") as zf:
|
with open(args.zonefile, "w+") as zf:
|
||||||
zf.write(zonefile)
|
zf.write(zonefile)
|
||||||
out.green(f"DNS records successfully written to: {args.zonefile}")
|
out.green(f"DNS records successfully written to: {args.zonefile}")
|
||||||
return -1
|
return 0
|
||||||
|
|
||||||
if diff_records:
|
if diff_records:
|
||||||
out.red("Please set the following DNS entries at your DNS provider:\n")
|
out.red("Please set the following DNS entries at your DNS provider:\n")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ All functions of this module
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import traceback
|
||||||
from subprocess import CalledProcessError, check_output
|
from subprocess import CalledProcessError, check_output
|
||||||
|
|
||||||
|
|
||||||
@@ -31,11 +32,12 @@ def get_systemd_running():
|
|||||||
|
|
||||||
def perform_initial_checks(mail_domain):
|
def perform_initial_checks(mail_domain):
|
||||||
"""Collecting initial DNS zone content."""
|
"""Collecting initial DNS zone content."""
|
||||||
|
assert mail_domain
|
||||||
A = query_dns("A", mail_domain)
|
A = query_dns("A", mail_domain)
|
||||||
AAAA = query_dns("AAAA", mail_domain)
|
AAAA = query_dns("AAAA", mail_domain)
|
||||||
MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}")
|
MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}")
|
||||||
|
|
||||||
res = dict(A=A, AAAA=AAAA, MTA_STS=MTA_STS)
|
res = dict(mail_domain=mail_domain, A=A, AAAA=AAAA, MTA_STS=MTA_STS)
|
||||||
if not MTA_STS or (not A and not AAAA):
|
if not MTA_STS or (not A and not AAAA):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@@ -69,14 +71,14 @@ def query_dns(typ, domain):
|
|||||||
print(res)
|
print(res)
|
||||||
if res:
|
if res:
|
||||||
return res.split("\n")[0]
|
return res.split("\n")[0]
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def check_zonefile(zonefile):
|
def check_zonefile(zonefile):
|
||||||
"""Check all expected zone file entries."""
|
"""Check expected zone file entries."""
|
||||||
diff = []
|
diff = []
|
||||||
|
|
||||||
for zf_line in zonefile.splitlines():
|
for zf_line in zonefile.splitlines():
|
||||||
print("")
|
|
||||||
print(f"dns-checking {zf_line!r}")
|
print(f"dns-checking {zf_line!r}")
|
||||||
zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2)
|
zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2)
|
||||||
zf_domain = zf_domain.rstrip(".")
|
zf_domain = zf_domain.rstrip(".")
|
||||||
@@ -89,16 +91,35 @@ def check_zonefile(zonefile):
|
|||||||
return diff
|
return diff
|
||||||
|
|
||||||
|
|
||||||
|
## Function Execution server
|
||||||
|
|
||||||
|
|
||||||
|
def _run_loop(cmd_channel):
|
||||||
|
while 1:
|
||||||
|
cmd = cmd_channel.receive()
|
||||||
|
if cmd is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
cmd_channel.send(_handle_one_request(cmd))
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_one_request(cmd):
|
||||||
|
func_name, kwargs = cmd
|
||||||
|
try:
|
||||||
|
res = globals()[func_name](**kwargs)
|
||||||
|
return ("finish", res)
|
||||||
|
except:
|
||||||
|
data = traceback.format_exc()
|
||||||
|
return ("error", data)
|
||||||
|
|
||||||
|
|
||||||
# check if this module is executed remotely
|
# check if this module is executed remotely
|
||||||
# and setup a simple serialized function-execution loop
|
# and setup a simple serialized function-execution loop
|
||||||
|
|
||||||
if __name__ == "__channelexec__":
|
if __name__ == "__channelexec__":
|
||||||
|
channel = channel # noqa (channel object gets injected)
|
||||||
|
|
||||||
def print(item):
|
# enable simple "print" debugging for anyone changing this module
|
||||||
channel.send(("log", item)) # noqa
|
globals()["print"] = lambda x="": channel.send(("log", x))
|
||||||
|
|
||||||
while 1:
|
_run_loop(channel)
|
||||||
func_name, kwargs = channel.receive() # noqa
|
|
||||||
kwargs = kwargs if kwargs else {}
|
|
||||||
res = globals()[func_name](**kwargs) # noqa
|
|
||||||
channel.send(("finish", res)) # noqa
|
|
||||||
|
|||||||
@@ -3,8 +3,13 @@ import sys
|
|||||||
import execnet
|
import execnet
|
||||||
|
|
||||||
|
|
||||||
|
class FuncError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SSHExec:
|
class SSHExec:
|
||||||
RemoteError = execnet.RemoteError
|
RemoteError = execnet.RemoteError
|
||||||
|
FuncError = FuncError
|
||||||
|
|
||||||
def __init__(self, host, remote_funcs, verbose=False, python="python3", timeout=60):
|
def __init__(self, host, remote_funcs, verbose=False, python="python3", timeout=60):
|
||||||
self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}")
|
self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}")
|
||||||
@@ -13,6 +18,8 @@ class SSHExec:
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
def __call__(self, call, kwargs=None, log_callback=None):
|
def __call__(self, call, kwargs=None, log_callback=None):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
self._remote_cmdloop_channel.send((call.__name__, kwargs))
|
self._remote_cmdloop_channel.send((call.__name__, kwargs))
|
||||||
while 1:
|
while 1:
|
||||||
code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout)
|
code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout)
|
||||||
@@ -20,6 +27,8 @@ class SSHExec:
|
|||||||
log_callback(data)
|
log_callback(data)
|
||||||
elif code == "finish":
|
elif code == "finish":
|
||||||
return data
|
return data
|
||||||
|
elif code == "error":
|
||||||
|
raise self.FuncError(data)
|
||||||
|
|
||||||
def logged(self, call, kwargs):
|
def logged(self, call, kwargs):
|
||||||
def log_progress(data):
|
def log_progress(data):
|
||||||
|
|||||||
@@ -40,6 +40,18 @@ class TestSSHExecutor:
|
|||||||
assert len(lines) > 4
|
assert len(lines) > 4
|
||||||
assert remote_funcs.perform_initial_checks.__doc__ in lines[0]
|
assert remote_funcs.perform_initial_checks.__doc__ in lines[0]
|
||||||
|
|
||||||
|
def test_exception(self, sshexec, capsys):
|
||||||
|
try:
|
||||||
|
sshexec.logged(
|
||||||
|
remote_funcs.perform_initial_checks,
|
||||||
|
kwargs=dict(mail_domain=None),
|
||||||
|
)
|
||||||
|
except sshexec.FuncError as e:
|
||||||
|
assert "remote_funcs.py" in str(e)
|
||||||
|
assert "AssertionError" in str(e)
|
||||||
|
else:
|
||||||
|
pytest.fail("didn't raise exception")
|
||||||
|
|
||||||
|
|
||||||
def test_remote(remote, imap_or_smtp):
|
def test_remote(remote, imap_or_smtp):
|
||||||
lineproducer = remote.iter_output(imap_or_smtp.logcmd)
|
lineproducer = remote.iter_output(imap_or_smtp.logcmd)
|
||||||
|
|||||||
50
cmdeploy/src/cmdeploy/tests/test_dns.py
Normal file
50
cmdeploy/src/cmdeploy/tests/test_dns.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from cmdeploy import remote_funcs
|
||||||
|
from cmdeploy.dns import check_initial_remote_data
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformInitialChecks:
|
||||||
|
@pytest.fixture
|
||||||
|
def mockdns(self, monkeypatch):
|
||||||
|
qdict = {
|
||||||
|
"A": {"some.domain": "1.1.1.1"},
|
||||||
|
"AAAA": {"some.domain": "fde5:cd7a:9e1c:3240:5a99:936f:cdac:53ae"},
|
||||||
|
"CNAME": {"mta-sts.some.domain": "some.domain"},
|
||||||
|
}.copy()
|
||||||
|
|
||||||
|
def query_dns(typ, domain):
|
||||||
|
try:
|
||||||
|
return qdict[typ][domain]
|
||||||
|
except KeyError:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
monkeypatch.setattr(remote_funcs, query_dns.__name__, query_dns)
|
||||||
|
return qdict
|
||||||
|
|
||||||
|
def test_perform_initial_checks_ok1(self, mockdns):
|
||||||
|
remote_data = remote_funcs.perform_initial_checks("some.domain")
|
||||||
|
assert len(remote_data) == 7
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("drop", ["A", "AAAA"])
|
||||||
|
def test_perform_initial_checks_with_one_of_A_AAAA(self, mockdns, drop):
|
||||||
|
del mockdns[drop]
|
||||||
|
remote_data = remote_funcs.perform_initial_checks("some.domain")
|
||||||
|
assert len(remote_data) == 7
|
||||||
|
assert not remote_data[drop]
|
||||||
|
|
||||||
|
l = []
|
||||||
|
res = check_initial_remote_data(remote_data, print=l.append)
|
||||||
|
assert res
|
||||||
|
assert not l
|
||||||
|
|
||||||
|
def test_perform_initial_checks_no_mta_sts(self, mockdns):
|
||||||
|
del mockdns["CNAME"]
|
||||||
|
remote_data = remote_funcs.perform_initial_checks("some.domain")
|
||||||
|
assert len(remote_data) == 4
|
||||||
|
assert not remote_data["MTA_STS"]
|
||||||
|
|
||||||
|
l = []
|
||||||
|
res = check_initial_remote_data(remote_data, print=l.append)
|
||||||
|
assert not res
|
||||||
|
assert len(l) == 2
|
||||||
Reference in New Issue
Block a user