diff --git a/cmdeploy/src/cmdeploy/cmdeploy.py b/cmdeploy/src/cmdeploy/cmdeploy.py index 6fdc2d4b..9b58511f 100644 --- a/cmdeploy/src/cmdeploy/cmdeploy.py +++ b/cmdeploy/src/cmdeploy/cmdeploy.py @@ -15,8 +15,7 @@ from pathlib import Path from chatmaild.config import read_config, write_initial_config from termcolor import colored -from . import remote_funcs -from .dns import NoIPRecords, show_dns +from . import dns, remote_funcs from .sshexec import SSHExec # @@ -54,7 +53,10 @@ def run_cmd_options(parser): def run_cmd(args, out): """Deploy chatmail services on the remote server.""" - retcode, remote_data = show_dns(args, out) + + remote_data = dns.get_initial_remote_data(args, out) + if not remote_data: + return 1 env = os.environ.copy() env["CHATMAIL_INI"] = args.inipath @@ -62,9 +64,9 @@ def run_cmd(args, out): pyinf = "pyinfra --dry" if args.dry_run else "pyinfra" cmd = f"{pyinf} --ssh-user root {args.config.mail_domain} {deploy_path}" - out.check_call(cmd, env=env) + retcode = out.check_call(cmd, env=env) if retcode == 0: - out.green("Deploy completed, call `cmdeploy test` next.") + out.green("Deploy completed, call `cmdeploy dns` next.") elif not remote_data["acme_account_url"]: out.red("Deploy completed but letsencrypt not configured") out.red("Run 'cmdeploy run' again") @@ -84,11 +86,10 @@ def dns_cmd_options(parser): def dns_cmd(args, out): """Check DNS entries and optionally generate dns zone file.""" - retcode, remote_data = show_dns(args, out) - for name in ["acme_account_url", "dkim_entry"]: - if not remote_data[name]: - # dns run insists on all records present - return 1 + remote_data = dns.get_initial_remote_data(args, out) + if not remote_data: + return 1 + retcode = dns.show_dns(args, out, remote_data) return retcode @@ -282,9 +283,16 @@ def main(args=None): if not hasattr(args, "func"): return parser.parse_args(["-h"]) - def get_sshexec(log=None): - print(f"[ssh] login to {args.config.mail_domain}") - return SSHExec(args.config.mail_domain, remote_funcs, log=log) + ssh_exec_cache = [] + + def get_sshexec(): + if not ssh_exec_cache: + print(f"[ssh] login to {args.config.mail_domain}") + ssh_exec = SSHExec( + args.config.mail_domain, remote_funcs, verbose=args.verbose + ) + ssh_exec_cache.append(ssh_exec) + return ssh_exec_cache[0] args.get_sshexec = get_sshexec @@ -309,9 +317,6 @@ def main(args=None): except KeyboardInterrupt: out.red("KeyboardInterrupt") sys.exit(130) - except NoIPRecords as e: - out.red(str(e)) - sys.exit(1) if __name__ == "__main__": diff --git a/cmdeploy/src/cmdeploy/dns.py b/cmdeploy/src/cmdeploy/dns.py index b51c2705..708e475d 100644 --- a/cmdeploy/src/cmdeploy/dns.py +++ b/cmdeploy/src/cmdeploy/dns.py @@ -1,71 +1,74 @@ import datetime import importlib -import sys from jinja2 import Template from . import remote_funcs -class NoIPRecords(Exception): - """Indicates that no DNS A or AAAA record is present.""" +def get_initial_remote_data(args, out): + sshexec = args.get_sshexec() + mail_domain = args.config.mail_domain + remote_data = sshexec.logged( + call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain) + ) + + if not remote_data["A"] and not remote_data["AAAA"]: + out.red("Missing A and/or AAAA DNS records for {mail_domain}!") + elif not remote_data["MTA_STS"]: + out.red("Missing MTA_STS record:") + out(f"{mail_domain}. CNAME {mail_domain}") + else: + return remote_data -def show_dns(args, out) -> int: +def show_dns(args, out, remote_data) -> int: """Check existing DNS records, optionally write them to zone file and return (exitcode, remote_data) tuple.""" - template = importlib.resources.files(__package__).joinpath("chatmail.zone.j2") - mail_domain = args.config.mail_domain - def log_progress(data): - sys.stdout.write(".") - sys.stdout.flush() + sshexec = args.get_sshexec() - sshexec = args.get_sshexec(log=print if args.verbose else log_progress) - print("Checking DNS entries ", end="\n" if args.verbose else "") + if not remote_data["acme_account_url"]: + out.red("could not get letsencrypt account url, please run 'cmdeploy run'") + return 1 - remote_data = sshexec(remote_funcs.perform_initial_checks, mail_domain=mail_domain) - - if not remote_data["ipv4"] and not remote_data["ipv6"]: - raise NoIPRecords(f"No A or AAAA DNS records set for {mail_domain}!") + if not remote_data["dkim_entry"]: + out.red("could not determine dkim_entry, please run 'cmdeploy run'") + return 1 sts_id = remote_data.get("sts_id") if not sts_id: sts_id = datetime.datetime.now().strftime("%Y%m%d%H%M") + template = importlib.resources.files(__package__).joinpath("chatmail.zone.j2") content = template.read_text() zonefile = Template(content).render( acme_account_url=remote_data.get("acme_account_url"), - dkim_entry=remote_data.get("dkim_entry"), - ipv4=remote_data["ipv4"], - ipv6=remote_data["ipv6"], + dkim_entry=remote_data["dkim_entry"], + ipv4=remote_data["A"], + ipv6=remote_data["AAAA"], sts_id=sts_id, chatmail_domain=args.config.mail_domain, ) - zonefile = "\n".join([x.strip() for x in zonefile.split("\n") if x.strip()]) + lines = [x.strip() for x in zonefile.split("\n") if x.strip()] + lines.append("") + zonefile = "\n".join(lines) - to_print = sshexec(remote_funcs.check_zonefile, zonefile=zonefile) - if not args.verbose: - print() + diff_records = sshexec.logged( + remote_funcs.check_zonefile, kwargs=dict(zonefile=zonefile) + ) if getattr(args, "zonefile", None): with open(args.zonefile, "w+") as zf: zf.write(zonefile) out.green(f"DNS records successfully written to: {args.zonefile}") - return 0, remote_data + return -1 - if to_print: - to_print.insert( - 0, "You should configure the following entries at your DNS provider:\n" - ) - to_print.append( - "\nIf you already configured the DNS entries, " - "wait a bit until the DNS entries propagate to the Internet." - ) - out.red("\n".join(to_print)) - exit_code = 1 + if diff_records: + out.red("Please set the following DNS entries at your DNS provider:\n") + for line in diff_records: + out(line) + return 1 else: out.green("Great! All your DNS entries are verified and correct.") - exit_code = 0 - - return exit_code, remote_data + return 0 diff --git a/cmdeploy/src/cmdeploy/remote_funcs.py b/cmdeploy/src/cmdeploy/remote_funcs.py index 32e02bec..6e8b78c6 100644 --- a/cmdeploy/src/cmdeploy/remote_funcs.py +++ b/cmdeploy/src/cmdeploy/remote_funcs.py @@ -1,12 +1,13 @@ """ -Functions to be executed on an ssh-connected host. +Pure python functions which execute remotely in a system Python interpreter. -All functions of this module need to work with Python builtin types -and standard library dependencies only. +All functions of this module -When a remote function executes remotely, it runs in a system python interpreter -without any installed dependencies. +- need to get and and return Python builtin data types only, +- can only use standard library dependencies, + +- can freely call each other. """ import re @@ -29,7 +30,14 @@ def get_systemd_running(): def perform_initial_checks(mail_domain): - res = {} + """Collecting initial DNS zone content.""" + A = query_dns("A", mail_domain) + AAAA = query_dns("AAAA", mail_domain) + MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}") + + res = dict(A=A, AAAA=AAAA, MTA_STS=MTA_STS) + if not MTA_STS or (not A and not AAAA): + return res res["acme_account_url"] = shell("acmetool account-url", fail_ok=True) if not shell("dig", fail_ok=True): @@ -37,18 +45,9 @@ def perform_initial_checks(mail_domain): shell(f"unbound-control flush_zone {mail_domain}", fail_ok=True) res["dkim_entry"] = get_dkim_entry(mail_domain, dkim_selector="opendkim") - res["ipv4"] = query_dns("A", mail_domain) - res["ipv6"] = query_dns("AAAA", mail_domain) - - # parse out sts-id if exists - val = query_dns("TXT", f"_mta-sts.{mail_domain}") - if val: - # "v=STSv1; id={{ sts_id }}" - parts = val.split("id=") - if len(parts) == 2: - val = parts[1].rstrip('"') - res["sts_id"] = val - + # parse out sts-id if exists, example: "v=STSv1; id=2090123" + parts = query_dns("TXT", f"_mta-sts.{mail_domain}").split("id=") + res["sts_id"] = parts[1].rstrip('"') if len(parts) == 2 else "" return res @@ -73,6 +72,7 @@ def query_dns(typ, domain): def check_zonefile(zonefile): + """Check all expected zone file entries.""" diff = [] for zf_line in zonefile.splitlines(): @@ -99,5 +99,6 @@ if __name__ == "__channelexec__": while 1: func_name, kwargs = channel.receive() # noqa + kwargs = kwargs if kwargs else {} res = globals()[func_name](**kwargs) # noqa channel.send(("finish", res)) # noqa diff --git a/cmdeploy/src/cmdeploy/sshexec.py b/cmdeploy/src/cmdeploy/sshexec.py index 31974b36..36ef5837 100644 --- a/cmdeploy/src/cmdeploy/sshexec.py +++ b/cmdeploy/src/cmdeploy/sshexec.py @@ -1,20 +1,39 @@ +import sys + import execnet class SSHExec: RemoteError = execnet.RemoteError - def __init__(self, host, remote_funcs, log=None, python="python3", timeout=60): + def __init__(self, host, remote_funcs, verbose=False, python="python3", timeout=60): self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}") self._remote_cmdloop_channel = self.gateway.remote_exec(remote_funcs) - self.log = log self.timeout = timeout + self.verbose = verbose - def __call__(self, func, **kwargs): - self._remote_cmdloop_channel.send((func.__name__, kwargs)) + def __call__(self, call, kwargs=None, log_callback=None): + self._remote_cmdloop_channel.send((call.__name__, kwargs)) while 1: code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout) - if code == "log" and self.log: - self.log(data) + if log_callback is not None and code == "log": + log_callback(data) elif code == "finish": return data + + def logged(self, call, kwargs): + def log_progress(data): + sys.stdout.write(".") + sys.stdout.flush() + + title = call.__doc__ + if not title: + title = call.__name__ + if self.verbose: + print("[ssh] " + title) + return self(call, kwargs, log_callback=print) + else: + print(title, end="") + res = self(call, kwargs, log_callback=log_progress) + print() + return res diff --git a/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py b/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py index eea07081..c94b3afa 100644 --- a/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py +++ b/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py @@ -7,18 +7,38 @@ from cmdeploy.sshexec import SSHExec class TestSSHExecutor: - @pytest.fixture + @pytest.fixture(scope="class") def sshexec(self, sshdomain): return SSHExec(sshdomain, remote_funcs) def test_ls(self, sshexec): - out = sshexec(remote_funcs.shell, command="ls") - out2 = sshexec(remote_funcs.shell, command="ls") + out = sshexec(call=remote_funcs.shell, kwargs=dict(command="ls")) + out2 = sshexec(call=remote_funcs.shell, kwargs=dict(command="ls")) assert out == out2 def test_perform_initial(self, sshexec, maildomain): - res = sshexec(remote_funcs.perform_initial_checks, mail_domain=maildomain) - assert res["ipv4"] or res["ipv6"] + res = sshexec( + remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain) + ) + assert res["A"] or res["AAAA"] + + def test_logged(self, sshexec, maildomain, capsys): + sshexec.logged( + remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain) + ) + out, err = capsys.readouterr() + assert out.startswith("Collecting") + assert out.endswith("....\n") + assert out.count("\n") == 1 + + sshexec.verbose = True + sshexec.logged( + remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain) + ) + out, err = capsys.readouterr() + lines = out.split("\n") + assert len(lines) > 4 + assert remote_funcs.perform_initial_checks.__doc__ in lines[0] def test_remote(remote, imap_or_smtp): diff --git a/cmdeploy/src/cmdeploy/tests/plugin.py b/cmdeploy/src/cmdeploy/tests/plugin.py index c78866f1..dcca725f 100644 --- a/cmdeploy/src/cmdeploy/tests/plugin.py +++ b/cmdeploy/src/cmdeploy/tests/plugin.py @@ -35,7 +35,7 @@ def pytest_runtest_setup(item): pytest.skip("skipping slow test, use --slow to run") -@pytest.fixture +@pytest.fixture(scope="session") def chatmail_config(pytestconfig): current = basedir = Path().resolve() while 1: @@ -49,12 +49,12 @@ def chatmail_config(pytestconfig): pytest.skip(f"no chatmail.ini file found in {basedir} or parent dirs") -@pytest.fixture +@pytest.fixture(scope="session") def maildomain(chatmail_config): return chatmail_config.mail_domain -@pytest.fixture +@pytest.fixture(scope="session") def sshdomain(maildomain): return os.environ.get("CHATMAIL_SSH", maildomain)