1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 from pathlib import Path
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
28 from nsconfig.daemon import NscDaemon
31 IPAddress = IPv4Address | IPv6Address
32 IPNetwork = IPv4Network | IPv6Network
33 IPAddr = str | IPAddress | List[str | IPAddress]
37 nsc_zone: 'NscZonePrimary'
42 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
43 self.nsc_zone = nsc_zone
45 self.node = nsc_zone.zone.find_node(name, create=True)
46 self._ttl = nsc_zone._min_ttl
48 def ttl(self, *args, **kwargs) -> Self:
49 if not args and not kwargs:
50 self._ttl = self.nsc_zone._min_ttl
52 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
55 def _add(self, rec: Rdata) -> None:
56 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
57 rds.add(rec, ttl=self._ttl)
59 def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
62 if not isinstance(a, list):
65 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
68 out.append(ip_address(b))
71 def _parse_name(self, name: str) -> Name:
72 # FIXME: Names with escaped dots
74 return dns.name.from_text(name)
76 return dns.name.from_text(name, origin=None)
78 def _parse_names(self, names: str | List[str]) -> List[Name]:
79 if isinstance(names, str):
80 return [self._parse_name(names)]
82 return [self._parse_name(n) for n in names]
84 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
85 for a in self._parse_addrs(addrs):
86 if isinstance(a, IPv4Address):
87 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
89 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
91 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
94 def MX(self, pri: int, name: str) -> Self:
96 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
100 def NS(self, names: str | List[str]) -> Self:
102 for name in self._parse_names(names):
103 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
106 def TXT(self, text: str) -> Self:
107 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
110 def PTR(self, target: Name | str) -> Self:
111 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
114 def generic(self, typ: str, text: str) -> Self:
115 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
127 default_config: Optional['NscZoneConfig'] = None
130 admin_email: Optional[str] = None,
131 refresh: Optional[timedelta] = None,
132 retry: Optional[timedelta] = None,
133 expire: Optional[timedelta] = None,
134 min_ttl: Optional[timedelta] = None,
135 origin_server: Optional[str] = None,
136 inherit_config: Optional['NscZoneConfig'] = None,
138 if inherit_config is None:
139 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
140 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
141 self.refresh = refresh if refresh is not None else inherit_config.refresh
142 self.retry = retry if retry is not None else inherit_config.retry
143 self.expire = expire if expire is not None else inherit_config.expire
144 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
145 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
147 def finalize(self) -> Self:
148 if not self.origin_server:
149 self.origin_server = socket.getfqdn()
150 if not self.admin_email:
151 self.admin_email = f'hostmaster@{self.origin_server}'
155 NscZoneConfig.default_config = NscZoneConfig(
157 refresh=timedelta(hours=8),
158 retry=timedelta(hours=2),
159 expire=timedelta(days=14),
160 min_ttl=timedelta(days=1),
169 def __init__(self) -> None:
173 def load(self, file: Path) -> None:
175 with open(file) as f:
177 assert isinstance(js, dict)
179 self.serial = js['serial']
181 self.hash = js['hash']
182 except FileNotFoundError:
185 def save(self, file: Path) -> None:
186 new_file = Path(str(file) + '.new')
187 with open(new_file, 'w') as f:
189 'serial': self.serial,
192 json.dump(js, f, indent=4, sort_keys=True)
193 new_file.replace(file)
196 class ZoneType(Enum):
204 safe_name: str # For use in file names
206 reverse_for: Optional[IPNetwork]
211 reverse_for: Optional[IPNetwork],
215 self.safe_name = name.replace('/', '@')
216 self.config = NscZoneConfig(**kwargs).finalize()
217 self.reverse_for = reverse_for
219 def process(self) -> None:
223 class NscZonePrimary(NscZone):
229 prev_state: NscZoneState
231 def __init__(self, *args, **kwargs) -> None:
232 super().__init__(*args, **kwargs)
234 self.zone_type = ZoneType.primary
235 self.zone_file = self.nsc.zone_dir / self.safe_name
236 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
238 self.state = NscZoneState()
239 self.prev_state = NscZoneState()
240 self.prev_state.load(self.state_file)
242 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
243 self._min_ttl = int(self.config.min_ttl.total_seconds())
246 def update_soa(self) -> None:
248 soa = dns.rdtypes.ANY.SOA.SOA(
249 RdataClass.IN, RdataType.SOA,
250 mname=conf.origin_server,
251 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
252 serial=self.state.serial,
253 refresh=int(conf.refresh.total_seconds()),
254 retry=int(conf.retry.total_seconds()),
255 expire=int(conf.expire.total_seconds()),
256 minimum=int(conf.min_ttl.total_seconds()),
258 self.zone.delete_rdataset("", RdataType.SOA)
261 def n(self, name: str) -> NscNode:
262 return NscNode(self, name)
264 def __getitem__(self, name: str) -> NscNode:
265 return NscNode(self, name)
267 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
268 n = NscNode(self, name)
269 n.A(*args, reverse=reverse)
272 def dump(self, file: Optional[TextIO] = None) -> None:
273 # Could use self.zone.to_file(sys.stdout), but we want better formatting
274 file = file or sys.stdout
275 file.write(f'; Zone file for {self.name}\n\n')
277 for name, ttl, rec in self.zone.iterate_rdatas():
278 if name == last_name:
282 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
285 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
286 assert isinstance(self.reverse_for, IPv4Network)
287 parts = str(addr).split('.')
288 parts = parts[self.reverse_for.prefixlen // 8:]
289 name = '.'.join(reversed(parts))
290 self.n(name).PTR(ptr_to)
292 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
293 assert isinstance(self.reverse_for, IPv6Network)
294 parts = addr.exploded.replace(':', "")
295 parts = parts[self.reverse_for.prefixlen // 4:]
296 name = '.'.join(reversed(parts))
297 self.n(name).PTR(ptr_to)
299 def gen_hash(self) -> None:
301 for name, ttl, rec in self.zone.iterate_rdatas():
302 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
303 sha.update(text.encode('us-ascii'))
304 self.state.hash = sha.hexdigest()[:16]
306 def gen_serial(self) -> None:
307 prev = self.prev_state.serial
308 if self.state.hash == self.prev_state.hash and prev > 0:
309 self.state.serial = self.prev_state.serial
311 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
313 self.state.serial = base + 1
315 self.state.serial = prev + 1
316 if prev >= base + 99:
317 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
319 def process(self) -> None:
320 if self.zone_type == ZoneType.primary:
324 def write_zone(self) -> None:
326 new_file = Path(str(self.zone_file) + '.new')
327 with open(new_file, 'w') as f:
329 new_file.replace(self.zone_file)
331 def write_state(self) -> None:
332 self.state.save(self.state_file)
334 def is_changed(self) -> bool:
335 return self.state.serial != self.prev_state.serial
338 class NscZoneSecondary(NscZone):
339 primary_server: IPAddress
342 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
343 super().__init__(*args, **kwargs)
344 self.zone_type = ZoneType.secondary
345 self.primary_server = primary_server
346 self.secondary_file = self.nsc.secondary_dir / self.safe_name
351 zones: Dict[str, NscZone]
352 default_zone_config: NscZoneConfig
353 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
354 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
359 daemon: 'NscDaemon' # Set by DaemonConfig class
362 directory: str = '.',
363 daemon: Optional['NscDaemon'] = None,
365 self.start_time = datetime.now()
367 self.default_zone_config = NscZoneConfig(**kwargs)
368 self.ipv4_reverse = defaultdict(list)
369 self.ipv6_reverse = defaultdict(list)
371 self.root_dir = Path(directory)
372 self.state_dir = self.root_dir / 'state'
373 self.state_dir.mkdir(parents=True, exist_ok=True)
374 self.zone_dir = self.root_dir / 'zone'
375 self.zone_dir.mkdir(parents=True, exist_ok=True)
376 self.secondary_dir = self.root_dir / 'secondary'
377 self.secondary_dir.mkdir(parents=True, exist_ok=True)
380 from nsconfig.daemon import NscDaemonNull
381 daemon = NscDaemonNull()
386 name: Optional[str] = None,
387 reverse_for: str | IPNetwork | None = None,
388 follow_primary: str | IPAddress | None = None,
389 inherit_config: Optional[NscZoneConfig] = None,
391 if inherit_config is None:
392 inherit_config = self.default_zone_config
394 if reverse_for is not None:
395 if isinstance(reverse_for, str):
396 reverse_for = ip_network(reverse_for, strict=True)
397 name = name or self._reverse_zone_name(reverse_for)
398 assert name is not None
399 assert name not in self.zones
402 if follow_primary is None:
403 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
405 if isinstance(follow_primary, str):
406 follow_primary = ip_address(follow_primary)
407 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
412 def _reverse_zone_name(self, net: IPNetwork) -> str:
413 if isinstance(net, IPv4Network):
414 parts = str(net.network_address).split('.')
415 out = parts[:net.prefixlen // 8]
416 if net.prefixlen % 8 != 0:
417 out.append(parts[len(out)] + '/' + str(net.prefixlen))
418 return '.'.join(reversed(out)) + '.in-addr.arpa'
419 elif isinstance(net, IPv6Network):
420 assert net.prefixlen % 4 == 0
421 nibbles = net.network_address.exploded.replace(':', "")
422 nibbles = nibbles[:net.prefixlen // 4]
423 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
425 raise NotImplementedError()
427 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
428 if isinstance(addr, IPv4Address):
429 self.ipv4_reverse[addr].append(ptr_to)
431 self.ipv6_reverse[addr].append(ptr_to)
433 def dump_reverse(self) -> None:
434 print('### Requests for reverse mappings ###')
435 for ipa4, name in sorted(self.ipv4_reverse.items()):
436 print(f'{ipa4}\t{name}')
437 for ipa6, name in sorted(self.ipv6_reverse.items()):
438 print(f'{ipa6}\t{name}')
440 def fill_reverse(self) -> None:
441 for z in self.zones.values():
442 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
443 if isinstance(z.reverse_for, IPv4Network):
444 for addr4, ptr_list in self.ipv4_reverse.items():
445 if addr4 in z.reverse_for:
446 for ptr_to in ptr_list:
447 z._add_ipv4_reverse(addr4, ptr_to)
449 for addr6, ptr_list in self.ipv6_reverse.items():
450 if addr6 in z.reverse_for:
451 for ptr_to in ptr_list:
452 z._add_ipv6_reverse(addr6, ptr_to)
454 def get_zones(self) -> List[NscZone]:
455 return [self.zones[k] for k in sorted(self.zones.keys())]
457 def process(self) -> None:
459 for z in self.get_zones():