cmdeploy: allow to run SSH commands locally

fix #604
related to #629
pulled out of https://github.com/Keonik1/relay/pull/3
This commit is contained in:
missytake
2025-10-08 09:37:46 +02:00
parent 0ed7c360a9
commit fdabed5c67
7 changed files with 66 additions and 42 deletions

View File

@@ -2,6 +2,9 @@
## untagged ## untagged
- cmdeploy: make --ssh-host work with localhost
([#659](https://github.com/chatmail/relay/pull/659))
- Update iroh-relay to 0.35.0 - Update iroh-relay to 0.35.0
([#650](https://github.com/chatmail/relay/pull/650)) ([#650](https://github.com/chatmail/relay/pull/650))

View File

@@ -61,14 +61,15 @@ def run_cmd_options(parser):
parser.add_argument( parser.add_argument(
"--ssh-host", "--ssh-host",
dest="ssh_host", dest="ssh_host",
help="specify an SSH host to deploy to; uses mail_domain from chatmail.ini by default", help="Deploy to 'localhost' or to a specific SSH host",
) )
def run_cmd(args, out): def run_cmd(args, out):
"""Deploy chatmail services on the remote server.""" """Deploy chatmail services on the remote server."""
sshexec = args.get_sshexec() ssh_host = args.ssh_host if args.ssh_host else args.config.mail_domain
sshexec = get_sshexec(ssh_host)
require_iroh = args.config.enable_iroh_relay require_iroh = args.config.enable_iroh_relay
remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain) remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain)
if not dns.check_initial_remote_data(remote_data, print=out.red): if not dns.check_initial_remote_data(remote_data, print=out.red):
@@ -80,8 +81,11 @@ def run_cmd(args, out):
env["CHATMAIL_REQUIRE_IROH"] = "True" if require_iroh else "" env["CHATMAIL_REQUIRE_IROH"] = "True" if require_iroh else ""
deploy_path = importlib.resources.files(__package__).joinpath("deploy.py").resolve() deploy_path = importlib.resources.files(__package__).joinpath("deploy.py").resolve()
pyinf = "pyinfra --dry" if args.dry_run else "pyinfra" pyinf = "pyinfra --dry" if args.dry_run else "pyinfra"
ssh_host = args.config.mail_domain if not args.ssh_host else args.ssh_host
cmd = f"{pyinf} --ssh-user root {ssh_host} {deploy_path} -y" cmd = f"{pyinf} --ssh-user root {ssh_host} {deploy_path} -y"
if ssh_host == "localhost":
cmd = f"{pyinf} @local {deploy_path} -y"
if version.parse(pyinfra.__version__) < version.parse("3"): if version.parse(pyinfra.__version__) < version.parse("3"):
out.red("Please re-run scripts/initenv.sh to update pyinfra to version 3.") out.red("Please re-run scripts/initenv.sh to update pyinfra to version 3.")
return 1 return 1
@@ -118,11 +122,17 @@ def dns_cmd_options(parser):
default=None, default=None,
help="write out a zonefile", help="write out a zonefile",
) )
parser.add_argument(
"--ssh-host",
dest="ssh_host",
help="Run the DNS queries on 'localhost' or on a specific SSH host",
)
def dns_cmd(args, out): def dns_cmd(args, out):
"""Check DNS entries and optionally generate dns zone file.""" """Check DNS entries and optionally generate dns zone file."""
sshexec = args.get_sshexec() ssh_host = args.ssh_host if args.ssh_host else args.config.mail_domain
sshexec = get_sshexec(ssh_host, verbose=args.verbose)
remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain) remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain)
if not remote_data: if not remote_data:
return 1 return 1
@@ -331,6 +341,14 @@ def get_parser():
return parser return parser
def get_sshexec(ssh_host: str, verbose=True):
if ssh_host in ["localhost", "@local"]:
return "localhost"
if verbose:
print(f"[ssh] login to {ssh_host}")
return SSHExec(ssh_host, verbose=verbose)
def main(args=None): def main(args=None):
"""Provide main entry point for 'cmdeploy' CLI invocation.""" """Provide main entry point for 'cmdeploy' CLI invocation."""
parser = get_parser() parser = get_parser()
@@ -338,12 +356,6 @@ def main(args=None):
if not hasattr(args, "func"): if not hasattr(args, "func"):
return parser.parse_args(["-h"]) return parser.parse_args(["-h"])
def get_sshexec():
print(f"[ssh] login to {args.config.mail_domain}")
return SSHExec(args.config.mail_domain, verbose=args.verbose)
args.get_sshexec = get_sshexec
out = Out() out = Out()
kwargs = {} kwargs = {}
if args.func.__name__ not in ("init_cmd", "fmt_cmd"): if args.func.__name__ not in ("init_cmd", "fmt_cmd"):

