diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ea09593..b19f2ec1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## untagged +- cmdeploy: make --ssh-host work with localhost + ([#659](https://github.com/chatmail/relay/pull/659)) + - Update iroh-relay to 0.35.0 ([#650](https://github.com/chatmail/relay/pull/650)) diff --git a/cmdeploy/src/cmdeploy/cmdeploy.py b/cmdeploy/src/cmdeploy/cmdeploy.py index 145c4bf6..8ae2481b 100644 --- a/cmdeploy/src/cmdeploy/cmdeploy.py +++ b/cmdeploy/src/cmdeploy/cmdeploy.py @@ -61,14 +61,15 @@ def run_cmd_options(parser): parser.add_argument( "--ssh-host", dest="ssh_host", - help="specify an SSH host to deploy to; uses mail_domain from chatmail.ini by default", + help="Deploy to 'localhost' or to a specific SSH host", ) def run_cmd(args, out): """Deploy chatmail services on the remote server.""" - sshexec = args.get_sshexec() + ssh_host = args.ssh_host if args.ssh_host else args.config.mail_domain + sshexec = get_sshexec(ssh_host) require_iroh = args.config.enable_iroh_relay remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain) if not dns.check_initial_remote_data(remote_data, print=out.red): @@ -80,8 +81,11 @@ def run_cmd(args, out): env["CHATMAIL_REQUIRE_IROH"] = "True" if require_iroh else "" deploy_path = importlib.resources.files(__package__).joinpath("deploy.py").resolve() pyinf = "pyinfra --dry" if args.dry_run else "pyinfra" - ssh_host = args.config.mail_domain if not args.ssh_host else args.ssh_host + cmd = f"{pyinf} --ssh-user root {ssh_host} {deploy_path} -y" + if ssh_host == "localhost": + cmd = f"{pyinf} @local {deploy_path} -y" + if version.parse(pyinfra.__version__) < version.parse("3"): out.red("Please re-run scripts/initenv.sh to update pyinfra to version 3.") return 1 @@ -118,11 +122,17 @@ def dns_cmd_options(parser): default=None, help="write out a zonefile", ) + parser.add_argument( + "--ssh-host", + dest="ssh_host", + help="Run the DNS queries on 'localhost' or on a specific SSH host", + ) def dns_cmd(args, out): """Check DNS entries and optionally generate dns zone file.""" - sshexec = args.get_sshexec() + ssh_host = args.ssh_host if args.ssh_host else args.config.mail_domain + sshexec = get_sshexec(ssh_host, verbose=args.verbose) remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain) if not remote_data: return 1 @@ -331,6 +341,14 @@ def get_parser(): return parser +def get_sshexec(ssh_host: str, verbose=True): + if ssh_host in ["localhost", "@local"]: + return "localhost" + if verbose: + print(f"[ssh] login to {ssh_host}") + return SSHExec(ssh_host, verbose=verbose) + + def main(args=None): """Provide main entry point for 'cmdeploy' CLI invocation.""" parser = get_parser() @@ -338,12 +356,6 @@ def main(args=None): if not hasattr(args, "func"): return parser.parse_args(["-h"]) - def get_sshexec(): - print(f"[ssh] login to {args.config.mail_domain}") - return SSHExec(args.config.mail_domain, verbose=args.verbose) - - args.get_sshexec = get_sshexec - out = Out() kwargs = {} if args.func.__name__ not in ("init_cmd", "fmt_cmd"): diff --git a/cmdeploy/src/cmdeploy/dns.py b/cmdeploy/src/cmdeploy/dns.py index 17456fd7..6277d158 100644 --- a/cmdeploy/src/cmdeploy/dns.py +++ b/cmdeploy/src/cmdeploy/dns.py @@ -7,9 +7,13 @@ from . import remote def get_initial_remote_data(sshexec, mail_domain): - return sshexec.logged( - call=remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=mail_domain) - ) + if sshexec == "localhost": + result = remote.rdns.perform_initial_checks(mail_domain) + else: + result = sshexec.logged( + call=remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=mail_domain) + ) + return result def check_initial_remote_data(remote_data, *, print=print): @@ -44,10 +48,14 @@ def check_full_zone(sshexec, remote_data, out, zonefile) -> int: """Check existing DNS records, optionally write them to zone file and return (exitcode, remote_data) tuple.""" - required_diff, recommended_diff = sshexec.logged( - remote.rdns.check_zonefile, - kwargs=dict(zonefile=zonefile, mail_domain=remote_data["mail_domain"]), - ) + if sshexec == "localhost": + required_diff, recommended_diff = remote.rdns.check_zonefile( + zonefile=zonefile, verbose=False + ) + else: + required_diff, recommended_diff = sshexec.logged( + remote.rdns.check_zonefile, kwargs=dict(zonefile=zonefile, verbose=False), + ) returncode = 0 if required_diff: diff --git a/cmdeploy/src/cmdeploy/remote/rdns.py b/cmdeploy/src/cmdeploy/remote/rdns.py index fd847efd..7340a777 100644 --- a/cmdeploy/src/cmdeploy/remote/rdns.py +++ b/cmdeploy/src/cmdeploy/remote/rdns.py @@ -12,23 +12,23 @@ All functions of this module import re -from .rshell import CalledProcessError, shell +from .rshell import CalledProcessError, shell, log_progress -def perform_initial_checks(mail_domain): +def perform_initial_checks(mail_domain, pre_command=""): """Collecting initial DNS settings.""" assert mail_domain - if not shell("dig", fail_ok=True): - shell("apt-get update && apt-get install -y dnsutils") + if not shell("dig", fail_ok=True, print=log_progress): + shell("apt-get update && apt-get install -y dnsutils", print=log_progress) A = query_dns("A", mail_domain) AAAA = query_dns("AAAA", mail_domain) MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}") WWW = query_dns("CNAME", f"www.{mail_domain}") res = dict(mail_domain=mail_domain, A=A, AAAA=AAAA, MTA_STS=MTA_STS, WWW=WWW) - res["acme_account_url"] = shell("acmetool account-url", fail_ok=True) + res["acme_account_url"] = shell(pre_command + "acmetool account-url", fail_ok=True, print=log_progress) res["dkim_entry"], res["web_dkim_entry"] = get_dkim_entry( - mail_domain, dkim_selector="opendkim" + mail_domain, pre_command, dkim_selector="opendkim" ) if not MTA_STS or not WWW or (not A and not AAAA): @@ -40,11 +40,12 @@ def perform_initial_checks(mail_domain): return res -def get_dkim_entry(mail_domain, dkim_selector): +def get_dkim_entry(mail_domain, pre_command, dkim_selector): try: dkim_pubkey = shell( - f"openssl rsa -in /etc/dkimkeys/{dkim_selector}.private " - "-pubout 2>/dev/null | awk '/-/{next}{printf(\"%s\",$0)}'" + f"{pre_command}openssl rsa -in /etc/dkimkeys/{dkim_selector}.private " + "-pubout 2>/dev/null | awk '/-/{next}{printf(\"%s\",$0)}'", + print=log_progress ) except CalledProcessError: return @@ -61,7 +62,7 @@ def query_dns(typ, domain): # Get autoritative nameserver from the SOA record. soa_answers = [ x.split() - for x in shell(f"dig -r -q {domain} -t SOA +noall +authority +answer").split( + for x in shell(f"dig -r -q {domain} -t SOA +noall +authority +answer", print=log_progress).split( "\n" ) ] @@ -71,13 +72,13 @@ def query_dns(typ, domain): ns = soa[0][4] # Query authoritative nameserver directly to bypass DNS cache. - res = shell(f"dig @{ns} -r -q {domain} -t {typ} +short") + res = shell(f"dig @{ns} -r -q {domain} -t {typ} +short", print=log_progress) if res: return res.split("\n")[0] return "" -def check_zonefile(zonefile, mail_domain): +def check_zonefile(zonefile, verbose=True): """Check expected zone file entries.""" required = True required_diff = [] @@ -89,7 +90,7 @@ def check_zonefile(zonefile, mail_domain): continue if not zf_line.strip() or zf_line.startswith(";"): continue - print(f"dns-checking {zf_line!r}") + print(f"dns-checking {zf_line!r}") if verbose else log_progress("") zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2) zf_domain = zf_domain.rstrip(".") zf_value = zf_value.strip() diff --git a/cmdeploy/src/cmdeploy/remote/rshell.py b/cmdeploy/src/cmdeploy/remote/rshell.py index 042c5bf2..f8166816 100644 --- a/cmdeploy/src/cmdeploy/remote/rshell.py +++ b/cmdeploy/src/cmdeploy/remote/rshell.py @@ -1,7 +1,14 @@ +import sys + from subprocess import DEVNULL, CalledProcessError, check_output -def shell(command, fail_ok=False): +def log_progress(data): + sys.stderr.write(".") + sys.stderr.flush() + + +def shell(command, fail_ok=False, print=print): print(f"$ {command}") args = dict(shell=True) if fail_ok: diff --git a/cmdeploy/src/cmdeploy/sshexec.py b/cmdeploy/src/cmdeploy/sshexec.py index 8a87e781..400ce50d 100644 --- a/cmdeploy/src/cmdeploy/sshexec.py +++ b/cmdeploy/src/cmdeploy/sshexec.py @@ -42,6 +42,7 @@ def bootstrap_remote(gateway, remote=remote): def print_stderr(item="", end="\n"): print(item, file=sys.stderr, end=end) + sys.stderr.flush() class SSHExec: @@ -70,10 +71,6 @@ class SSHExec: raise self.FuncError(data) def logged(self, call, kwargs): - def log_progress(data): - sys.stderr.write(".") - sys.stderr.flush() - title = call.__doc__ if not title: title = call.__name__ @@ -82,6 +79,6 @@ class SSHExec: return self(call, kwargs, log_callback=print_stderr) else: print_stderr(title, end="") - res = self(call, kwargs, log_callback=log_progress) + res = self(call, kwargs, log_callback=remote.rshell.log_progress) print_stderr() return res diff --git a/cmdeploy/src/cmdeploy/tests/test_dns.py b/cmdeploy/src/cmdeploy/tests/test_dns.py index fd11095f..d6f756b7 100644 --- a/cmdeploy/src/cmdeploy/tests/test_dns.py +++ b/cmdeploy/src/cmdeploy/tests/test_dns.py @@ -89,18 +89,14 @@ class TestZonefileChecks: def test_check_zonefile_all_ok(self, cm_data, mockdns_base): zonefile = cm_data.get("zftest.zone") parse_zonefile_into_dict(zonefile, mockdns_base) - required_diff, recommended_diff = remote.rdns.check_zonefile( - zonefile, "some.domain" - ) + required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile) assert not required_diff and not recommended_diff def test_check_zonefile_recommended_not_set(self, cm_data, mockdns_base): zonefile = cm_data.get("zftest.zone") zonefile_mocked = zonefile.split("; Recommended")[0] parse_zonefile_into_dict(zonefile_mocked, mockdns_base) - required_diff, recommended_diff = remote.rdns.check_zonefile( - zonefile, "some.domain" - ) + required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile) assert not required_diff assert len(recommended_diff) == 8