-from collections import defaultdict
-from datetime import timedelta
-import dns.name
-from dns.name import Name
-from dns.node import Node
-from dns.rdata import Rdata
-from dns.rdataclass import RdataClass
-from dns.rdatatype import RdataType
-import dns.rdtypes.ANY.MX
-import dns.rdtypes.ANY.NS
-import dns.rdtypes.ANY.PTR
-import dns.rdtypes.ANY.SOA
-import dns.rdtypes.ANY.TXT
-import dns.rdtypes.IN.A
-import dns.rdtypes.IN.AAAA
-from dns.zone import Zone
-from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
-import socket
-from typing import Optional, Dict, List, Self, Tuple, DefaultDict
-
-
-IPAddress = IPv4Address | IPv6Address
-IPNetwork = IPv4Network | IPv6Network
-IPAddr = str | IPAddress | List[str | IPAddress]
-
-
-class NscNode:
- nsc_zone: 'NscZone'
- name: str
- node: Node
- _ttl: int
-
- def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
- self.nsc_zone = nsc_zone
- self.name = name
- self.node = nsc_zone.zone.find_node(name, create=True)
- self._ttl = nsc_zone._min_ttl
-
- def ttl(self, *args, **kwargs) -> Self:
- if not args and not kwargs:
- self._ttl = self.nsc_zone._min_ttl
- else:
- self._ttl = int(timedelta(*args, **kwargs).total_seconds())
- return self
-
- def _add(self, rec: Rdata) -> None:
- rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
- rds.add(rec, ttl=self._ttl)
-
- def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
- out = []
- for a in addrs:
- if not isinstance(a, list):
- a = [a]
- for b in a:
- if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
- out.append(b)
- else:
- out.append(ip_address(b))
- return out
-
- def _parse_name(self, name: str) -> Name:
- # FIXME: Names with escaped dots
- if '.' in name:
- return dns.name.from_text(name)
- else:
- return dns.name.from_text(name, origin=None)
-
- def _parse_names(self, names: str | List[str]) -> List[Name]:
- if isinstance(names, str):
- return [self._parse_name(names)]
- else:
- return [self._parse_name(n) for n in names]
-
- def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
- for a in self._parse_addrs(addrs):
- if isinstance(a, IPv4Address):
- self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
- else:
- self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
- if reverse:
- self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
- return self
-
- def MX(self, pri: int, name: str) -> Self:
- self._add(
- dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
- )
- return self
-
- def NS(self, names: str | List[str]) -> Self:
- # FIXME: Variadic?
- for name in self._parse_names(names):
- self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
- return self
-
- def TXT(self, text: str) -> Self:
- self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
- return self
-
- def PTR(self, target: Name | str) -> Self:
- self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
- return self
-
- def generic(self, typ: str, text: str) -> Self:
- self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
- return self
-
-
-class NscZoneConfig:
- admin_email: str
- refresh: timedelta
- retry: timedelta
- expire: timedelta
- min_ttl: timedelta
- origin_server: str
-
- default_config: Optional['NscZoneConfig'] = None
-
- def __init__(self,
- admin_email: Optional[str] = None,
- refresh: Optional[timedelta] = None,
- retry: Optional[timedelta] = None,
- expire: Optional[timedelta] = None,
- min_ttl: Optional[timedelta] = None,
- origin_server: Optional[str] = None,
- inherit_config: Optional['NscZoneConfig'] = None,
- ) -> None:
- if inherit_config is None:
- inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
- self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
- self.refresh = refresh if refresh is not None else inherit_config.refresh
- self.retry = retry if retry is not None else inherit_config.retry
- self.expire = expire if expire is not None else inherit_config.expire
- self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
- self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
-
- def finalize(self) -> Self:
- if not self.origin_server:
- self.origin_server = socket.getfqdn()
- if not self.admin_email:
- self.admin_email = f'hostmaster@{self.origin_server}'
- return self
-
-
-NscZoneConfig.default_config = NscZoneConfig(
- admin_email="",
- refresh=timedelta(hours=8),
- retry=timedelta(hours=2),
- expire=timedelta(days=14),
- min_ttl=timedelta(days=1),
- origin_server="",
-)
-
-
-class NscZone:
- nsc: 'Nsc'
- name: str
- zone: Zone
- _min_ttl: int
- reverse_for: Optional[IPNetwork]
-
- def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
- self.nsc = nsc
- self.name = name
- self.config = NscZoneConfig(**kwargs).finalize()
- self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
- self._min_ttl = int(self.config.min_ttl.total_seconds())
- self.reverse_for = reverse_for
-
- conf = self.config
- root = self[""]
- root._add(
- dns.rdtypes.ANY.SOA.SOA(
- RdataClass.IN, RdataType.SOA,
- mname=conf.origin_server,
- rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
- serial=12345,
- refresh=int(conf.refresh.total_seconds()),
- retry=int(conf.retry.total_seconds()),
- expire=int(conf.expire.total_seconds()),
- minimum=int(conf.min_ttl.total_seconds()),
- )
- )
-
- def n(self, name: str) -> NscNode:
- return NscNode(self, name)
-
- def __getitem__(self, name: str) -> NscNode:
- return NscNode(self, name)
-
- def host(self, name: str, *args, reverse: bool = True) -> NscNode:
- n = NscNode(self, name)
- n.A(*args, reverse=reverse)
- return n
-
- def dump(self) -> None:
- # Could use self.zone.to_file(sys.stdout), but we want better formatting
- print(f'; Zone file for {self.name}')
- last_name = None
- for name, ttl, rec in self.zone.iterate_rdatas():
- if name == last_name:
- print_name = ""
- else:
- print_name = name
- print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
- last_name = name
-
- def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
- # Called only for addresses from this reverse network
- assert self.reverse_for is not None
- parts = str(addr).split('.')
- parts = parts[self.reverse_for.prefixlen // 8:]
- name = '.'.join(reversed(parts))
- self.n(name).PTR(ptr_to)
-
- def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
- # Called only for addresses from this reverse network
- assert self.reverse_for is not None
- parts = addr.exploded.replace(':', "")
- parts = parts[self.reverse_for.prefixlen // 4:]
- name = '.'.join(reversed(parts))
- self.n(name).PTR(ptr_to)
-
-
-class Nsc:
- zones: Dict[str, NscZone]
- default_zone_config: NscZoneConfig
- ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
- ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
-
- def __init__(self, **kwargs) -> None:
- self.zones = {}
- self.default_zone_config = NscZoneConfig(**kwargs)
- self.ipv4_reverse = defaultdict(list)
- self.ipv6_reverse = defaultdict(list)
-
- def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
- if inherit_config is None:
- inherit_config = self.default_zone_config
- z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
- assert z.name not in self.zones
- self.zones[z.name] = z
- return z
-
- def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
- if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
- net = ip_network(net, strict=True)
- name = name or self._reverse_zone_name(net)
- return self.add_zone(name, reverse_for=net, **kwargs)
-
- def _reverse_zone_name(self, net: IPNetwork) -> str:
- if isinstance(net, IPv4Network):
- parts = str(net.network_address).split('.')
- out = parts[:net.prefixlen // 8]
- if net.prefixlen % 8 != 0:
- out.append(parts[len(out)] + '/' + str(net.prefixlen))
- return '.'.join(reversed(out)) + '.in-addr.arpa'
- elif isinstance(net, IPv6Network):
- assert net.prefixlen % 4 == 0
- nibbles = net.network_address.exploded.replace(':', "")
- nibbles = nibbles[:net.prefixlen // 4]
- return '.'.join(reversed(nibbles)) + '.ip6.arpa'
- else:
- raise NotImplementedError()
-
- def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
- if isinstance(addr, IPv4Address):
- self.ipv4_reverse[addr].append(ptr_to)
- else:
- self.ipv6_reverse[addr].append(ptr_to)
-
- def dump_reverse(self) -> None:
- print('### Requests for reverse mappings ###')
- for ipa4, name in sorted(self.ipv4_reverse.items()):
- print(f'{ipa4}\t{name}')
- for ipa6, name in sorted(self.ipv6_reverse.items()):
- print(f'{ipa6}\t{name}')
-
- def fill_reverse(self) -> None:
- for z in self.zones.values():
- if z.reverse_for is not None:
- if isinstance(z.reverse_for, IPv4Network):
- for addr4, ptr_list in self.ipv4_reverse.items():
- if addr4 in z.reverse_for:
- for ptr_to in ptr_list:
- z._add_ipv4_reverse(addr4, ptr_to)
- else:
- for addr6, ptr_list in self.ipv6_reverse.items():
- if addr6 in z.reverse_for:
- for ptr_to in ptr_list:
- z._add_ipv6_reverse(addr6, ptr_to)
-
- def dump(self) -> None:
- for z in self.zones.values():
- z.dump()