more tests and refinements

This commit is contained in:
holger krekel
2024-07-15 12:42:06 +02:00
parent c3caddcec9
commit 79a9d2345b
9 changed files with 128 additions and 73 deletions

View File

@@ -79,3 +79,22 @@ def maildata(request):
return BytesParser(policy=policy.default).parsebytes(text.encode())
return maildata
@pytest.fixture
def mockout():
class MockOut:
captured_red = []
captured_green = []
captured_plain = []
def red(self, msg):
self.captured_red.append(msg)
def green(self, msg):
self.captured_green.append(msg)
def __call__(self, msg):
self.captured_plain.append(msg)
return MockOut()

View File

@@ -1,9 +1,11 @@
#
# Required DNS entries for chatmail servers
{% if ipv4 %}
{{ chatmail_domain }}. A {{ ipv4 }}
#
{% if A %}
{{ chatmail_domain }}. A {{ A }}
{% endif %}
{% if ipv6 %}
{{ chatmail_domain }}. AAAA {{ ipv6 }}
{% if AAAA %}
{{ chatmail_domain }}. AAAA {{ AAAA }}
{% endif %}
{{ chatmail_domain }}. MX 10 {{ chatmail_domain }}.
_mta-sts.{{ chatmail_domain }}. TXT "v=STSv1; id={{ sts_id }}"
@@ -11,7 +13,9 @@ mta-sts.{{ chatmail_domain }}. CNAME {{ chatmail_domain }}.
www.{{ chatmail_domain }}. CNAME {{ chatmail_domain }}.
{{ dkim_entry }}
# Recommended DNS entries
#
# Recommended DNS entries for interoperability and security-hardening
#
{{ chatmail_domain }}. TXT "v=spf1 a:{{ chatmail_domain }} ~all"
_dmarc.{{ chatmail_domain }}. TXT "v=DMARC1;p=reject;adkim=s;aspf=s"
@@ -20,7 +24,6 @@ _dmarc.{{ chatmail_domain }}. TXT "v=DMARC1;p=reject;adkim=s;aspf=s"
{% endif %}
_adsp._domainkey.{{ chatmail_domain }}. TXT "dkim=discardable"
# The following are technically not required for Delta Chat
_submission._tcp.{{ chatmail_domain }}. SRV 0 1 587 {{ chatmail_domain }}.
_submissions._tcp.{{ chatmail_domain }}. SRV 0 1 465 {{ chatmail_domain }}.
_imap._tcp.{{ chatmail_domain }}. SRV 0 1 143 {{ chatmail_domain }}.

View File

@@ -7,6 +7,7 @@ import argparse
import importlib.resources
import importlib.util
import os
import pathlib
import shutil
import subprocess
import sys
@@ -54,7 +55,8 @@ def run_cmd_options(parser):
def run_cmd(args, out):
"""Deploy chatmail services on the remote server."""
remote_data = dns.get_initial_remote_data(args)
sshexec = args.get_sshexec()
remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain)
if not dns.check_initial_remote_data(remote_data, print=out.red):
return 1
@@ -80,16 +82,37 @@ def dns_cmd_options(parser):
parser.add_argument(
"--zonefile",
dest="zonefile",
help="print the whole zonefile for deploying directly",
type=pathlib.Path,
default=None,
help="write out a zonefile",
)
def dns_cmd(args, out):
"""Check DNS entries and optionally generate dns zone file."""
remote_data = dns.get_initial_remote_data(args)
sshexec = args.get_sshexec()
remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain)
if not remote_data:
return 1
retcode = dns.show_dns(args, out, remote_data)
if not remote_data["acme_account_url"]:
out.red("could not get letsencrypt account url, please run 'cmdeploy run'")
return 1
if not remote_data["dkim_entry"]:
out.red("could not determine dkim_entry, please run 'cmdeploy run'")
return 1
zonefile = dns.get_filled_zone_file(remote_data)
if args.zonefile:
args.zonefile.write_text(zonefile)
out.green(f"DNS records successfully written to: {args.zonefile}")
return 0
retcode = dns.check_full_zone(
sshexec, remote_data=remote_data, zonefile=zonefile, out=out
)
return retcode
@@ -283,14 +306,9 @@ def main(args=None):
if not hasattr(args, "func"):
return parser.parse_args(["-h"])
ssh_cache = []
def get_sshexec():
if not ssh_cache:
print(f"[ssh] login to {args.config.mail_domain}")
ssh = SSHExec(args.config.mail_domain, remote_funcs, verbose=args.verbose)
ssh_cache.append(ssh)
return ssh_cache[0]
print(f"[ssh] login to {args.config.mail_domain}")
return SSHExec(args.config.mail_domain, remote_funcs, verbose=args.verbose)
args.get_sshexec = get_sshexec

