DNS: fix CNAME resolving, don't print ssh commands for DNS requests

This commit is contained in:
missytake
2023-12-14 18:12:56 +01:00
parent 146def2f06
commit 03aab4043c
2 changed files with 21 additions and 11 deletions

View File

@@ -215,8 +215,9 @@ class Out:
color = "red" if red else ("green" if green else None)
print(colored(msg, color), file=file)
def shell_output(self, arg):
self(f"[$ {arg}]", file=sys.stderr)
def shell_output(self, arg, no_print=False):
if not no_print:
self(f"[$ {arg}]", file=sys.stderr)
return subprocess.check_output(arg, shell=True).decode()
def check_call(self, arg, env=None, quiet=False):

View File

@@ -1,3 +1,5 @@
import ipaddress
import requests
import importlib
import subprocess
@@ -10,19 +12,25 @@ class DNS:
self.session = requests.Session()
self.out = out
self.ssh = f"ssh root@{mail_domain}"
self.out.shell_output(f"{self.ssh} -- unbound-control flush {mail_domain}")
try:
self.shell(f"unbound-control flush {mail_domain}")
except subprocess.CalledProcessError:
pass
def shell(self, cmd):
return self.out.shell_output(f"{self.ssh} -- {cmd}", no_print=True)
def get_ipv4(self):
cmd = "ip a | grep 'inet ' | grep 'scope global' | grep -oE '[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}' | head -1"
return self.out.shell_output(f"{self.ssh} -- {cmd}").strip()
return self.shell(cmd).strip()
def get_ipv6(self):
cmd = "ip a | grep inet6 | grep 'scope global' | sed -e 's#/64 scope global##' | sed -e 's#inet6##'"
return self.out.shell_output(f"{self.ssh} -- {cmd}").strip()
return self.shell(cmd).strip()
def get(self, typ: str, domain: str) -> str:
"""Get a DNS entry"""
dig_result = self.out.shell_output(f"{self.ssh} -- dig {typ} {domain}")
dig_result = self.shell(f"dig {typ} {domain}")
line_num = 0
for line in dig_result.splitlines():
line_num += 1
@@ -31,12 +39,13 @@ class DNS:
def resolve(self, domain: str) -> str:
result = self.get("A", domain)
if not result:
try:
assert ipaddress.ip_address(result).version == 4
except ValueError:
result = self.get("CNAME", domain)
if result:
result = self.get("A", result[:-1])
if not result:
result = self.get("AAAA", domain)
return self.resolve(result)
except AssertionError:
result = self.get("AAAA", domain)
return result
def check_ptr_record(self, ip: str, mail_domain) -> str: