mirror of
https://github.com/chatmail/relay.git
synced 2026-05-10 16:04:37 +00:00
DNS: use local dig if ssh fails
This commit is contained in:
@@ -211,10 +211,13 @@ class Out:
|
||||
color = "red" if red else ("green" if green else None)
|
||||
print(colored(msg, color), file=file)
|
||||
|
||||
def shell_output(self, arg, no_print=False):
|
||||
def shell_output(self, arg, no_print=False, timeout=10):
|
||||
if not no_print:
|
||||
self(f"[$ {arg}]", file=sys.stderr)
|
||||
return subprocess.check_output(arg, shell=True).decode()
|
||||
output = subprocess.STDOUT
|
||||
else:
|
||||
output = subprocess.DEVNULL
|
||||
return subprocess.check_output(arg, shell=True, timeout=timeout, stderr=output).decode()
|
||||
|
||||
def check_call(self, arg, env=None, quiet=False):
|
||||
if not quiet:
|
||||
|
||||
@@ -11,14 +11,23 @@ class DNS:
|
||||
def __init__(self, out, mail_domain):
|
||||
self.session = requests.Session()
|
||||
self.out = out
|
||||
self.ssh = f"ssh root@{mail_domain}"
|
||||
self.ssh = f"ssh root@{mail_domain} -- "
|
||||
try:
|
||||
self.shell(f"unbound-control flush {mail_domain}")
|
||||
self.shell(f"unbound-control flush {mail_domain}", retry_local=False, warn_reachable=True)
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
def shell(self, cmd):
|
||||
return self.out.shell_output(f"{self.ssh} -- {cmd}", no_print=True)
|
||||
def shell(self, cmd, retry_local=False, warn_reachable=False):
|
||||
try:
|
||||
return self.out.shell_output(f"{self.ssh}{cmd}", no_print=True, timeout=3)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
|
||||
str(e)
|
||||
if warn_reachable and ("exit status 255" in str(e) or "timed out" in str(e)):
|
||||
print(f"Warning: can't reach the server with: {self.ssh[:-4]}")
|
||||
if retry_local:
|
||||
return self.out.shell_output(f"{cmd}", no_print=True)
|
||||
if e == subprocess.CalledProcessError:
|
||||
raise
|
||||
|
||||
def get_ipv4(self):
|
||||
cmd = "ip a | grep 'inet ' | grep 'scope global' | grep -oE '[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}' | head -1"
|
||||
@@ -28,26 +37,15 @@ class DNS:
|
||||
cmd = "ip a | grep inet6 | grep 'scope global' | sed -e 's#/64 scope global##' | sed -e 's#inet6##'"
|
||||
return self.shell(cmd).strip()
|
||||
|
||||
def get(self, typ: str, domain: str) -> str:
|
||||
def get(self, typ: str, domain: str) -> str | None:
|
||||
"""Get a DNS entry"""
|
||||
dig_result = self.shell(f"dig {typ} {domain}")
|
||||
dig_result = self.shell(f"dig {typ} {domain}", retry_local=True)
|
||||
line_num = 0
|
||||
for line in dig_result.splitlines():
|
||||
line_num += 1
|
||||
if line.strip() == ";; ANSWER SECTION:":
|
||||
return dig_result.splitlines()[line_num].split("\t")[-1]
|
||||
|
||||
def resolve(self, domain: str) -> str:
|
||||
result = self.get("A", domain)
|
||||
try:
|
||||
assert ipaddress.ip_address(result).version == 4
|
||||
except ValueError:
|
||||
result = self.get("CNAME", domain)
|
||||
return self.resolve(result)
|
||||
except AssertionError:
|
||||
result = self.get("AAAA", domain)
|
||||
return result
|
||||
|
||||
def check_ptr_record(self, ip: str, mail_domain) -> str:
|
||||
"""Check the PTR record for an IPv4 or IPv6 address."""
|
||||
result = self.get("-x", ip)
|
||||
@@ -96,11 +94,13 @@ def show_dns(args, out):
|
||||
)
|
||||
.strip()
|
||||
)
|
||||
if args.zonefile:
|
||||
try:
|
||||
with open(args.zonefile, "w+") as zf:
|
||||
zf.write(zonefile)
|
||||
print(f"DNS records successfully written to: {args.zonefile}")
|
||||
print(f"DNS records successfully written to: {args.zonefile}")
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
started_dkim_parsing = False
|
||||
for line in zonefile.splitlines():
|
||||
line = line.format(
|
||||
@@ -178,17 +178,19 @@ def show_dns(args, out):
|
||||
def check_necessary_dns(out, mail_domain):
|
||||
"""Check whether $mail_domain and mta-sts.$mail_domain resolve."""
|
||||
dns = DNS(out, mail_domain)
|
||||
try:
|
||||
ipaddress = dns.resolve(mail_domain)
|
||||
mta_ipadress = dns.resolve("mta-sts." + mail_domain)
|
||||
except subprocess.CalledProcessError:
|
||||
ipaddress = None
|
||||
mta_ipadress = None
|
||||
ipv4 = dns.get("A", mail_domain)
|
||||
ipv6 = dns.get("AAAA", mail_domain)
|
||||
mta_entry = dns.get("CNAME", "mta-sts." + mail_domain)
|
||||
mta_ip = dns.get("A", mta_entry)
|
||||
if not mta_ip:
|
||||
mta_ip = dns.get("AAAA", mta_entry)
|
||||
to_print = []
|
||||
if not ipaddress:
|
||||
to_print.append(f"\tA\t{mail_domain}.\t\t<your server's IPv4 address>")
|
||||
elif not mta_ipadress or mta_ipadress != ipaddress:
|
||||
to_print.append(f"\tCNAME\tmta-sts.{mail_domain}.\t{mail_domain}.")
|
||||
if not (ipv4 or ipv6):
|
||||
to_print.append(f"\t{mail_domain}.\t\t\tA<your server's IPv4 address>")
|
||||
if not mta_ip or not (mta_ip == ipv4 or mta_ip == ipv6):
|
||||
#print(mta_entry, mta_ip)
|
||||
#print(ipv4, ipv6)
|
||||
to_print.append(f"\tmta-sts.{mail_domain}.\tCNAME\t{mail_domain}.")
|
||||
if to_print:
|
||||
to_print.insert(
|
||||
0,
|
||||
@@ -198,4 +200,5 @@ def check_necessary_dns(out, mail_domain):
|
||||
print(line)
|
||||
print()
|
||||
else:
|
||||
print("All necessary DNS entries seem to be set.")
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user