View File

@@ -6,9 +6,7 @@ from jinja2 import Template
from . import remote_funcs
def get_initial_remote_data(args):
sshexec = args.get_sshexec()
mail_domain = args.config.mail_domain
def get_initial_remote_data(sshexec, mail_domain):
return sshexec.logged(
call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain)
)
@@ -25,7 +23,7 @@ def check_initial_remote_data(remote_data, print=print):
return remote_data
def get_filled_zone_file(remote_data, mail_domain):
def get_filled_zone_file(remote_data):
sts_id = remote_data.get("sts_id")
if not sts_id:
sts_id = datetime.datetime.now().strftime("%Y%m%d%H%M")
@@ -33,12 +31,12 @@ def get_filled_zone_file(remote_data, mail_domain):
template = importlib.resources.files(__package__).joinpath("chatmail.zone.j2")
content = template.read_text()
zonefile = Template(content).render(
acme_account_url=remote_data.get("acme_account_url"),
acme_account_url=remote_data["acme_account_url"],
dkim_entry=remote_data["dkim_entry"],
ipv4=remote_data["A"],
ipv6=remote_data["AAAA"],
A=remote_data["A"],
AAAA=remote_data["AAAA"],
sts_id=sts_id,
chatmail_domain=mail_domain,
chatmail_domain=remote_data["mail_domain"],
)
lines = [x.strip() for x in zonefile.split("\n") if x.strip()]
lines.append("")
@@ -46,28 +44,10 @@ def get_filled_zone_file(remote_data, mail_domain):
return zonefile
def show_dns(args, out, remote_data) -> int:
def check_full_zone(sshexec, remote_data, out, zonefile) -> int:
"""Check existing DNS records, optionally write them to zone file
and return (exitcode, remote_data) tuple."""
sshexec = args.get_sshexec()
if not remote_data["acme_account_url"]:
out.red("could not get letsencrypt account url, please run 'cmdeploy run'")
return 1
if not remote_data["dkim_entry"]:
out.red("could not determine dkim_entry, please run 'cmdeploy run'")
return 1
zonefile = get_filled_zone_file(remote_data, args.config.mail_domain)
if getattr(args, "zonefile", None):
with open(args.zonefile, "w+") as zf:
zf.write(zonefile)
out.green(f"DNS records successfully written to: {args.zonefile}")
return 0
required_diff, recommended_diff = sshexec.logged(
remote_funcs.check_zonefile, kwargs=dict(zonefile=zonefile)
)

View File

@@ -31,7 +31,7 @@ def get_systemd_running():
def perform_initial_checks(mail_domain):
"""Collecting initial DNS zone content."""
"""Collecting initial DNS settings."""
assert mail_domain
A = query_dns("A", mail_domain)
AAAA = query_dns("AAAA", mail_domain)

View File

@@ -7,6 +7,10 @@ class FuncError(Exception):
pass
def print_stderr(item="", end="\n"):
print(item, file=sys.stderr, end=end)
class SSHExec:
RemoteError = execnet.RemoteError
FuncError = FuncError
@@ -32,17 +36,17 @@ class SSHExec:
def logged(self, call, kwargs):
def log_progress(data):
sys.stdout.write(".")
sys.stdout.flush()
sys.stderr.write(".")
sys.stderr.flush()
title = call.__doc__
if not title:
title = call.__name__
if self.verbose:
print("[ssh] " + title)
return self(call, kwargs, log_callback=print)
print_stderr("[ssh] " + title)
return self(call, kwargs, log_callback=print_stderr)
else:
print(title, end="")
print_stderr(title, end="")
res = self(call, kwargs, log_callback=log_progress)
print()
print_stderr()
return res

