1 # PyNSC: Main data structures
2 # (c) 2024 Martin Mareš <mj@ucw.cz>
4 from collections import defaultdict
5 from datetime import datetime, timedelta
7 from dns.name import Name
8 from dns.node import Node
9 from dns.rdata import Rdata
10 from dns.rdataclass import RdataClass
11 from dns.rdatatype import RdataType
12 import dns.rdtypes.ANY.CNAME
13 import dns.rdtypes.ANY.DNAME
14 import dns.rdtypes.ANY.MX
15 import dns.rdtypes.ANY.NS
16 import dns.rdtypes.ANY.PTR
17 import dns.rdtypes.ANY.SOA
18 import dns.rdtypes.ANY.TXT
19 import dns.rdtypes.IN.A
20 import dns.rdtypes.IN.AAAA
21 import dns.rdtypes.IN.SRV
22 from dns.zone import Zone
23 from enum import Enum, auto
25 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
27 from pathlib import Path
30 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, Tuple, TYPE_CHECKING
32 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration, parse_rname
33 from nsconfig.util import IPAddress, IPNetwork, IPAddr, NameParseMode
37 from nsconfig.daemon import NscDaemon
41 nsc_zone: 'NscZonePrimary'
46 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
47 self.nsc_zone = nsc_zone
48 self.name = self._parse_lhs_name(name)
49 self.node = nsc_zone.zone.find_node(self.name, create=True)
50 self._ttl = nsc_zone.config.default_ttl
52 def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
53 if seconds is not None:
56 self._ttl = parse_duration(timedelta(**kwargs))
58 self._ttl = self.nsc_zone.config.default_ttl
61 def _add(self, rec: Rdata) -> None:
62 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
63 rds.add(rec, ttl=self._ttl)
65 def _parse_lhs_name(self, name, **kwargs):
66 return parse_name(name, mode=NameParseMode.relative, **kwargs)
68 def _parse_rhs_name(self, name, **kwargs):
69 return parse_name(name, mode=self.nsc_zone.config.name_parse_mode, **kwargs)
71 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
72 for a in map(parse_address, flatten_list(addrs)):
73 if isinstance(a, IPv4Address):
74 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
76 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
78 self.nsc_zone.nsc._add_reverse_mapping(a, self.name.choose_relativity(origin=self.nsc_zone.dns_name, relativize=False))
81 def CNAME(self, target: Name | str) -> Self:
82 self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
85 def DNAME(self, target: Name | str) -> Self:
86 self._add(dns.rdtypes.ANY.DNAME.DNAME(RdataClass.IN, RdataType.DNAME, target))
89 def MX(self, pri: int, name: str) -> Self:
91 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_rhs_name(name))
95 def MX_list(self, mxs: List[Tuple[int, str]]) -> Self:
100 def NS(self, *names: str | List[str]) -> Self:
101 for name in flatten_list(names):
102 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, self._parse_rhs_name(name)))
105 def PTR(self, target: Name | str) -> Self:
106 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
109 def SRV(self, priority: int, weight: int, port: int, target: Name | str) -> Self:
110 self._add(dns.rdtypes.IN.SRV.SRV(RdataClass.IN, RdataType.SRV, priority, weight, port, self._parse_rhs_name(target)))
113 def TXT(self, *text: str | List[str]) -> Self:
114 for txt in flatten_list(text):
115 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
118 def generic(self, typ: str, text: str) -> Self:
119 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
122 def alias(self, *aliases: str) -> Self:
123 # FIXME: Inter-zone aliases?
124 for alias in flatten_list(aliases):
125 self.nsc_zone[alias].CNAME(self.name)
137 daemon_options: List[str]
139 name_parse_mode: NameParseMode
141 default_config: Optional['NscZoneConfig'] = None
145 admin_email: Optional[str] = None,
146 refresh: Optional[int | timedelta] = None,
147 retry: Optional[int | timedelta] = None,
148 expire: Optional[int | timedelta] = None,
149 min_ttl: Optional[int | timedelta] = None,
150 default_ttl: Optional[int | timedelta] = None,
151 origin_server: Optional[str] = None,
152 daemon_options: Optional[List[str]] = None,
153 add_daemon_options: Optional[List[str]] = None,
154 add_null_mx: Optional[bool] = None,
155 name_parse_mode: Optional[NameParseMode] = None,
156 inherit_config: Optional['NscZoneConfig'] = None,
158 if inherit_config is None:
159 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
160 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
161 self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
162 self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
163 self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
164 self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
165 self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
166 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
167 self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
168 self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
169 self.name_parse_mode = name_parse_mode if name_parse_mode is not None else inherit_config.name_parse_mode
170 if add_daemon_options is not None:
171 self.daemon_options += add_daemon_options
173 def finalize(self) -> Self:
174 if not self.origin_server:
175 self.origin_server = socket.getfqdn()
176 if not self.admin_email:
177 self.admin_email = f'hostmaster@{self.origin_server}'
178 if self.default_ttl == 0:
179 self.default_ttl = self.min_ttl
183 NscZoneConfig.default_config = NscZoneConfig(
185 refresh=timedelta(hours=8),
186 retry=timedelta(hours=2),
187 expire=timedelta(days=14),
188 min_ttl=timedelta(days=1),
193 name_parse_mode=NameParseMode.absolute,
201 def __init__(self) -> None:
205 def load(self, file: Path) -> None:
207 with open(file) as f:
209 assert isinstance(js, dict)
211 self.serial = js['serial']
213 self.hash = js['hash']
214 except FileNotFoundError:
217 def save(self, file: Path) -> None:
218 new_file = Path(str(file) + '.new')
219 with open(new_file, 'w') as f:
221 'serial': self.serial,
224 json.dump(js, f, indent=4, sort_keys=True)
225 new_file.replace(file)
228 class ZoneType(Enum):
238 safe_name: str # For use in file names
240 config: NscZoneConfig
241 reverse_for: Optional[IPNetwork]
246 reverse_for: Optional[IPNetwork],
250 self.dns_name = dns.name.from_text(name)
251 self.safe_name = name.replace('/', '@')
252 self.config = NscZoneConfig(**kwargs).finalize()
253 self.reverse_for = reverse_for
255 def process(self) -> None:
258 def is_changed(self) -> bool:
262 class NscZonePrimary(NscZone):
267 prev_state: NscZoneState
268 aliases: List['NscZoneAlias']
270 def __init__(self, *args, **kwargs) -> None:
271 super().__init__(*args, **kwargs)
273 self.zone_type = ZoneType.primary
274 self.zone_file = self.nsc.zone_dir / self.safe_name
275 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
277 self.state = NscZoneState()
278 self.prev_state = NscZoneState()
279 self.prev_state.load(self.state_file)
283 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
288 return NscNode(self, "")
290 def update_soa(self) -> None:
292 soa = dns.rdtypes.ANY.SOA.SOA(
293 RdataClass.IN, RdataType.SOA,
294 mname=conf.origin_server,
295 rname=parse_rname(conf.admin_email),
296 serial=self.state.serial,
297 refresh=conf.refresh,
300 minimum=conf.min_ttl,
302 self.zone.delete_rdataset("", RdataType.SOA)
305 def n(self, name: str) -> NscNode:
306 return NscNode(self, name)
308 def __getitem__(self, name: str) -> NscNode:
309 return NscNode(self, name)
311 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
312 n = NscNode(self, name)
313 n.A(*args, reverse=reverse)
316 def zone_header(self) -> str:
318 f'; Zone file for {self.name}\n'
319 + '; Generated by NSC, please do not edit manually.\n'
322 def dump(self, file: Optional[TextIO] = None) -> None:
323 # Could use self.zone.to_file(sys.stdout), but we want better formatting
324 file = file or sys.stdout
325 file.write(self.zone_header())
326 file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
328 for name, ttl, rec in self.zone.iterate_rdatas():
329 if name == last_name:
333 file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
336 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
337 assert isinstance(self.reverse_for, IPv4Network)
338 parts = str(addr).split('.')
339 parts = parts[self.reverse_for.prefixlen // 8:]
340 name = '.'.join(reversed(parts))
341 self.n(name).PTR(ptr_to)
343 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
344 assert isinstance(self.reverse_for, IPv6Network)
345 parts = addr.exploded.replace(':', "")
346 parts = parts[self.reverse_for.prefixlen // 4:]
347 name = '.'.join(reversed(parts))
348 self.n(name).PTR(ptr_to)
350 def gen_hash(self) -> None:
352 sha.update(self.zone_header().encode('us-ascii'))
353 for name, ttl, rec in self.zone.iterate_rdatas():
354 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
355 sha.update(text.encode('us-ascii'))
356 self.state.hash = sha.hexdigest()[:16]
358 def gen_serial(self) -> None:
359 prev = self.prev_state.serial
360 if self.state.hash == self.prev_state.hash and prev > 0:
361 self.state.serial = self.prev_state.serial
363 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
365 self.state.serial = base + 1
367 self.state.serial = prev + 1
368 if prev >= base + 99:
369 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
371 def process(self) -> None:
372 if self.config.add_null_mx:
377 def write_zone(self) -> None:
379 new_file = Path(str(self.zone_file) + '.new')
380 with open(new_file, 'w') as f:
382 new_file.replace(self.zone_file)
384 def write_state(self) -> None:
385 self.state.save(self.state_file)
387 def is_changed(self) -> bool:
388 return self.state.serial != self.prev_state.serial
390 def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
391 net = parse_network(net)
392 assert self.reverse_for is not None
393 assert isinstance(self.reverse_for, IPv4Network)
394 assert self.reverse_for.prefixlen % 8 == 0
395 assert isinstance(net, IPv4Network)
396 assert net.subnet_of(self.reverse_for)
397 assert net.prefixlen < self.reverse_for.prefixlen + 8
399 start = int(net.network_address.packed[net.prefixlen // 8])
400 num = 1 << (8 - net.prefixlen % 8)
402 if subdomain is None:
403 subdomain = f'{start}/{net.prefixlen}'
405 for i in range(start, start + num):
406 target = f'{i}.{subdomain}'
407 self[str(i)].CNAME(parse_name(target, mode=NameParseMode.relative))
409 return self[subdomain]
411 def gen_null_mx(self) -> None:
412 for name, node in self.zone.items():
413 rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
414 rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
415 if rds_a or rds_aaaa:
416 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
419 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
420 ttl=self.config.default_ttl,
424 class NscZoneSecondary(NscZone):
425 primary_server: IPAddress
428 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
429 super().__init__(*args, **kwargs)
430 self.zone_type = ZoneType.secondary
431 self.primary_server = primary_server
432 self.secondary_file = self.nsc.secondary_dir / self.safe_name
435 class NscZoneAlias(NscZone):
436 alias_for: NscZonePrimary
438 def __init__(self, *args, alias_for=NscZonePrimary, **kwargs) -> None:
439 assert isinstance(alias_for, NscZonePrimary)
440 super().__init__(*args, **kwargs)
441 self.zone_type = ZoneType.alias
442 self.alias_for = alias_for
443 self.zone_file = alias_for.zone_file
444 alias_for.aliases.append(self)
446 def is_changed(self) -> bool:
447 return self.alias_for.is_changed()
452 zones: Dict[str, NscZone]
453 default_zone_config: NscZoneConfig
454 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
455 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
464 directory: str = '.',
465 daemon: Optional['NscDaemon'] = None,
467 self.start_time = datetime.now()
469 self.default_zone_config = NscZoneConfig(**kwargs)
470 self.ipv4_reverse = defaultdict(list)
471 self.ipv6_reverse = defaultdict(list)
473 self.root_dir = Path(directory)
474 self.state_dir = self.root_dir / 'state'
475 self.state_dir.mkdir(parents=True, exist_ok=True)
476 self.zone_dir = self.root_dir / 'zone'
477 self.zone_dir.mkdir(parents=True, exist_ok=True)
478 self.secondary_dir = self.root_dir / 'secondary'
479 self.secondary_dir.mkdir(parents=True, exist_ok=True)
482 from nsconfig.daemon import NscDaemonNull
483 daemon = NscDaemonNull()
488 name: Optional[str] = None,
490 reverse_for: str | IPNetwork | None = None,
491 alias_for: Optional[NscZonePrimary] = None,
492 secondary_to: str | IPAddress | None = None,
493 inherit_config: Optional[NscZoneConfig] = None,
495 if inherit_config is None:
496 inherit_config = self.default_zone_config
498 if reverse_for is not None:
499 if isinstance(reverse_for, str):
500 reverse_for = ip_network(reverse_for, strict=True)
501 name = name or self._reverse_zone_name(reverse_for)
502 assert name is not None
503 assert name not in self.zones
506 if alias_for is not None:
507 assert secondary_to is None
508 z = NscZoneAlias(self, name, reverse_for=reverse_for, alias_for=alias_for, inherit_config=inherit_config, **kwargs)
509 elif secondary_to is None:
510 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
512 if isinstance(secondary_to, str):
513 secondary_to = ip_address(secondary_to)
514 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=secondary_to, inherit_config=inherit_config, **kwargs)
519 def __getitem__(self, name: str) -> NscZone:
520 return self.zones[name]
522 def _reverse_zone_name(self, net: IPNetwork) -> str:
523 if isinstance(net, IPv4Network):
524 parts = str(net.network_address).split('.')
525 out = parts[:net.prefixlen // 8]
526 if net.prefixlen % 8 != 0:
527 out.append(parts[len(out)] + '/' + str(net.prefixlen))
528 return '.'.join(reversed(out)) + '.in-addr.arpa'
529 elif isinstance(net, IPv6Network):
530 assert net.prefixlen % 4 == 0
531 nibbles = net.network_address.exploded.replace(':', "")
532 nibbles = nibbles[:net.prefixlen // 4]
533 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
535 raise NotImplementedError()
537 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
538 if isinstance(addr, IPv4Address):
539 self.ipv4_reverse[addr].append(ptr_to)
541 self.ipv6_reverse[addr].append(ptr_to)
543 def dump_reverse(self) -> None:
544 print('### Requests for reverse mappings ###')
545 for ipa4, name in sorted(self.ipv4_reverse.items()):
546 print(f'{ipa4}\t{name}')
547 for ipa6, name in sorted(self.ipv6_reverse.items()):
548 print(f'{ipa6}\t{name}')
550 def fill_reverse(self) -> None:
551 for z in self.zones.values():
552 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
553 if isinstance(z.reverse_for, IPv4Network):
554 for addr4, ptr_list in self.ipv4_reverse.items():
555 if addr4 in z.reverse_for:
556 for ptr_to in ptr_list:
557 z._add_ipv4_reverse(addr4, ptr_to)
559 for addr6, ptr_list in self.ipv6_reverse.items():
560 if addr6 in z.reverse_for:
561 for ptr_to in ptr_list:
562 z._add_ipv6_reverse(addr6, ptr_to)
564 def get_zones(self) -> List[NscZone]:
565 return [self.zones[k] for k in sorted(self.zones.keys())]
567 def process(self) -> None:
569 for z in self.get_zones():