add DNS tests, make remote ssh-exec errors show locally, cleanup ssh-bootstrap

This commit is contained in:
holger krekel
2024-07-11 12:23:24 +02:00
parent 1cb64b4777
commit 4f175eec94
6 changed files with 116 additions and 24 deletions

View File

@@ -55,7 +55,7 @@ def run_cmd(args, out):
"""Deploy chatmail services on the remote server."""
remote_data = dns.get_initial_remote_data(args, out)
if not remote_data:
if not dns.check_initial_remote_data(remote_data, print=out.red):
return 1
env = os.environ.copy()
@@ -283,16 +283,14 @@ def main(args=None):
if not hasattr(args, "func"):
return parser.parse_args(["-h"])
ssh_exec_cache = []
ssh_cache = []
def get_sshexec():
if not ssh_exec_cache:
if not ssh_cache:
print(f"[ssh] login to {args.config.mail_domain}")
ssh_exec = SSHExec(
args.config.mail_domain, remote_funcs, verbose=args.verbose
)
ssh_exec_cache.append(ssh_exec)
return ssh_exec_cache[0]
ssh = SSHExec(args.config.mail_domain, remote_funcs, verbose=args.verbose)
ssh_cache.append(ssh)
return ssh_cache[0]
args.get_sshexec = get_sshexec
@@ -313,7 +311,6 @@ def main(args=None):
if res is None:
res = 0
return res
except KeyboardInterrupt:
out.red("KeyboardInterrupt")
sys.exit(130)

View File

@@ -9,15 +9,18 @@ from . import remote_funcs
def get_initial_remote_data(args, out):
sshexec = args.get_sshexec()
mail_domain = args.config.mail_domain
remote_data = sshexec.logged(
return sshexec.logged(
call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain)
)
def check_initial_remote_data(remote_data, print=print):
mail_domain = remote_data["mail_domain"]
if not remote_data["A"] and not remote_data["AAAA"]:
out.red("Missing A and/or AAAA DNS records for {mail_domain}!")
print("Missing A and/or AAAA DNS records for {mail_domain}!")
elif not remote_data["MTA_STS"]:
out.red("Missing MTA_STS record:")
out(f"{mail_domain}. CNAME {mail_domain}")
print("Missing MTA-STS CNAME record:")
print(f"mta-sts.{mail_domain}. CNAME {mail_domain}")
else:
return remote_data
@@ -62,7 +65,7 @@ def show_dns(args, out, remote_data) -> int:
with open(args.zonefile, "w+") as zf:
zf.write(zonefile)
out.green(f"DNS records successfully written to: {args.zonefile}")
return -1
return 0
if diff_records:
out.red("Please set the following DNS entries at your DNS provider:\n")

View File

@@ -11,6 +11,7 @@ All functions of this module
"""
import re
import traceback
from subprocess import CalledProcessError, check_output
@@ -31,11 +32,12 @@ def get_systemd_running():
def perform_initial_checks(mail_domain):
"""Collecting initial DNS zone content."""
assert mail_domain
A = query_dns("A", mail_domain)
AAAA = query_dns("AAAA", mail_domain)
MTA_STS = query_dns("CNAME", f"mta-sts.{mail_domain}")
res = dict(A=A, AAAA=AAAA, MTA_STS=MTA_STS)
res = dict(mail_domain=mail_domain, A=A, AAAA=AAAA, MTA_STS=MTA_STS)
if not MTA_STS or (not A and not AAAA):
return res
@@ -69,14 +71,14 @@ def query_dns(typ, domain):
print(res)
if res:
return res.split("\n")[0]
return ""
def check_zonefile(zonefile):
"""Check all expected zone file entries."""
"""Check expected zone file entries."""
diff = []
for zf_line in zonefile.splitlines():
print("")
print(f"dns-checking {zf_line!r}")
zf_domain, zf_typ, zf_value = zf_line.split(maxsplit=2)
zf_domain = zf_domain.rstrip(".")
@@ -89,16 +91,35 @@ def check_zonefile(zonefile):
return diff
## Function Execution server
def _run_loop(cmd_channel):
while 1:
cmd = cmd_channel.receive()
if cmd is None:
break
cmd_channel.send(_handle_one_request(cmd))
def _handle_one_request(cmd):
func_name, kwargs = cmd
try:
res = globals()[func_name](**kwargs)
return ("finish", res)
except:
data = traceback.format_exc()
return ("error", data)
# check if this module is executed remotely
# and setup a simple serialized function-execution loop
if __name__ == "__channelexec__":
channel = channel # noqa (channel object gets injected)
def print(item):
channel.send(("log", item)) # noqa
# enable simple "print" debugging for anyone changing this module
globals()["print"] = lambda x="": channel.send(("log", x))
while 1:
func_name, kwargs = channel.receive() # noqa
kwargs = kwargs if kwargs else {}
res = globals()[func_name](**kwargs) # noqa
channel.send(("finish", res)) # noqa
_run_loop(channel)

View File

@@ -3,8 +3,13 @@ import sys
import execnet
class FuncError(Exception):
pass
class SSHExec:
RemoteError = execnet.RemoteError
FuncError = FuncError
def __init__(self, host, remote_funcs, verbose=False, python="python3", timeout=60):
self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}")
@@ -13,6 +18,8 @@ class SSHExec:
self.verbose = verbose
def __call__(self, call, kwargs=None, log_callback=None):
if kwargs is None:
kwargs = {}
self._remote_cmdloop_channel.send((call.__name__, kwargs))
while 1:
code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout)
@@ -20,6 +27,8 @@ class SSHExec:
log_callback(data)
elif code == "finish":
return data
elif code == "error":
raise self.FuncError(data)
def logged(self, call, kwargs):
def log_progress(data):

View File

@@ -40,6 +40,18 @@ class TestSSHExecutor:
assert len(lines) > 4
assert remote_funcs.perform_initial_checks.__doc__ in lines[0]
def test_exception(self, sshexec, capsys):
try:
sshexec.logged(
remote_funcs.perform_initial_checks,
kwargs=dict(mail_domain=None),
)
except sshexec.FuncError as e:
assert "remote_funcs.py" in str(e)
assert "AssertionError" in str(e)
else:
pytest.fail("didn't raise exception")
def test_remote(remote, imap_or_smtp):
lineproducer = remote.iter_output(imap_or_smtp.logcmd)

View File

@@ -0,0 +1,50 @@
import pytest
from cmdeploy import remote_funcs
from cmdeploy.dns import check_initial_remote_data
class TestPerformInitialChecks:
@pytest.fixture
def mockdns(self, monkeypatch):
qdict = {
"A": {"some.domain": "1.1.1.1"},
"AAAA": {"some.domain": "fde5:cd7a:9e1c:3240:5a99:936f:cdac:53ae"},
"CNAME": {"mta-sts.some.domain": "some.domain"},
}.copy()
def query_dns(typ, domain):
try:
return qdict[typ][domain]
except KeyError:
return ""
monkeypatch.setattr(remote_funcs, query_dns.__name__, query_dns)
return qdict
def test_perform_initial_checks_ok1(self, mockdns):
remote_data = remote_funcs.perform_initial_checks("some.domain")
assert len(remote_data) == 7
@pytest.mark.parametrize("drop", ["A", "AAAA"])
def test_perform_initial_checks_with_one_of_A_AAAA(self, mockdns, drop):
del mockdns[drop]
remote_data = remote_funcs.perform_initial_checks("some.domain")
assert len(remote_data) == 7
assert not remote_data[drop]
l = []
res = check_initial_remote_data(remote_data, print=l.append)
assert res
assert not l
def test_perform_initial_checks_no_mta_sts(self, mockdns):
del mockdns["CNAME"]
remote_data = remote_funcs.perform_initial_checks("some.domain")
assert len(remote_data) == 4
assert not remote_data["MTA_STS"]
l = []
res = check_initial_remote_data(remote_data, print=l.append)
assert not res
assert len(l) == 2