View File

@@ -27,16 +27,16 @@ class TestSSHExecutor:
remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
)
out, err = capsys.readouterr()
assert out.startswith("Collecting")
assert out.endswith("....\n")
assert out.count("\n") == 1
assert err.startswith("Collecting")
assert err.endswith("....\n")
assert err.count("\n") == 1
sshexec.verbose = True
sshexec.logged(
remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
)
out, err = capsys.readouterr()
lines = out.split("\n")
lines = err.split("\n")
assert len(lines) > 4
assert remote_funcs.perform_initial_checks.__doc__ in lines[0]

View File

@@ -80,14 +80,14 @@ def pytest_report_header():
@pytest.fixture
def data(request):
def cm_data(request):
datadir = request.fspath.dirpath("data")
class Data:
class CMData:
def get(self, name):
return datadir.join(name).read()
return Data()
return CMData()
@pytest.fixture

View File

@@ -1,7 +1,7 @@
import pytest
from cmdeploy import remote_funcs
from cmdeploy.dns import check_initial_remote_data
from cmdeploy.dns import check_full_zone, check_initial_remote_data
@pytest.fixture
@@ -59,9 +59,13 @@ class TestPerformInitialChecks:
assert len(l) == 2
def parse_zonefile_into_dict(zonefile, mockdns_base):
def parse_zonefile_into_dict(zonefile, mockdns_base, only_required=False):
for zf_line in zonefile.split("\n"):
if not zf_line.strip() or zf_line.startswith("#"):
if zf_line.startswith("#"):
if "Recommended" in zf_line and only_required:
return
continue
if not zf_line.strip():
continue
zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2)
zf_domain = zf_domain.rstrip(".")
@@ -69,18 +73,45 @@ def parse_zonefile_into_dict(zonefile, mockdns_base):
mockdns_base.setdefault(zf_typ, {})[zf_domain] = zf_value
def test_check_zonefile_all_ok(data, mockdns_base):
zonefile = data.get("zftest.zone")
parse_zonefile_into_dict(zonefile, mockdns_base)
required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile)
assert not required_diff and not recommended_diff
class MockSSHExec:
def logged(self, func, kwargs):
return func(**kwargs)
def call(self, func, kwargs):
return func(**kwargs)
def test_check_zonefile_recommended_not_set(data, mockdns_base):
zonefile = data.get("zftest.zone")
class TestZonefileChecks:
def test_check_zonefile_all_ok(self, cm_data, mockdns_base):
zonefile = cm_data.get("zftest.zone")
parse_zonefile_into_dict(zonefile, mockdns_base)
required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile)
assert not required_diff and not recommended_diff
zonefile_mocked = zonefile.split("# Recommended")[0]
parse_zonefile_into_dict(zonefile_mocked, mockdns_base)
required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile)
assert not required_diff
assert len(recommended_diff) == 8
def test_check_zonefile_recommended_not_set(self, cm_data, mockdns_base):
zonefile = cm_data.get("zftest.zone")
zonefile_mocked = zonefile.split("# Recommended")[0]
parse_zonefile_into_dict(zonefile_mocked, mockdns_base)
required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile)
assert not required_diff
assert len(recommended_diff) == 8
def test_check_zonefile_output_required_fine(self, cm_data, mockdns_base, mockout):
zonefile = cm_data.get("zftest.zone")
zonefile_mocked = zonefile.split("# Recommended")[0]
parse_zonefile_into_dict(zonefile_mocked, mockdns_base, only_required=True)
mssh = MockSSHExec()
res = check_full_zone(mssh, mockdns_base, out=mockout, zonefile=zonefile)
assert res == 0
assert "WARNING" in mockout.captured_plain[0]
assert len(mockout.captured_plain) == 9
def test_check_zonefile_output_full(self, cm_data, mockdns_base, mockout):
zonefile = cm_data.get("zftest.zone")
parse_zonefile_into_dict(zonefile, mockdns_base)
mssh = MockSSHExec()
res = check_full_zone(mssh, mockdns_base, out=mockout, zonefile=zonefile)
assert res == 0
assert not mockout.captured_red
assert "correct" in mockout.captured_green[0]
assert not mockout.captured_red