diff --git a/cmdeploy/src/cmdeploy/cmdeploy.py b/cmdeploy/src/cmdeploy/cmdeploy.py index 9b58511f..5a1f8c7f 100644 --- a/cmdeploy/src/cmdeploy/cmdeploy.py +++ b/cmdeploy/src/cmdeploy/cmdeploy.py @@ -55,7 +55,7 @@ def run_cmd(args, out): """Deploy chatmail services on the remote server.""" remote_data = dns.get_initial_remote_data(args, out) - if not remote_data: + if not dns.check_initial_remote_data(remote_data, print=out.red): return 1 env = os.environ.copy() @@ -283,16 +283,14 @@ def main(args=None): if not hasattr(args, "func"): return parser.parse_args(["-h"]) - ssh_exec_cache = [] + ssh_cache = [] def get_sshexec(): - if not ssh_exec_cache: + if not ssh_cache: print(f"[ssh] login to {args.config.mail_domain}") - ssh_exec = SSHExec( - args.config.mail_domain, remote_funcs, verbose=args.verbose - ) - ssh_exec_cache.append(ssh_exec) - return ssh_exec_cache[0] + ssh = SSHExec(args.config.mail_domain, remote_funcs, verbose=args.verbose) + ssh_cache.append(ssh) + return ssh_cache[0] args.get_sshexec = get_sshexec @@ -313,7 +311,6 @@ def main(args=None): if res is None: res = 0 return res - except KeyboardInterrupt: out.red("KeyboardInterrupt") sys.exit(130) diff --git a/cmdeploy/src/cmdeploy/dns.py b/cmdeploy/src/cmdeploy/dns.py index 708e475d..d97d268b 100644 --- a/cmdeploy/src/cmdeploy/dns.py +++ b/cmdeploy/src/cmdeploy/dns.py @@ -9,15 +9,18 @@ from . import remote_funcs def get_initial_remote_data(args, out): sshexec = args.get_sshexec() mail_domain = args.config.mail_domain - remote_data = sshexec.logged( + return sshexec.logged( call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain) ) + +def check_initial_remote_data(remote_data, print=print): + mail_domain = remote_data["mail_domain"] if not remote_data["A"] and not remote_data["AAAA"]: - out.red("Missing A and/or AAAA DNS records for {mail_domain}!") + print("Missing A and/or AAAA DNS records for {mail_domain}!") elif not remote_data["MTA_STS"]: - out.red("Missing MTA_STS record:") - out(f"{mail_domain}. CNAME {mail_domain}") + print("Missing MTA-STS CNAME record:") + print(f"mta-sts.{mail_domain}. CNAME {mail_domain}") else: return remote_data @@ -62,7 +65,7 @@ def show_dns(args, out, remote_data) -> int: with open(args.zonefile, "w+") as zf: zf.write(zonefile) out.green(f"DNS records successfully written to: {args.zonefile}") - return -1 + return 0 if diff_records: out.red("Please set the following DNS entries at your DNS provider:\n") diff --git a/cmdeploy/src/cmdeploy/remote_funcs.py b/cmdeploy/src/cmdeploy/remote_funcs.py index 6e8b78c6..4a0f3cd6 100644 --- a/cmdeploy/src/cmdeploy/remote_funcs.py +++ b/cmdeploy/src/cmdeploy/remote_funcs.py @@ -11,6 +11,7 @@ All functions of this module """ import re +import traceback from subprocess import CalledProcessError, check_output @@ -31,11 +32,12 @@ def get_systemd_running(): def perform_initial_checks(mail_domain): """Collecting initial DNS zone content.""" + assert mail_domain A = query_dns("A", mail_domain) AAAA = query_dns("AAAA", mail_domain) MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}") - res = dict(A=A, AAAA=AAAA, MTA_STS=MTA_STS) + res = dict(mail_domain=mail_domain, A=A, AAAA=AAAA, MTA_STS=MTA_STS) if not MTA_STS or (not A and not AAAA): return res @@ -69,14 +71,14 @@ def query_dns(typ, domain): print(res) if res: return res.split("\n")[0] + return "" def check_zonefile(zonefile): - """Check all expected zone file entries.""" + """Check expected zone file entries.""" diff = [] for zf_line in zonefile.splitlines(): - print("") print(f"dns-checking {zf_line!r}") zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2) zf_domain = zf_domain.rstrip(".") @@ -89,16 +91,35 @@ def check_zonefile(zonefile): return diff +## Function Execution server + + +def _run_loop(cmd_channel): + while 1: + cmd = cmd_channel.receive() + if cmd is None: + break + + cmd_channel.send(_handle_one_request(cmd)) + + +def _handle_one_request(cmd): + func_name, kwargs = cmd + try: + res = globals()[func_name](**kwargs) + return ("finish", res) + except: + data = traceback.format_exc() + return ("error", data) + + # check if this module is executed remotely # and setup a simple serialized function-execution loop if __name__ == "__channelexec__": + channel = channel # noqa (channel object gets injected) - def print(item): - channel.send(("log", item)) # noqa + # enable simple "print" debugging for anyone changing this module + globals()["print"] = lambda x="": channel.send(("log", x)) - while 1: - func_name, kwargs = channel.receive() # noqa - kwargs = kwargs if kwargs else {} - res = globals()[func_name](**kwargs) # noqa - channel.send(("finish", res)) # noqa + _run_loop(channel) diff --git a/cmdeploy/src/cmdeploy/sshexec.py b/cmdeploy/src/cmdeploy/sshexec.py index 36ef5837..4fdaa9d9 100644 --- a/cmdeploy/src/cmdeploy/sshexec.py +++ b/cmdeploy/src/cmdeploy/sshexec.py @@ -3,8 +3,13 @@ import sys import execnet +class FuncError(Exception): + pass + + class SSHExec: RemoteError = execnet.RemoteError + FuncError = FuncError def __init__(self, host, remote_funcs, verbose=False, python="python3", timeout=60): self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}") @@ -13,6 +18,8 @@ class SSHExec: self.verbose = verbose def __call__(self, call, kwargs=None, log_callback=None): + if kwargs is None: + kwargs = {} self._remote_cmdloop_channel.send((call.__name__, kwargs)) while 1: code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout) @@ -20,6 +27,8 @@ class SSHExec: log_callback(data) elif code == "finish": return data + elif code == "error": + raise self.FuncError(data) def logged(self, call, kwargs): def log_progress(data): diff --git a/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py b/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py index c94b3afa..44b68c18 100644 --- a/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py +++ b/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py @@ -40,6 +40,18 @@ class TestSSHExecutor: assert len(lines) > 4 assert remote_funcs.perform_initial_checks.__doc__ in lines[0] + def test_exception(self, sshexec, capsys): + try: + sshexec.logged( + remote_funcs.perform_initial_checks, + kwargs=dict(mail_domain=None), + ) + except sshexec.FuncError as e: + assert "remote_funcs.py" in str(e) + assert "AssertionError" in str(e) + else: + pytest.fail("didn't raise exception") + def test_remote(remote, imap_or_smtp): lineproducer = remote.iter_output(imap_or_smtp.logcmd) diff --git a/cmdeploy/src/cmdeploy/tests/test_dns.py b/cmdeploy/src/cmdeploy/tests/test_dns.py new file mode 100644 index 00000000..0740595d --- /dev/null +++ b/cmdeploy/src/cmdeploy/tests/test_dns.py @@ -0,0 +1,50 @@ +import pytest + +from cmdeploy import remote_funcs +from cmdeploy.dns import check_initial_remote_data + + +class TestPerformInitialChecks: + @pytest.fixture + def mockdns(self, monkeypatch): + qdict = { + "A": {"some.domain": "1.1.1.1"}, + "AAAA": {"some.domain": "fde5:cd7a:9e1c:3240:5a99:936f:cdac:53ae"}, + "CNAME": {"mta-sts.some.domain": "some.domain"}, + }.copy() + + def query_dns(typ, domain): + try: + return qdict[typ][domain] + except KeyError: + return "" + + monkeypatch.setattr(remote_funcs, query_dns.__name__, query_dns) + return qdict + + def test_perform_initial_checks_ok1(self, mockdns): + remote_data = remote_funcs.perform_initial_checks("some.domain") + assert len(remote_data) == 7 + + @pytest.mark.parametrize("drop", ["A", "AAAA"]) + def test_perform_initial_checks_with_one_of_A_AAAA(self, mockdns, drop): + del mockdns[drop] + remote_data = remote_funcs.perform_initial_checks("some.domain") + assert len(remote_data) == 7 + assert not remote_data[drop] + + l = [] + res = check_initial_remote_data(remote_data, print=l.append) + assert res + assert not l + + def test_perform_initial_checks_no_mta_sts(self, mockdns): + del mockdns["CNAME"] + remote_data = remote_funcs.perform_initial_checks("some.domain") + assert len(remote_data) == 4 + assert not remote_data["MTA_STS"] + + l = [] + res = check_initial_remote_data(remote_data, print=l.append) + assert not res + assert len(l) == 2