View File

@@ -7,9 +7,13 @@ from . import remote
def get_initial_remote_data(sshexec, mail_domain): def get_initial_remote_data(sshexec, mail_domain):
return sshexec.logged( if sshexec == "localhost":
call=remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=mail_domain) result = remote.rdns.perform_initial_checks(mail_domain)
) else:
result = sshexec.logged(
call=remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=mail_domain)
)
return result
def check_initial_remote_data(remote_data, *, print=print): def check_initial_remote_data(remote_data, *, print=print):
@@ -44,10 +48,14 @@ def check_full_zone(sshexec, remote_data, out, zonefile) -> int:
"""Check existing DNS records, optionally write them to zone file """Check existing DNS records, optionally write them to zone file
and return (exitcode, remote_data) tuple.""" and return (exitcode, remote_data) tuple."""
required_diff, recommended_diff = sshexec.logged( if sshexec == "localhost":
remote.rdns.check_zonefile, required_diff, recommended_diff = remote.rdns.check_zonefile(
kwargs=dict(zonefile=zonefile, mail_domain=remote_data["mail_domain"]), zonefile=zonefile, verbose=False
) )
else:
required_diff, recommended_diff = sshexec.logged(
remote.rdns.check_zonefile, kwargs=dict(zonefile=zonefile, verbose=False),
)
returncode = 0 returncode = 0
if required_diff: if required_diff:

View File

@@ -12,23 +12,23 @@ All functions of this module
import re import re
from .rshell import CalledProcessError, shell from .rshell import CalledProcessError, shell, log_progress
def perform_initial_checks(mail_domain): def perform_initial_checks(mail_domain, pre_command=""):
"""Collecting initial DNS settings.""" """Collecting initial DNS settings."""
assert mail_domain assert mail_domain
if not shell("dig", fail_ok=True): if not shell("dig", fail_ok=True, print=log_progress):
shell("apt-get update && apt-get install -y dnsutils") shell("apt-get update && apt-get install -y dnsutils", print=log_progress)
A = query_dns("A", mail_domain) A = query_dns("A", mail_domain)
AAAA = query_dns("AAAA", mail_domain) AAAA = query_dns("AAAA", mail_domain)
MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}") MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}")
WWW = query_dns("CNAME", f"www.{mail_domain}") WWW = query_dns("CNAME", f"www.{mail_domain}")
res = dict(mail_domain=mail_domain, A=A, AAAA=AAAA, MTA_STS=MTA_STS, WWW=WWW) res = dict(mail_domain=mail_domain, A=A, AAAA=AAAA, MTA_STS=MTA_STS, WWW=WWW)
res["acme_account_url"] = shell("acmetool account-url", fail_ok=True) res["acme_account_url"] = shell(pre_command + "acmetool account-url", fail_ok=True, print=log_progress)
res["dkim_entry"], res["web_dkim_entry"] = get_dkim_entry( res["dkim_entry"], res["web_dkim_entry"] = get_dkim_entry(
mail_domain, dkim_selector="opendkim" mail_domain, pre_command, dkim_selector="opendkim"
) )
if not MTA_STS or not WWW or (not A and not AAAA): if not MTA_STS or not WWW or (not A and not AAAA):
@@ -40,11 +40,12 @@ def perform_initial_checks(mail_domain):
return res return res
def get_dkim_entry(mail_domain, dkim_selector): def get_dkim_entry(mail_domain, pre_command, dkim_selector):
try: try:
dkim_pubkey = shell( dkim_pubkey = shell(
f"openssl rsa -in /etc/dkimkeys/{dkim_selector}.private " f"{pre_command}openssl rsa -in /etc/dkimkeys/{dkim_selector}.private "
"-pubout 2>/dev/null | awk '/-/{next}{printf(\"%s\",$0)}'" "-pubout 2>/dev/null | awk '/-/{next}{printf(\"%s\",$0)}'",
print=log_progress
) )
except CalledProcessError: except CalledProcessError:
return return
@@ -61,7 +62,7 @@ def query_dns(typ, domain):
# Get autoritative nameserver from the SOA record. # Get autoritative nameserver from the SOA record.
soa_answers = [ soa_answers = [
x.split() x.split()
for x in shell(f"dig -r -q {domain} -t SOA +noall +authority +answer").split( for x in shell(f"dig -r -q {domain} -t SOA +noall +authority +answer", print=log_progress).split(
"\n" "\n"
) )
] ]
@@ -71,13 +72,13 @@ def query_dns(typ, domain):
ns = soa[0][4] ns = soa[0][4]
# Query authoritative nameserver directly to bypass DNS cache. # Query authoritative nameserver directly to bypass DNS cache.
res = shell(f"dig @{ns} -r -q {domain} -t {typ} +short") res = shell(f"dig @{ns} -r -q {domain} -t {typ} +short", print=log_progress)
if res: if res:
return res.split("\n")[0] return res.split("\n")[0]
return "" return ""
def check_zonefile(zonefile, mail_domain): def check_zonefile(zonefile, verbose=True):
"""Check expected zone file entries.""" """Check expected zone file entries."""
required = True required = True
required_diff = [] required_diff = []
@@ -89,7 +90,7 @@ def check_zonefile(zonefile, mail_domain):
continue continue
if not zf_line.strip() or zf_line.startswith(";"): if not zf_line.strip() or zf_line.startswith(";"):
continue continue
print(f"dns-checking {zf_line!r}") print(f"dns-checking {zf_line!r}") if verbose else log_progress("")
zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2) zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2)
zf_domain = zf_domain.rstrip(".") zf_domain = zf_domain.rstrip(".")
zf_value = zf_value.strip() zf_value = zf_value.strip()

