1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 from pathlib import Path
24 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
26 from nsconfig.util import flatten_list
30 from nsconfig.daemon import NscDaemon
33 IPAddress = IPv4Address | IPv6Address
34 IPNetwork = IPv4Network | IPv6Network
35 IPAddr = str | IPAddress | List[str | IPAddress]
39 nsc_zone: 'NscZonePrimary'
44 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
45 self.nsc_zone = nsc_zone
47 self.node = nsc_zone.zone.find_node(name, create=True)
48 self._ttl = nsc_zone._min_ttl
50 def ttl(self, *args, **kwargs) -> Self:
51 if not args and not kwargs:
52 self._ttl = self.nsc_zone._min_ttl
54 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
57 def _add(self, rec: Rdata) -> None:
58 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
59 rds.add(rec, ttl=self._ttl)
61 def _parse_addr(self, addr: IPAddr | str) -> IPAddress:
62 if isinstance(addr, IPv4Address) or isinstance(addr, IPv6Address):
64 elif isinstance(addr, str):
65 return ip_address(addr)
67 raise ValueError('Cannot parse IP address')
69 def _parse_name(self, name: str) -> Name:
70 # FIXME: Names with escaped dots
72 return dns.name.from_text(name)
74 return dns.name.from_text(name, origin=None)
76 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
77 for a in map(self._parse_addr, flatten_list(addrs)):
78 if isinstance(a, IPv4Address):
79 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
81 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
83 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
86 def MX(self, pri: int, name: str) -> Self:
88 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
92 def NS(self, *names: str | List[str]) -> Self:
93 for name in map(self._parse_name, flatten_list(names)):
94 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
97 def TXT(self, *text: str | List[str]) -> Self:
98 for txt in flatten_list(text):
99 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
102 def PTR(self, target: Name | str) -> Self:
103 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
106 def generic(self, typ: str, text: str) -> Self:
107 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
118 daemon_options: List[str]
120 default_config: Optional['NscZoneConfig'] = None
123 admin_email: Optional[str] = None,
124 refresh: Optional[timedelta] = None,
125 retry: Optional[timedelta] = None,
126 expire: Optional[timedelta] = None,
127 min_ttl: Optional[timedelta] = None,
128 origin_server: Optional[str] = None,
129 daemon_options: Optional[List[str]] = None,
130 add_daemon_options: Optional[List[str]] = None,
131 inherit_config: Optional['NscZoneConfig'] = None,
133 if inherit_config is None:
134 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
135 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
136 self.refresh = refresh if refresh is not None else inherit_config.refresh
137 self.retry = retry if retry is not None else inherit_config.retry
138 self.expire = expire if expire is not None else inherit_config.expire
139 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
140 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
141 self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
142 if add_daemon_options is not None:
143 self.daemon_options += add_daemon_options
145 def finalize(self) -> Self:
146 if not self.origin_server:
147 self.origin_server = socket.getfqdn()
148 if not self.admin_email:
149 self.admin_email = f'hostmaster@{self.origin_server}'
153 NscZoneConfig.default_config = NscZoneConfig(
155 refresh=timedelta(hours=8),
156 retry=timedelta(hours=2),
157 expire=timedelta(days=14),
158 min_ttl=timedelta(days=1),
168 def __init__(self) -> None:
172 def load(self, file: Path) -> None:
174 with open(file) as f:
176 assert isinstance(js, dict)
178 self.serial = js['serial']
180 self.hash = js['hash']
181 except FileNotFoundError:
184 def save(self, file: Path) -> None:
185 new_file = Path(str(file) + '.new')
186 with open(new_file, 'w') as f:
188 'serial': self.serial,
191 json.dump(js, f, indent=4, sort_keys=True)
192 new_file.replace(file)
195 class ZoneType(Enum):
203 safe_name: str # For use in file names
205 reverse_for: Optional[IPNetwork]
210 reverse_for: Optional[IPNetwork],
214 self.safe_name = name.replace('/', '@')
215 self.config = NscZoneConfig(**kwargs).finalize()
216 self.reverse_for = reverse_for
218 def process(self) -> None:
222 class NscZonePrimary(NscZone):
228 prev_state: NscZoneState
230 def __init__(self, *args, **kwargs) -> None:
231 super().__init__(*args, **kwargs)
233 self.zone_type = ZoneType.primary
234 self.zone_file = self.nsc.zone_dir / self.safe_name
235 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
237 self.state = NscZoneState()
238 self.prev_state = NscZoneState()
239 self.prev_state.load(self.state_file)
241 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
242 self._min_ttl = int(self.config.min_ttl.total_seconds())
245 def update_soa(self) -> None:
247 soa = dns.rdtypes.ANY.SOA.SOA(
248 RdataClass.IN, RdataType.SOA,
249 mname=conf.origin_server,
250 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
251 serial=self.state.serial,
252 refresh=int(conf.refresh.total_seconds()),
253 retry=int(conf.retry.total_seconds()),
254 expire=int(conf.expire.total_seconds()),
255 minimum=int(conf.min_ttl.total_seconds()),
257 self.zone.delete_rdataset("", RdataType.SOA)
260 def n(self, name: str) -> NscNode:
261 return NscNode(self, name)
263 def __getitem__(self, name: str) -> NscNode:
264 return NscNode(self, name)
266 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
267 n = NscNode(self, name)
268 n.A(*args, reverse=reverse)
271 def zone_header(self) -> str:
273 f'; Zone file for {self.name}\n'
274 + '; Generated by NSC, please do not edit manually.\n'
277 def dump(self, file: Optional[TextIO] = None) -> None:
278 # Could use self.zone.to_file(sys.stdout), but we want better formatting
279 file = file or sys.stdout
280 file.write(self.zone_header())
282 for name, ttl, rec in self.zone.iterate_rdatas():
283 if name == last_name:
287 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
290 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
291 assert isinstance(self.reverse_for, IPv4Network)
292 parts = str(addr).split('.')
293 parts = parts[self.reverse_for.prefixlen // 8:]
294 name = '.'.join(reversed(parts))
295 self.n(name).PTR(ptr_to)
297 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
298 assert isinstance(self.reverse_for, IPv6Network)
299 parts = addr.exploded.replace(':', "")
300 parts = parts[self.reverse_for.prefixlen // 4:]
301 name = '.'.join(reversed(parts))
302 self.n(name).PTR(ptr_to)
304 def gen_hash(self) -> None:
306 sha.update(self.zone_header().encode('us-ascii'))
307 for name, ttl, rec in self.zone.iterate_rdatas():
308 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
309 sha.update(text.encode('us-ascii'))
310 self.state.hash = sha.hexdigest()[:16]
312 def gen_serial(self) -> None:
313 prev = self.prev_state.serial
314 if self.state.hash == self.prev_state.hash and prev > 0:
315 self.state.serial = self.prev_state.serial
317 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
319 self.state.serial = base + 1
321 self.state.serial = prev + 1
322 if prev >= base + 99:
323 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
325 def process(self) -> None:
326 if self.zone_type == ZoneType.primary:
330 def write_zone(self) -> None:
332 new_file = Path(str(self.zone_file) + '.new')
333 with open(new_file, 'w') as f:
335 new_file.replace(self.zone_file)
337 def write_state(self) -> None:
338 self.state.save(self.state_file)
340 def is_changed(self) -> bool:
341 return self.state.serial != self.prev_state.serial
344 class NscZoneSecondary(NscZone):
345 primary_server: IPAddress
348 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
349 super().__init__(*args, **kwargs)
350 self.zone_type = ZoneType.secondary
351 self.primary_server = primary_server
352 self.secondary_file = self.nsc.secondary_dir / self.safe_name
357 zones: Dict[str, NscZone]
358 default_zone_config: NscZoneConfig
359 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
360 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
365 daemon: 'NscDaemon' # Set by DaemonConfig class
368 directory: str = '.',
369 daemon: Optional['NscDaemon'] = None,
371 self.start_time = datetime.now()
373 self.default_zone_config = NscZoneConfig(**kwargs)
374 self.ipv4_reverse = defaultdict(list)
375 self.ipv6_reverse = defaultdict(list)
377 self.root_dir = Path(directory)
378 self.state_dir = self.root_dir / 'state'
379 self.state_dir.mkdir(parents=True, exist_ok=True)
380 self.zone_dir = self.root_dir / 'zone'
381 self.zone_dir.mkdir(parents=True, exist_ok=True)
382 self.secondary_dir = self.root_dir / 'secondary'
383 self.secondary_dir.mkdir(parents=True, exist_ok=True)
386 from nsconfig.daemon import NscDaemonNull
387 daemon = NscDaemonNull()
392 name: Optional[str] = None,
393 reverse_for: str | IPNetwork | None = None,
394 follow_primary: str | IPAddress | None = None,
395 inherit_config: Optional[NscZoneConfig] = None,
397 if inherit_config is None:
398 inherit_config = self.default_zone_config
400 if reverse_for is not None:
401 if isinstance(reverse_for, str):
402 reverse_for = ip_network(reverse_for, strict=True)
403 name = name or self._reverse_zone_name(reverse_for)
404 assert name is not None
405 assert name not in self.zones
408 if follow_primary is None:
409 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
411 if isinstance(follow_primary, str):
412 follow_primary = ip_address(follow_primary)
413 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
418 def _reverse_zone_name(self, net: IPNetwork) -> str:
419 if isinstance(net, IPv4Network):
420 parts = str(net.network_address).split('.')
421 out = parts[:net.prefixlen // 8]
422 if net.prefixlen % 8 != 0:
423 out.append(parts[len(out)] + '/' + str(net.prefixlen))
424 return '.'.join(reversed(out)) + '.in-addr.arpa'
425 elif isinstance(net, IPv6Network):
426 assert net.prefixlen % 4 == 0
427 nibbles = net.network_address.exploded.replace(':', "")
428 nibbles = nibbles[:net.prefixlen // 4]
429 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
431 raise NotImplementedError()
433 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
434 if isinstance(addr, IPv4Address):
435 self.ipv4_reverse[addr].append(ptr_to)
437 self.ipv6_reverse[addr].append(ptr_to)
439 def dump_reverse(self) -> None:
440 print('### Requests for reverse mappings ###')
441 for ipa4, name in sorted(self.ipv4_reverse.items()):
442 print(f'{ipa4}\t{name}')
443 for ipa6, name in sorted(self.ipv6_reverse.items()):
444 print(f'{ipa6}\t{name}')
446 def fill_reverse(self) -> None:
447 for z in self.zones.values():
448 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
449 if isinstance(z.reverse_for, IPv4Network):
450 for addr4, ptr_list in self.ipv4_reverse.items():
451 if addr4 in z.reverse_for:
452 for ptr_to in ptr_list:
453 z._add_ipv4_reverse(addr4, ptr_to)
455 for addr6, ptr_list in self.ipv6_reverse.items():
456 if addr6 in z.reverse_for:
457 for ptr_to in ptr_list:
458 z._add_ipv6_reverse(addr6, ptr_to)
460 def get_zones(self) -> List[NscZone]:
461 return [self.zones[k] for k in sorted(self.zones.keys())]
463 def process(self) -> None:
465 for z in self.get_zones():