1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 from pathlib import Path
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
27 IPAddress = IPv4Address | IPv6Address
28 IPNetwork = IPv4Network | IPv6Network
29 IPAddr = str | IPAddress | List[str | IPAddress]
33 nsc_zone: 'NscZonePrimary'
38 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
39 self.nsc_zone = nsc_zone
41 self.node = nsc_zone.zone.find_node(name, create=True)
42 self._ttl = nsc_zone._min_ttl
44 def ttl(self, *args, **kwargs) -> Self:
45 if not args and not kwargs:
46 self._ttl = self.nsc_zone._min_ttl
48 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
51 def _add(self, rec: Rdata) -> None:
52 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
53 rds.add(rec, ttl=self._ttl)
55 def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
58 if not isinstance(a, list):
61 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
64 out.append(ip_address(b))
67 def _parse_name(self, name: str) -> Name:
68 # FIXME: Names with escaped dots
70 return dns.name.from_text(name)
72 return dns.name.from_text(name, origin=None)
74 def _parse_names(self, names: str | List[str]) -> List[Name]:
75 if isinstance(names, str):
76 return [self._parse_name(names)]
78 return [self._parse_name(n) for n in names]
80 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
81 for a in self._parse_addrs(addrs):
82 if isinstance(a, IPv4Address):
83 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
85 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
87 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
90 def MX(self, pri: int, name: str) -> Self:
92 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
96 def NS(self, names: str | List[str]) -> Self:
98 for name in self._parse_names(names):
99 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
102 def TXT(self, text: str) -> Self:
103 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
106 def PTR(self, target: Name | str) -> Self:
107 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
110 def generic(self, typ: str, text: str) -> Self:
111 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
123 default_config: Optional['NscZoneConfig'] = None
126 admin_email: Optional[str] = None,
127 refresh: Optional[timedelta] = None,
128 retry: Optional[timedelta] = None,
129 expire: Optional[timedelta] = None,
130 min_ttl: Optional[timedelta] = None,
131 origin_server: Optional[str] = None,
132 inherit_config: Optional['NscZoneConfig'] = None,
134 if inherit_config is None:
135 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
136 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
137 self.refresh = refresh if refresh is not None else inherit_config.refresh
138 self.retry = retry if retry is not None else inherit_config.retry
139 self.expire = expire if expire is not None else inherit_config.expire
140 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
141 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
143 def finalize(self) -> Self:
144 if not self.origin_server:
145 self.origin_server = socket.getfqdn()
146 if not self.admin_email:
147 self.admin_email = f'hostmaster@{self.origin_server}'
151 NscZoneConfig.default_config = NscZoneConfig(
153 refresh=timedelta(hours=8),
154 retry=timedelta(hours=2),
155 expire=timedelta(days=14),
156 min_ttl=timedelta(days=1),
165 def __init__(self) -> None:
169 def load(self, file: Path) -> None:
171 with open(file) as f:
173 assert isinstance(js, dict)
175 self.serial = js['serial']
177 self.hash = js['hash']
178 except FileNotFoundError:
181 def save(self, file: Path) -> None:
182 new_file = Path(str(file) + '.new')
183 with open(new_file, 'w') as f:
185 'serial': self.serial,
188 json.dump(js, f, indent=4, sort_keys=True)
189 new_file.replace(file)
192 class ZoneType(Enum):
200 safe_name: str # For use in file names
202 reverse_for: Optional[IPNetwork]
207 reverse_for: Optional[IPNetwork],
211 self.safe_name = name.replace('/', '@')
212 self.config = NscZoneConfig(**kwargs).finalize()
213 self.reverse_for = reverse_for
215 def process(self) -> None:
219 class NscZonePrimary(NscZone):
225 prev_state: NscZoneState
227 def __init__(self, *args, **kwargs) -> None:
228 super().__init__(*args, **kwargs)
230 self.zone_type = ZoneType.primary
231 self.zone_file = self.nsc.zone_dir / self.safe_name
232 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
234 self.state = NscZoneState()
235 self.prev_state = NscZoneState()
236 self.prev_state.load(self.state_file)
238 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
239 self._min_ttl = int(self.config.min_ttl.total_seconds())
242 def update_soa(self) -> None:
244 soa = dns.rdtypes.ANY.SOA.SOA(
245 RdataClass.IN, RdataType.SOA,
246 mname=conf.origin_server,
247 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
248 serial=self.state.serial,
249 refresh=int(conf.refresh.total_seconds()),
250 retry=int(conf.retry.total_seconds()),
251 expire=int(conf.expire.total_seconds()),
252 minimum=int(conf.min_ttl.total_seconds()),
254 self.zone.delete_rdataset("", RdataType.SOA)
257 def n(self, name: str) -> NscNode:
258 return NscNode(self, name)
260 def __getitem__(self, name: str) -> NscNode:
261 return NscNode(self, name)
263 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
264 n = NscNode(self, name)
265 n.A(*args, reverse=reverse)
268 def dump(self, file: Optional[TextIO] = None) -> None:
269 # Could use self.zone.to_file(sys.stdout), but we want better formatting
270 file = file or sys.stdout
271 file.write(f'; Zone file for {self.name}\n\n')
273 for name, ttl, rec in self.zone.iterate_rdatas():
274 if name == last_name:
278 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
281 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
282 assert isinstance(self.reverse_for, IPv4Network)
283 parts = str(addr).split('.')
284 parts = parts[self.reverse_for.prefixlen // 8:]
285 name = '.'.join(reversed(parts))
286 self.n(name).PTR(ptr_to)
288 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
289 assert isinstance(self.reverse_for, IPv6Network)
290 parts = addr.exploded.replace(':', "")
291 parts = parts[self.reverse_for.prefixlen // 4:]
292 name = '.'.join(reversed(parts))
293 self.n(name).PTR(ptr_to)
295 def gen_hash(self) -> None:
297 for name, ttl, rec in self.zone.iterate_rdatas():
298 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
299 sha.update(text.encode('us-ascii'))
300 self.state.hash = sha.hexdigest()[:16]
302 def gen_serial(self) -> None:
303 prev = self.prev_state.serial
304 if self.state.hash == self.prev_state.hash and prev > 0:
305 self.state.serial = self.prev_state.serial
307 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
309 self.state.serial = base + 1
311 self.state.serial = prev + 1
312 if prev >= base + 99:
313 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
315 def process(self) -> None:
316 if self.zone_type == ZoneType.primary:
320 def write_zone(self) -> None:
322 new_file = Path(str(self.zone_file) + '.new')
323 with open(new_file, 'w') as f:
325 new_file.replace(self.zone_file)
327 def write_state(self) -> None:
328 self.state.save(self.state_file)
331 class NscZoneSecondary(NscZone):
332 primary_server: IPAddress
335 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
336 super().__init__(*args, **kwargs)
337 self.zone_type = ZoneType.secondary
338 self.primary_server = primary_server
339 self.secondary_file = self.nsc.secondary_dir / self.safe_name
344 zones: Dict[str, NscZone]
345 default_zone_config: NscZoneConfig
346 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
347 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
353 def __init__(self, directory: str = '.', **kwargs) -> None:
354 self.start_time = datetime.now()
356 self.default_zone_config = NscZoneConfig(**kwargs)
357 self.ipv4_reverse = defaultdict(list)
358 self.ipv6_reverse = defaultdict(list)
360 self.root_dir = Path(directory)
361 self.state_dir = self.root_dir / 'state'
362 self.state_dir.mkdir(parents=True, exist_ok=True)
363 self.zone_dir = self.root_dir / 'zone'
364 self.zone_dir.mkdir(parents=True, exist_ok=True)
365 self.secondary_dir = self.root_dir / 'secondary'
366 self.secondary_dir.mkdir(parents=True, exist_ok=True)
369 name: Optional[str] = None,
370 reverse_for: str | IPNetwork | None = None,
371 secondary_for: str | IPAddress | None = None,
372 inherit_config: Optional[NscZoneConfig] = None,
374 if inherit_config is None:
375 inherit_config = self.default_zone_config
377 if reverse_for is not None:
378 if isinstance(reverse_for, str):
379 reverse_for = ip_network(reverse_for, strict=True)
380 name = name or self._reverse_zone_name(reverse_for)
381 assert name is not None
382 assert name not in self.zones
385 if secondary_for is None:
386 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
388 if isinstance(secondary_for, str):
389 secondary_for = ip_address(secondary_for)
390 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=secondary_for, inherit_config=inherit_config, **kwargs)
395 def _reverse_zone_name(self, net: IPNetwork) -> str:
396 if isinstance(net, IPv4Network):
397 parts = str(net.network_address).split('.')
398 out = parts[:net.prefixlen // 8]
399 if net.prefixlen % 8 != 0:
400 out.append(parts[len(out)] + '/' + str(net.prefixlen))
401 return '.'.join(reversed(out)) + '.in-addr.arpa'
402 elif isinstance(net, IPv6Network):
403 assert net.prefixlen % 4 == 0
404 nibbles = net.network_address.exploded.replace(':', "")
405 nibbles = nibbles[:net.prefixlen // 4]
406 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
408 raise NotImplementedError()
410 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
411 if isinstance(addr, IPv4Address):
412 self.ipv4_reverse[addr].append(ptr_to)
414 self.ipv6_reverse[addr].append(ptr_to)
416 def dump_reverse(self) -> None:
417 print('### Requests for reverse mappings ###')
418 for ipa4, name in sorted(self.ipv4_reverse.items()):
419 print(f'{ipa4}\t{name}')
420 for ipa6, name in sorted(self.ipv6_reverse.items()):
421 print(f'{ipa6}\t{name}')
423 def fill_reverse(self) -> None:
424 for z in self.zones.values():
425 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
426 if isinstance(z.reverse_for, IPv4Network):
427 for addr4, ptr_list in self.ipv4_reverse.items():
428 if addr4 in z.reverse_for:
429 for ptr_to in ptr_list:
430 z._add_ipv4_reverse(addr4, ptr_to)
432 for addr6, ptr_list in self.ipv6_reverse.items():
433 if addr6 in z.reverse_for:
434 for ptr_to in ptr_list:
435 z._add_ipv6_reverse(addr6, ptr_to)
437 def get_zones(self) -> List[NscZone]:
438 return [self.zones[k] for k in sorted(self.zones.keys())]
440 def process(self) -> None:
442 for z in self.get_zones():