1 # PyNSC: Main data structures
2 # (c) 2024 Martin Mareš <mj@ucw.cz>
4 from collections import defaultdict
5 from datetime import datetime, timedelta
7 from dns.name import Name
8 from dns.node import Node
9 from dns.rdata import Rdata
10 from dns.rdataclass import RdataClass
11 from dns.rdatatype import RdataType
12 import dns.rdtypes.ANY.CNAME
13 import dns.rdtypes.ANY.MX
14 import dns.rdtypes.ANY.NS
15 import dns.rdtypes.ANY.PTR
16 import dns.rdtypes.ANY.SOA
17 import dns.rdtypes.ANY.TXT
18 import dns.rdtypes.IN.A
19 import dns.rdtypes.IN.AAAA
20 from dns.zone import Zone
21 from enum import Enum, auto
23 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
25 from pathlib import Path
28 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
30 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration
31 from nsconfig.util import IPAddress, IPNetwork, IPAddr
35 from nsconfig.daemon import NscDaemon
39 nsc_zone: 'NscZonePrimary'
44 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
45 self.nsc_zone = nsc_zone
47 self.node = nsc_zone.zone.find_node(name, create=True)
48 self._ttl = nsc_zone.config.default_ttl
50 def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
51 if seconds is not None:
54 self._ttl = parse_duration(timedelta(**kwargs))
56 self._ttl = self.nsc_zone.config.default_ttl
59 def _add(self, rec: Rdata) -> None:
60 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
61 rds.add(rec, ttl=self._ttl)
63 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
64 for a in map(parse_address, flatten_list(addrs)):
65 if isinstance(a, IPv4Address):
66 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
68 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
70 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name, origin=self.nsc_zone.dns_name))
73 def MX(self, pri: int, name: str) -> Self:
75 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
79 def NS(self, *names: str | List[str]) -> Self:
80 for name in map(parse_name, flatten_list(names)):
81 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
84 def TXT(self, *text: str | List[str]) -> Self:
85 for txt in flatten_list(text):
86 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
89 def PTR(self, target: Name | str) -> Self:
90 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
93 def CNAME(self, target: Name | str) -> Self:
94 self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
97 def generic(self, typ: str, text: str) -> Self:
98 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
110 daemon_options: List[str]
113 default_config: Optional['NscZoneConfig'] = None
116 admin_email: Optional[str] = None,
117 refresh: Optional[int | timedelta] = None,
118 retry: Optional[int | timedelta] = None,
119 expire: Optional[int | timedelta] = None,
120 min_ttl: Optional[int | timedelta] = None,
121 default_ttl: Optional[int | timedelta] = None,
122 origin_server: Optional[str] = None,
123 daemon_options: Optional[List[str]] = None,
124 add_daemon_options: Optional[List[str]] = None,
125 add_null_mx: Optional[bool] = None,
126 inherit_config: Optional['NscZoneConfig'] = None,
128 if inherit_config is None:
129 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
130 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
131 self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
132 self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
133 self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
134 self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
135 self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
136 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
137 self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
138 self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
139 if add_daemon_options is not None:
140 self.daemon_options += add_daemon_options
142 def finalize(self) -> Self:
143 if not self.origin_server:
144 self.origin_server = socket.getfqdn()
145 if not self.admin_email:
146 self.admin_email = f'hostmaster@{self.origin_server}'
147 if self.default_ttl == 0:
148 self.default_ttl = self.min_ttl
152 NscZoneConfig.default_config = NscZoneConfig(
154 refresh=timedelta(hours=8),
155 retry=timedelta(hours=2),
156 expire=timedelta(days=14),
157 min_ttl=timedelta(days=1),
169 def __init__(self) -> None:
173 def load(self, file: Path) -> None:
175 with open(file) as f:
177 assert isinstance(js, dict)
179 self.serial = js['serial']
181 self.hash = js['hash']
182 except FileNotFoundError:
185 def save(self, file: Path) -> None:
186 new_file = Path(str(file) + '.new')
187 with open(new_file, 'w') as f:
189 'serial': self.serial,
192 json.dump(js, f, indent=4, sort_keys=True)
193 new_file.replace(file)
196 class ZoneType(Enum):
206 safe_name: str # For use in file names
208 reverse_for: Optional[IPNetwork]
213 reverse_for: Optional[IPNetwork],
217 self.dns_name = dns.name.from_text(name)
218 self.safe_name = name.replace('/', '@')
219 self.config = NscZoneConfig(**kwargs).finalize()
220 self.reverse_for = reverse_for
222 def process(self) -> None:
225 def is_changed(self) -> bool:
229 class NscZonePrimary(NscZone):
234 prev_state: NscZoneState
235 aliases: List['NscZoneAlias']
237 def __init__(self, *args, **kwargs) -> None:
238 super().__init__(*args, **kwargs)
240 self.zone_type = ZoneType.primary
241 self.zone_file = self.nsc.zone_dir / self.safe_name
242 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
244 self.state = NscZoneState()
245 self.prev_state = NscZoneState()
246 self.prev_state.load(self.state_file)
250 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
253 def update_soa(self) -> None:
255 soa = dns.rdtypes.ANY.SOA.SOA(
256 RdataClass.IN, RdataType.SOA,
257 mname=conf.origin_server,
258 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
259 serial=self.state.serial,
260 refresh=conf.refresh,
263 minimum=conf.min_ttl,
265 self.zone.delete_rdataset("", RdataType.SOA)
268 def n(self, name: str) -> NscNode:
269 return NscNode(self, name)
271 def __getitem__(self, name: str) -> NscNode:
272 return NscNode(self, name)
274 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
275 n = NscNode(self, name)
276 n.A(*args, reverse=reverse)
279 def zone_header(self) -> str:
281 f'; Zone file for {self.name}\n'
282 + '; Generated by NSC, please do not edit manually.\n'
285 def dump(self, file: Optional[TextIO] = None) -> None:
286 # Could use self.zone.to_file(sys.stdout), but we want better formatting
287 file = file or sys.stdout
288 file.write(self.zone_header())
289 file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
291 for name, ttl, rec in self.zone.iterate_rdatas():
292 if name == last_name:
296 file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
299 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
300 assert isinstance(self.reverse_for, IPv4Network)
301 parts = str(addr).split('.')
302 parts = parts[self.reverse_for.prefixlen // 8:]
303 name = '.'.join(reversed(parts))
304 self.n(name).PTR(ptr_to)
306 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
307 assert isinstance(self.reverse_for, IPv6Network)
308 parts = addr.exploded.replace(':', "")
309 parts = parts[self.reverse_for.prefixlen // 4:]
310 name = '.'.join(reversed(parts))
311 self.n(name).PTR(ptr_to)
313 def gen_hash(self) -> None:
315 sha.update(self.zone_header().encode('us-ascii'))
316 for name, ttl, rec in self.zone.iterate_rdatas():
317 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
318 sha.update(text.encode('us-ascii'))
319 self.state.hash = sha.hexdigest()[:16]
321 def gen_serial(self) -> None:
322 prev = self.prev_state.serial
323 if self.state.hash == self.prev_state.hash and prev > 0:
324 self.state.serial = self.prev_state.serial
326 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
328 self.state.serial = base + 1
330 self.state.serial = prev + 1
331 if prev >= base + 99:
332 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
334 def process(self) -> None:
335 if self.config.add_null_mx:
340 def write_zone(self) -> None:
342 new_file = Path(str(self.zone_file) + '.new')
343 with open(new_file, 'w') as f:
345 new_file.replace(self.zone_file)
347 def write_state(self) -> None:
348 self.state.save(self.state_file)
350 def is_changed(self) -> bool:
351 return self.state.serial != self.prev_state.serial
353 def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
354 net = parse_network(net)
355 assert self.reverse_for is not None
356 assert isinstance(self.reverse_for, IPv4Network)
357 assert self.reverse_for.prefixlen % 8 == 0
358 assert isinstance(net, IPv4Network)
359 assert net.subnet_of(self.reverse_for)
360 assert net.prefixlen < self.reverse_for.prefixlen + 8
362 start = int(net.network_address.packed[net.prefixlen // 8])
363 num = 1 << (8 - net.prefixlen % 8)
365 if subdomain is None:
366 subdomain = f'{start}/{net.prefixlen}'
368 for i in range(start, start + num):
369 target = f'{i}.{subdomain}'
370 self[str(i)].CNAME(parse_name(target, relative=True))
372 return self[subdomain]
374 def gen_null_mx(self) -> None:
375 for name, node in self.zone.items():
376 rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
377 rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
378 if rds_a or rds_aaaa:
379 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
382 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
383 ttl=self.config.default_ttl,
387 class NscZoneSecondary(NscZone):
388 primary_server: IPAddress
391 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
392 super().__init__(*args, **kwargs)
393 self.zone_type = ZoneType.secondary
394 self.primary_server = primary_server
395 self.secondary_file = self.nsc.secondary_dir / self.safe_name
398 class NscZoneAlias(NscZone):
399 alias_for: NscZonePrimary
401 def __init__(self, *args, alias_for=NscZonePrimary, **kwargs) -> None:
402 assert isinstance(alias_for, NscZonePrimary)
403 super().__init__(*args, **kwargs)
404 self.zone_type = ZoneType.alias
405 self.alias_for = alias_for
406 self.zone_file = alias_for.zone_file
407 alias_for.aliases.append(self)
409 def is_changed(self) -> bool:
410 return self.alias_for.is_changed()
415 zones: Dict[str, NscZone]
416 default_zone_config: NscZoneConfig
417 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
418 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
423 daemon: 'NscDaemon' # Set by DaemonConfig class
426 directory: str = '.',
427 daemon: Optional['NscDaemon'] = None,
429 self.start_time = datetime.now()
431 self.default_zone_config = NscZoneConfig(**kwargs)
432 self.ipv4_reverse = defaultdict(list)
433 self.ipv6_reverse = defaultdict(list)
435 self.root_dir = Path(directory)
436 self.state_dir = self.root_dir / 'state'
437 self.state_dir.mkdir(parents=True, exist_ok=True)
438 self.zone_dir = self.root_dir / 'zone'
439 self.zone_dir.mkdir(parents=True, exist_ok=True)
440 self.secondary_dir = self.root_dir / 'secondary'
441 self.secondary_dir.mkdir(parents=True, exist_ok=True)
444 from nsconfig.daemon import NscDaemonNull
445 daemon = NscDaemonNull()
450 name: Optional[str] = None,
451 reverse_for: str | IPNetwork | None = None,
452 alias_for: Optional[NscZonePrimary] = None,
453 follow_primary: str | IPAddress | None = None,
454 inherit_config: Optional[NscZoneConfig] = None,
456 if inherit_config is None:
457 inherit_config = self.default_zone_config
459 if reverse_for is not None:
460 if isinstance(reverse_for, str):
461 reverse_for = ip_network(reverse_for, strict=True)
462 name = name or self._reverse_zone_name(reverse_for)
463 assert name is not None
464 assert name not in self.zones
467 if alias_for is not None:
468 assert follow_primary is None
469 z = NscZoneAlias(self, name, reverse_for=reverse_for, alias_for=alias_for, inherit_config=inherit_config, **kwargs)
470 elif follow_primary is None:
471 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
473 if isinstance(follow_primary, str):
474 follow_primary = ip_address(follow_primary)
475 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
480 def __getitem__(self, name: str) -> NscZone:
481 return self.zones[name]
483 def _reverse_zone_name(self, net: IPNetwork) -> str:
484 if isinstance(net, IPv4Network):
485 parts = str(net.network_address).split('.')
486 out = parts[:net.prefixlen // 8]
487 if net.prefixlen % 8 != 0:
488 out.append(parts[len(out)] + '/' + str(net.prefixlen))
489 return '.'.join(reversed(out)) + '.in-addr.arpa'
490 elif isinstance(net, IPv6Network):
491 assert net.prefixlen % 4 == 0
492 nibbles = net.network_address.exploded.replace(':', "")
493 nibbles = nibbles[:net.prefixlen // 4]
494 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
496 raise NotImplementedError()
498 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
499 if isinstance(addr, IPv4Address):
500 self.ipv4_reverse[addr].append(ptr_to)
502 self.ipv6_reverse[addr].append(ptr_to)
504 def dump_reverse(self) -> None:
505 print('### Requests for reverse mappings ###')
506 for ipa4, name in sorted(self.ipv4_reverse.items()):
507 print(f'{ipa4}\t{name}')
508 for ipa6, name in sorted(self.ipv6_reverse.items()):
509 print(f'{ipa6}\t{name}')
511 def fill_reverse(self) -> None:
512 for z in self.zones.values():
513 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
514 if isinstance(z.reverse_for, IPv4Network):
515 for addr4, ptr_list in self.ipv4_reverse.items():
516 if addr4 in z.reverse_for:
517 for ptr_to in ptr_list:
518 z._add_ipv4_reverse(addr4, ptr_to)
520 for addr6, ptr_list in self.ipv6_reverse.items():
521 if addr6 in z.reverse_for:
522 for ptr_to in ptr_list:
523 z._add_ipv6_reverse(addr6, ptr_to)
525 def get_zones(self) -> List[NscZone]:
526 return [self.zones[k] for k in sorted(self.zones.keys())]
528 def process(self) -> None:
530 for z in self.get_zones():