View File

@@ -1,7 +1,14 @@
import sys
from subprocess import DEVNULL, CalledProcessError, check_output from subprocess import DEVNULL, CalledProcessError, check_output
def shell(command, fail_ok=False): def log_progress(data):
sys.stderr.write(".")
sys.stderr.flush()
def shell(command, fail_ok=False, print=print):
print(f"$ {command}") print(f"$ {command}")
args = dict(shell=True) args = dict(shell=True)
if fail_ok: if fail_ok:

View File

@@ -42,6 +42,7 @@ def bootstrap_remote(gateway, remote=remote):
def print_stderr(item="", end="\n"): def print_stderr(item="", end="\n"):
print(item, file=sys.stderr, end=end) print(item, file=sys.stderr, end=end)
sys.stderr.flush()
class SSHExec: class SSHExec:
@@ -70,10 +71,6 @@ class SSHExec:
raise self.FuncError(data) raise self.FuncError(data)
def logged(self, call, kwargs): def logged(self, call, kwargs):
def log_progress(data):
sys.stderr.write(".")
sys.stderr.flush()
title = call.__doc__ title = call.__doc__
if not title: if not title:
title = call.__name__ title = call.__name__
@@ -82,6 +79,6 @@ class SSHExec:
return self(call, kwargs, log_callback=print_stderr) return self(call, kwargs, log_callback=print_stderr)
else: else:
print_stderr(title, end="") print_stderr(title, end="")
res = self(call, kwargs, log_callback=log_progress) res = self(call, kwargs, log_callback=remote.rshell.log_progress)
print_stderr() print_stderr()
return res return res

View File

@@ -89,18 +89,14 @@ class TestZonefileChecks:
def test_check_zonefile_all_ok(self, cm_data, mockdns_base): def test_check_zonefile_all_ok(self, cm_data, mockdns_base):
zonefile = cm_data.get("zftest.zone") zonefile = cm_data.get("zftest.zone")
parse_zonefile_into_dict(zonefile, mockdns_base) parse_zonefile_into_dict(zonefile, mockdns_base)
required_diff, recommended_diff = remote.rdns.check_zonefile( required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile)
zonefile, "some.domain"
)
assert not required_diff and not recommended_diff assert not required_diff and not recommended_diff
def test_check_zonefile_recommended_not_set(self, cm_data, mockdns_base): def test_check_zonefile_recommended_not_set(self, cm_data, mockdns_base):
zonefile = cm_data.get("zftest.zone") zonefile = cm_data.get("zftest.zone")
zonefile_mocked = zonefile.split("; Recommended")[0] zonefile_mocked = zonefile.split("; Recommended")[0]
parse_zonefile_into_dict(zonefile_mocked, mockdns_base) parse_zonefile_into_dict(zonefile_mocked, mockdns_base)
required_diff, recommended_diff = remote.rdns.check_zonefile( required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile)
zonefile, "some.domain"
)
assert not required_diff assert not required_diff
assert len(recommended_diff) == 8 assert len(recommended_diff) == 8