diff --git a/cmdeploy/src/cmdeploy/cmdeploy.py b/cmdeploy/src/cmdeploy/cmdeploy.py index 36917257..c68e877a 100644 --- a/cmdeploy/src/cmdeploy/cmdeploy.py +++ b/cmdeploy/src/cmdeploy/cmdeploy.py @@ -15,7 +15,7 @@ from pathlib import Path from termcolor import colored from chatmaild.config import read_config, write_initial_config -from cmdeploy.dns import resolve, resolve_mx, get +from cmdeploy.dns import DNS # @@ -38,8 +38,9 @@ def init_cmd(args, out): else: write_initial_config(args.inipath, args.chatmail_domain) out.green(f"created config file for {args.chatmail_domain} in {args.inipath}") - ipaddress = resolve(args.chatmail_domain) - mta_ipadress = resolve("mta-sts." + args.chatmail_domain) + dns = DNS() + ipaddress = dns.resolve(args.chatmail_domain) + mta_ipadress = dns.resolve("mta-sts." + args.chatmail_domain) entries = 0 to_print = ["Now you should add %dnsentry% at your DNS provider:\n"] if not ipaddress: @@ -81,8 +82,9 @@ def run_cmd(args, out): cmd = f"{pyinf} --ssh-user root {args.config.mail_domain} {deploy_path}" mail_domain = args.config.mail_domain - root_ip = resolve(mail_domain) - mta_ip = resolve(f"mta-sts.{mail_domain}") + dns = DNS() + root_ip = dns.resolve(mail_domain) + mta_ip = dns.resolve(f"mta-sts.{mail_domain}") if not root_ip or root_ip != mta_ip: out.red("DNS entries missing. Show instructions with:\n") print(f"\tcmdeploy init {mail_domain}\n") @@ -95,6 +97,7 @@ def dns_cmd(args, out): template = importlib.resources.files(__package__).joinpath("chatmail.zone.f") ssh = f"ssh root@{args.config.mail_domain}" get_ipv6 = "ip a | grep inet6 | grep 'scope global' | sed -e 's#/64 scope global##' | sed -e 's#inet6##'" + dns = DNS() def read_dkim_entries(entry): lines = [] @@ -123,29 +126,29 @@ def dns_cmd(args, out): ).strip() if " MX " in line: domain, typ, prio, value = line.split() - current = resolve_mx(domain[:-1]) + current = dns.resolve_mx(domain[:-1]) if not current[0]: print(line) elif current[1] != value: print(line.replace(prio, str(current[0] + 1))) if " SRV " in line: domain, typ, prio, weight, port, value = line.split() - current = get("SRV", domain[:-1]) + current = dns.get("SRV", domain[:-1]) if current != f"{prio} {weight} {port} {value}": print(line) if " AAAA " in line: domain, value = line.split(" AAAA ") - current = get("AAAA", domain.strip()[:-1]) + current = dns.get("AAAA", domain.strip()[:-1]) if current != value: print(line) if " CAA " in line: domain, value = line.split(" IN CAA ") - current = get("CAA", domain.strip()[:-1]) + current = dns.get("CAA", domain.strip()[:-1]) if current != value: print(line) if " TXT " in line: domain, value = line.split(" TXT ") - current = get("TXT", domain.strip()[:-1]) + current = dns.get("TXT", domain.strip()[:-1]) if domain.startswith("_mta-sts."): if current.split("id=")[0] == value.split("id=")[0]: continue @@ -154,7 +157,7 @@ def dns_cmd(args, out): if " IN TXT ( " in line: line += f.read() domain, data = line.split(" IN TXT ") - current = get("TXT", domain.strip()[:-1]).replace('" "', '"\n "') + current = dns.get("TXT", domain.strip()[:-1]).replace('" "', '"\n "') current = f"( {current} )" if current.replace(";", "\\;") != data: print( diff --git a/cmdeploy/src/cmdeploy/dns.py b/cmdeploy/src/cmdeploy/dns.py index 0d6fdb56..617c16d6 100644 --- a/cmdeploy/src/cmdeploy/dns.py +++ b/cmdeploy/src/cmdeploy/dns.py @@ -12,48 +12,50 @@ dns_types = { } -def get(typ: str, domain: str) -> str: - """Get a DNS entry""" - r = requests.get( - url, - params={"name": domain, "type": typ}, - headers={"accept": "application/dns-json"}, - ) +class DNS: + def __init__(self): + self.session = requests.Session() - j = r.json() - if "Answer" in j: - for answer in j["Answer"]: - if answer["type"] == dns_types[typ]: - return answer["data"] - return "" + def get(self, typ: str, domain: str) -> str: + """Get a DNS entry""" + r = self.session.get( + url, + params={"name": domain, "type": typ}, + headers={"accept": "application/dns-json"}, + ) + j = r.json() + if "Answer" in j: + for answer in j["Answer"]: + if answer["type"] == dns_types[typ]: + return answer["data"] + return "" -def resolve_mx(domain: str) -> (str, str): - """Resolve an MX entry""" - r = requests.get( - url, - params={"name": domain, "type": "MX"}, - headers={"accept": "application/dns-json"}, - ) + def resolve_mx(self, domain: str) -> (str, str): + """Resolve an MX entry""" + r = self.session.get( + url, + params={"name": domain, "type": "MX"}, + headers={"accept": "application/dns-json"}, + ) - j = r.json() - if "Answer" in j: - result = (0, None) - for answer in j["Answer"]: - if answer["type"] == dns_types["MX"]: - prio, server_name = answer["data"].split() - if int(prio) > result[0]: - result = (int(prio), server_name) + j = r.json() + if "Answer" in j: + result = (0, None) + for answer in j["Answer"]: + if answer["type"] == dns_types["MX"]: + prio, server_name = answer["data"].split() + if int(prio) > result[0]: + result = (int(prio), server_name) + return result + return None, None + + def resolve(self, domain: str) -> str: + result = self.get("A", domain) + if not result: + result = self.get("CNAME", domain) + if result: + result = self.get("A", result[:-1]) + if not result: + result = self.get("AAAA", domain) return result - return None, None - - -def resolve(domain: str) -> str: - result = get("A", domain) - if not result: - result = get("CNAME", domain) - if result: - result = get("A", result[:-1]) - if not result: - result = get("AAAA", domain) - return result