DNS: use local dig if ssh fails

This commit is contained in:
missytake
2023-12-19 15:30:20 +01:00
parent 5ff98a571c
commit 36a4381484
2 changed files with 37 additions and 31 deletions

View File

@@ -211,10 +211,13 @@ class Out:
color = "red" if red else ("green" if green else None)
print(colored(msg, color), file=file)
def shell_output(self, arg, no_print=False):
def shell_output(self, arg, no_print=False, timeout=10):
if not no_print:
self(f"[$ {arg}]", file=sys.stderr)
return subprocess.check_output(arg, shell=True).decode()
output = subprocess.STDOUT
else:
output = subprocess.DEVNULL
return subprocess.check_output(arg, shell=True, timeout=timeout, stderr=output).decode()
def check_call(self, arg, env=None, quiet=False):
if not quiet:

View File

@@ -11,14 +11,23 @@ class DNS:
def __init__(self, out, mail_domain):
self.session = requests.Session()
self.out = out
self.ssh = f"ssh root@{mail_domain}"
self.ssh = f"ssh root@{mail_domain} -- "
try:
self.shell(f"unbound-control flush {mail_domain}")
self.shell(f"unbound-control flush {mail_domain}", retry_local=False, warn_reachable=True)
except subprocess.CalledProcessError:
pass
def shell(self, cmd):
return self.out.shell_output(f"{self.ssh} -- {cmd}", no_print=True)
def shell(self, cmd, retry_local=False, warn_reachable=False):
try:
return self.out.shell_output(f"{self.ssh}{cmd}", no_print=True, timeout=3)
except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
str(e)
if warn_reachable and ("exit status 255" in str(e) or "timed out" in str(e)):
print(f"Warning: can't reach the server with: {self.ssh[:-4]}")
if retry_local:
return self.out.shell_output(f"{cmd}", no_print=True)
if e == subprocess.CalledProcessError:
raise
def get_ipv4(self):
cmd = "ip a | grep 'inet ' | grep 'scope global' | grep -oE '[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}' | head -1"
@@ -28,26 +37,15 @@ class DNS:
cmd = "ip a | grep inet6 | grep 'scope global' | sed -e 's#/64 scope global##' | sed -e 's#inet6##'"
return self.shell(cmd).strip()
def get(self, typ: str, domain: str) -> str:
def get(self, typ: str, domain: str) -> str | None:
"""Get a DNS entry"""
dig_result = self.shell(f"dig {typ} {domain}")
dig_result = self.shell(f"dig {typ} {domain}", retry_local=True)
line_num = 0
for line in dig_result.splitlines():
line_num += 1
if line.strip() == ";; ANSWER SECTION:":
return dig_result.splitlines()[line_num].split("\t")[-1]
def resolve(self, domain: str) -> str:
result = self.get("A", domain)
try:
assert ipaddress.ip_address(result).version == 4
except ValueError:
result = self.get("CNAME", domain)
return self.resolve(result)
except AssertionError:
result = self.get("AAAA", domain)
return result
def check_ptr_record(self, ip: str, mail_domain) -> str:
"""Check the PTR record for an IPv4 or IPv6 address."""
result = self.get("-x", ip)
@@ -96,11 +94,13 @@ def show_dns(args, out):
)
.strip()
)
if args.zonefile:
try:
with open(args.zonefile, "w+") as zf:
zf.write(zonefile)
print(f"DNS records successfully written to: {args.zonefile}")
print(f"DNS records successfully written to: {args.zonefile}")
return
except AttributeError:
pass
started_dkim_parsing = False
for line in zonefile.splitlines():
line = line.format(
@@ -178,17 +178,19 @@ def show_dns(args, out):
def check_necessary_dns(out, mail_domain):
"""Check whether $mail_domain and mta-sts.$mail_domain resolve."""
dns = DNS(out, mail_domain)
try:
ipaddress = dns.resolve(mail_domain)
mta_ipadress = dns.resolve("mta-sts." + mail_domain)
except subprocess.CalledProcessError:
ipaddress = None
mta_ipadress = None
ipv4 = dns.get("A", mail_domain)
ipv6 = dns.get("AAAA", mail_domain)
mta_entry = dns.get("CNAME", "mta-sts." + mail_domain)
mta_ip = dns.get("A", mta_entry)
if not mta_ip:
mta_ip = dns.get("AAAA", mta_entry)
to_print = []
if not ipaddress:
to_print.append(f"\tA\t{mail_domain}.\t\t<your server's IPv4 address>")
elif not mta_ipadress or mta_ipadress != ipaddress:
to_print.append(f"\tCNAME\tmta-sts.{mail_domain}.\t{mail_domain}.")
if not (ipv4 or ipv6):
to_print.append(f"\t{mail_domain}.\t\t\tA<your server's IPv4 address>")
if not mta_ip or not (mta_ip == ipv4 or mta_ip == ipv6):
#print(mta_entry, mta_ip)
#print(ipv4, ipv6)
to_print.append(f"\tmta-sts.{mail_domain}.\tCNAME\t{mail_domain}.")
if to_print:
to_print.insert(
0,
@@ -198,4 +200,5 @@ def check_necessary_dns(out, mail_domain):
print(line)
print()
else:
print("All necessary DNS entries seem to be set.")
return True