1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.CNAME
10 import dns.rdtypes.ANY.MX
11 import dns.rdtypes.ANY.NS
12 import dns.rdtypes.ANY.PTR
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from enum import Enum, auto
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
22 from pathlib import Path
25 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
27 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration
28 from nsconfig.util import IPAddress, IPNetwork, IPAddr
32 from nsconfig.daemon import NscDaemon
36 nsc_zone: 'NscZonePrimary'
41 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
42 self.nsc_zone = nsc_zone
44 self.node = nsc_zone.zone.find_node(name, create=True)
45 self._ttl = nsc_zone.config.default_ttl
47 def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
48 if seconds is not None:
51 self._ttl = parse_duration(timedelta(**kwargs))
53 self._ttl = self.nsc_zone.config.default_ttl
56 def _add(self, rec: Rdata) -> None:
57 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
58 rds.add(rec, ttl=self._ttl)
60 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
61 for a in map(parse_address, flatten_list(addrs)):
62 if isinstance(a, IPv4Address):
63 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
65 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
67 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name, origin=self.nsc_zone.dns_name))
70 def MX(self, pri: int, name: str) -> Self:
72 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
76 def NS(self, *names: str | List[str]) -> Self:
77 for name in map(parse_name, flatten_list(names)):
78 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
81 def TXT(self, *text: str | List[str]) -> Self:
82 for txt in flatten_list(text):
83 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
86 def PTR(self, target: Name | str) -> Self:
87 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
90 def CNAME(self, target: Name | str) -> Self:
91 self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
94 def generic(self, typ: str, text: str) -> Self:
95 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
107 daemon_options: List[str]
110 default_config: Optional['NscZoneConfig'] = None
113 admin_email: Optional[str] = None,
114 refresh: Optional[int | timedelta] = None,
115 retry: Optional[int | timedelta] = None,
116 expire: Optional[int | timedelta] = None,
117 min_ttl: Optional[int | timedelta] = None,
118 default_ttl: Optional[int | timedelta] = None,
119 origin_server: Optional[str] = None,
120 daemon_options: Optional[List[str]] = None,
121 add_daemon_options: Optional[List[str]] = None,
122 add_null_mx: Optional[bool] = None,
123 inherit_config: Optional['NscZoneConfig'] = None,
125 if inherit_config is None:
126 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
127 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
128 self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
129 self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
130 self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
131 self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
132 self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
133 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
134 self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
135 self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
136 if add_daemon_options is not None:
137 self.daemon_options += add_daemon_options
139 def finalize(self) -> Self:
140 if not self.origin_server:
141 self.origin_server = socket.getfqdn()
142 if not self.admin_email:
143 self.admin_email = f'hostmaster@{self.origin_server}'
144 if self.default_ttl == 0:
145 self.default_ttl = self.min_ttl
149 NscZoneConfig.default_config = NscZoneConfig(
151 refresh=timedelta(hours=8),
152 retry=timedelta(hours=2),
153 expire=timedelta(days=14),
154 min_ttl=timedelta(days=1),
166 def __init__(self) -> None:
170 def load(self, file: Path) -> None:
172 with open(file) as f:
174 assert isinstance(js, dict)
176 self.serial = js['serial']
178 self.hash = js['hash']
179 except FileNotFoundError:
182 def save(self, file: Path) -> None:
183 new_file = Path(str(file) + '.new')
184 with open(new_file, 'w') as f:
186 'serial': self.serial,
189 json.dump(js, f, indent=4, sort_keys=True)
190 new_file.replace(file)
193 class ZoneType(Enum):
203 safe_name: str # For use in file names
205 reverse_for: Optional[IPNetwork]
210 reverse_for: Optional[IPNetwork],
214 self.dns_name = dns.name.from_text(name)
215 self.safe_name = name.replace('/', '@')
216 self.config = NscZoneConfig(**kwargs).finalize()
217 self.reverse_for = reverse_for
219 def process(self) -> None:
222 def is_changed(self) -> bool:
226 class NscZonePrimary(NscZone):
231 prev_state: NscZoneState
232 aliases: List['NscZoneAlias']
234 def __init__(self, *args, **kwargs) -> None:
235 super().__init__(*args, **kwargs)
237 self.zone_type = ZoneType.primary
238 self.zone_file = self.nsc.zone_dir / self.safe_name
239 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
241 self.state = NscZoneState()
242 self.prev_state = NscZoneState()
243 self.prev_state.load(self.state_file)
247 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
250 def update_soa(self) -> None:
252 soa = dns.rdtypes.ANY.SOA.SOA(
253 RdataClass.IN, RdataType.SOA,
254 mname=conf.origin_server,
255 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
256 serial=self.state.serial,
257 refresh=conf.refresh,
260 minimum=conf.min_ttl,
262 self.zone.delete_rdataset("", RdataType.SOA)
265 def n(self, name: str) -> NscNode:
266 return NscNode(self, name)
268 def __getitem__(self, name: str) -> NscNode:
269 return NscNode(self, name)
271 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
272 n = NscNode(self, name)
273 n.A(*args, reverse=reverse)
276 def zone_header(self) -> str:
278 f'; Zone file for {self.name}\n'
279 + '; Generated by NSC, please do not edit manually.\n'
282 def dump(self, file: Optional[TextIO] = None) -> None:
283 # Could use self.zone.to_file(sys.stdout), but we want better formatting
284 file = file or sys.stdout
285 file.write(self.zone_header())
286 file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
288 for name, ttl, rec in self.zone.iterate_rdatas():
289 if name == last_name:
293 file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
296 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
297 assert isinstance(self.reverse_for, IPv4Network)
298 parts = str(addr).split('.')
299 parts = parts[self.reverse_for.prefixlen // 8:]
300 name = '.'.join(reversed(parts))
301 self.n(name).PTR(ptr_to)
303 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
304 assert isinstance(self.reverse_for, IPv6Network)
305 parts = addr.exploded.replace(':', "")
306 parts = parts[self.reverse_for.prefixlen // 4:]
307 name = '.'.join(reversed(parts))
308 self.n(name).PTR(ptr_to)
310 def gen_hash(self) -> None:
312 sha.update(self.zone_header().encode('us-ascii'))
313 for name, ttl, rec in self.zone.iterate_rdatas():
314 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
315 sha.update(text.encode('us-ascii'))
316 self.state.hash = sha.hexdigest()[:16]
318 def gen_serial(self) -> None:
319 prev = self.prev_state.serial
320 if self.state.hash == self.prev_state.hash and prev > 0:
321 self.state.serial = self.prev_state.serial
323 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
325 self.state.serial = base + 1
327 self.state.serial = prev + 1
328 if prev >= base + 99:
329 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
331 def process(self) -> None:
332 if self.config.add_null_mx:
337 def write_zone(self) -> None:
339 new_file = Path(str(self.zone_file) + '.new')
340 with open(new_file, 'w') as f:
342 new_file.replace(self.zone_file)
344 def write_state(self) -> None:
345 self.state.save(self.state_file)
347 def is_changed(self) -> bool:
348 return self.state.serial != self.prev_state.serial
350 def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
351 net = parse_network(net)
352 assert self.reverse_for is not None
353 assert isinstance(self.reverse_for, IPv4Network)
354 assert self.reverse_for.prefixlen % 8 == 0
355 assert isinstance(net, IPv4Network)
356 assert net.subnet_of(self.reverse_for)
357 assert net.prefixlen < self.reverse_for.prefixlen + 8
359 start = int(net.network_address.packed[net.prefixlen // 8])
360 num = 1 << (8 - net.prefixlen % 8)
362 if subdomain is None:
363 subdomain = f'{start}/{net.prefixlen}'
365 for i in range(start, start + num):
366 target = f'{i}.{subdomain}'
367 self[str(i)].CNAME(parse_name(target, relative=True))
369 return self[subdomain]
371 def gen_null_mx(self) -> None:
372 for name, node in self.zone.items():
373 rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
374 rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
375 if rds_a or rds_aaaa:
376 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
379 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
380 ttl=self.config.default_ttl,
384 class NscZoneSecondary(NscZone):
385 primary_server: IPAddress
388 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
389 super().__init__(*args, **kwargs)
390 self.zone_type = ZoneType.secondary
391 self.primary_server = primary_server
392 self.secondary_file = self.nsc.secondary_dir / self.safe_name
395 class NscZoneAlias(NscZone):
396 alias_for: NscZonePrimary
398 def __init__(self, *args, alias_for=NscZonePrimary, **kwargs) -> None:
399 assert isinstance(alias_for, NscZonePrimary)
400 super().__init__(*args, **kwargs)
401 self.zone_type = ZoneType.alias
402 self.alias_for = alias_for
403 self.zone_file = alias_for.zone_file
404 alias_for.aliases.append(self)
406 def is_changed(self) -> bool:
407 return self.alias_for.is_changed()
412 zones: Dict[str, NscZone]
413 default_zone_config: NscZoneConfig
414 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
415 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
420 daemon: 'NscDaemon' # Set by DaemonConfig class
423 directory: str = '.',
424 daemon: Optional['NscDaemon'] = None,
426 self.start_time = datetime.now()
428 self.default_zone_config = NscZoneConfig(**kwargs)
429 self.ipv4_reverse = defaultdict(list)
430 self.ipv6_reverse = defaultdict(list)
432 self.root_dir = Path(directory)
433 self.state_dir = self.root_dir / 'state'
434 self.state_dir.mkdir(parents=True, exist_ok=True)
435 self.zone_dir = self.root_dir / 'zone'
436 self.zone_dir.mkdir(parents=True, exist_ok=True)
437 self.secondary_dir = self.root_dir / 'secondary'
438 self.secondary_dir.mkdir(parents=True, exist_ok=True)
441 from nsconfig.daemon import NscDaemonNull
442 daemon = NscDaemonNull()
447 name: Optional[str] = None,
448 reverse_for: str | IPNetwork | None = None,
449 alias_for: Optional[NscZonePrimary] = None,
450 follow_primary: str | IPAddress | None = None,
451 inherit_config: Optional[NscZoneConfig] = None,
453 if inherit_config is None:
454 inherit_config = self.default_zone_config
456 if reverse_for is not None:
457 if isinstance(reverse_for, str):
458 reverse_for = ip_network(reverse_for, strict=True)
459 name = name or self._reverse_zone_name(reverse_for)
460 assert name is not None
461 assert name not in self.zones
464 if alias_for is not None:
465 assert follow_primary is None
466 z = NscZoneAlias(self, name, reverse_for=reverse_for, alias_for=alias_for, inherit_config=inherit_config, **kwargs)
467 elif follow_primary is None:
468 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
470 if isinstance(follow_primary, str):
471 follow_primary = ip_address(follow_primary)
472 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
477 def __getitem__(self, name: str) -> NscZone:
478 return self.zones[name]
480 def _reverse_zone_name(self, net: IPNetwork) -> str:
481 if isinstance(net, IPv4Network):
482 parts = str(net.network_address).split('.')
483 out = parts[:net.prefixlen // 8]
484 if net.prefixlen % 8 != 0:
485 out.append(parts[len(out)] + '/' + str(net.prefixlen))
486 return '.'.join(reversed(out)) + '.in-addr.arpa'
487 elif isinstance(net, IPv6Network):
488 assert net.prefixlen % 4 == 0
489 nibbles = net.network_address.exploded.replace(':', "")
490 nibbles = nibbles[:net.prefixlen // 4]
491 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
493 raise NotImplementedError()
495 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
496 if isinstance(addr, IPv4Address):
497 self.ipv4_reverse[addr].append(ptr_to)
499 self.ipv6_reverse[addr].append(ptr_to)
501 def dump_reverse(self) -> None:
502 print('### Requests for reverse mappings ###')
503 for ipa4, name in sorted(self.ipv4_reverse.items()):
504 print(f'{ipa4}\t{name}')
505 for ipa6, name in sorted(self.ipv6_reverse.items()):
506 print(f'{ipa6}\t{name}')
508 def fill_reverse(self) -> None:
509 for z in self.zones.values():
510 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
511 if isinstance(z.reverse_for, IPv4Network):
512 for addr4, ptr_list in self.ipv4_reverse.items():
513 if addr4 in z.reverse_for:
514 for ptr_to in ptr_list:
515 z._add_ipv4_reverse(addr4, ptr_to)
517 for addr6, ptr_list in self.ipv6_reverse.items():
518 if addr6 in z.reverse_for:
519 for ptr_to in ptr_list:
520 z._add_ipv6_reverse(addr6, ptr_to)
522 def get_zones(self) -> List[NscZone]:
523 return [self.zones[k] for k in sorted(self.zones.keys())]
525 def process(self) -> None:
527 for z in self.get_zones():