1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.CNAME
10 import dns.rdtypes.ANY.MX
11 import dns.rdtypes.ANY.NS
12 import dns.rdtypes.ANY.PTR
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from enum import Enum, auto
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
22 from pathlib import Path
25 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
27 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name
28 from nsconfig.util import IPAddress, IPNetwork, IPAddr
32 from nsconfig.daemon import NscDaemon
36 nsc_zone: 'NscZonePrimary'
41 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
42 self.nsc_zone = nsc_zone
44 self.node = nsc_zone.zone.find_node(name, create=True)
45 self._ttl = nsc_zone._min_ttl
47 def ttl(self, *args, **kwargs) -> Self:
48 if not args and not kwargs:
49 self._ttl = self.nsc_zone._min_ttl
51 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
54 def _add(self, rec: Rdata) -> None:
55 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
56 rds.add(rec, ttl=self._ttl)
58 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
59 for a in map(parse_address, flatten_list(addrs)):
60 if isinstance(a, IPv4Address):
61 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
63 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
65 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name + '.' + self.nsc_zone.name))
68 def MX(self, pri: int, name: str) -> Self:
70 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
74 def NS(self, *names: str | List[str]) -> Self:
75 for name in map(parse_name, flatten_list(names)):
76 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
79 def TXT(self, *text: str | List[str]) -> Self:
80 for txt in flatten_list(text):
81 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
84 def PTR(self, target: Name | str) -> Self:
85 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
88 def CNAME(self, target: Name | str) -> Self:
89 self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
92 def generic(self, typ: str, text: str) -> Self:
93 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
104 daemon_options: List[str]
107 default_config: Optional['NscZoneConfig'] = None
110 admin_email: Optional[str] = None,
111 refresh: Optional[timedelta] = None,
112 retry: Optional[timedelta] = None,
113 expire: Optional[timedelta] = None,
114 min_ttl: Optional[timedelta] = None,
115 origin_server: Optional[str] = None,
116 daemon_options: Optional[List[str]] = None,
117 add_daemon_options: Optional[List[str]] = None,
118 add_null_mx: Optional[bool] = None,
119 inherit_config: Optional['NscZoneConfig'] = None,
121 if inherit_config is None:
122 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
123 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
124 self.refresh = refresh if refresh is not None else inherit_config.refresh
125 self.retry = retry if retry is not None else inherit_config.retry
126 self.expire = expire if expire is not None else inherit_config.expire
127 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
128 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
129 self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
130 self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
131 if add_daemon_options is not None:
132 self.daemon_options += add_daemon_options
134 def finalize(self) -> Self:
135 if not self.origin_server:
136 self.origin_server = socket.getfqdn()
137 if not self.admin_email:
138 self.admin_email = f'hostmaster@{self.origin_server}'
142 NscZoneConfig.default_config = NscZoneConfig(
144 refresh=timedelta(hours=8),
145 retry=timedelta(hours=2),
146 expire=timedelta(days=14),
147 min_ttl=timedelta(days=1),
158 def __init__(self) -> None:
162 def load(self, file: Path) -> None:
164 with open(file) as f:
166 assert isinstance(js, dict)
168 self.serial = js['serial']
170 self.hash = js['hash']
171 except FileNotFoundError:
174 def save(self, file: Path) -> None:
175 new_file = Path(str(file) + '.new')
176 with open(new_file, 'w') as f:
178 'serial': self.serial,
181 json.dump(js, f, indent=4, sort_keys=True)
182 new_file.replace(file)
185 class ZoneType(Enum):
193 safe_name: str # For use in file names
195 reverse_for: Optional[IPNetwork]
200 reverse_for: Optional[IPNetwork],
204 self.safe_name = name.replace('/', '@')
205 self.config = NscZoneConfig(**kwargs).finalize()
206 self.reverse_for = reverse_for
208 def process(self) -> None:
212 class NscZonePrimary(NscZone):
218 prev_state: NscZoneState
220 def __init__(self, *args, **kwargs) -> None:
221 super().__init__(*args, **kwargs)
223 self.zone_type = ZoneType.primary
224 self.zone_file = self.nsc.zone_dir / self.safe_name
225 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
227 self.state = NscZoneState()
228 self.prev_state = NscZoneState()
229 self.prev_state.load(self.state_file)
231 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
232 self._min_ttl = int(self.config.min_ttl.total_seconds())
235 def update_soa(self) -> None:
237 soa = dns.rdtypes.ANY.SOA.SOA(
238 RdataClass.IN, RdataType.SOA,
239 mname=conf.origin_server,
240 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
241 serial=self.state.serial,
242 refresh=int(conf.refresh.total_seconds()),
243 retry=int(conf.retry.total_seconds()),
244 expire=int(conf.expire.total_seconds()),
245 minimum=int(conf.min_ttl.total_seconds()),
247 self.zone.delete_rdataset("", RdataType.SOA)
250 def n(self, name: str) -> NscNode:
251 return NscNode(self, name)
253 def __getitem__(self, name: str) -> NscNode:
254 return NscNode(self, name)
256 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
257 n = NscNode(self, name)
258 n.A(*args, reverse=reverse)
261 def zone_header(self) -> str:
263 f'; Zone file for {self.name}\n'
264 + '; Generated by NSC, please do not edit manually.\n'
267 def dump(self, file: Optional[TextIO] = None) -> None:
268 # Could use self.zone.to_file(sys.stdout), but we want better formatting
269 file = file or sys.stdout
270 file.write(self.zone_header())
272 for name, ttl, rec in self.zone.iterate_rdatas():
273 if name == last_name:
277 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
280 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
281 assert isinstance(self.reverse_for, IPv4Network)
282 parts = str(addr).split('.')
283 parts = parts[self.reverse_for.prefixlen // 8:]
284 name = '.'.join(reversed(parts))
285 self.n(name).PTR(ptr_to)
287 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
288 assert isinstance(self.reverse_for, IPv6Network)
289 parts = addr.exploded.replace(':', "")
290 parts = parts[self.reverse_for.prefixlen // 4:]
291 name = '.'.join(reversed(parts))
292 self.n(name).PTR(ptr_to)
294 def gen_hash(self) -> None:
296 sha.update(self.zone_header().encode('us-ascii'))
297 for name, ttl, rec in self.zone.iterate_rdatas():
298 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
299 sha.update(text.encode('us-ascii'))
300 self.state.hash = sha.hexdigest()[:16]
302 def gen_serial(self) -> None:
303 prev = self.prev_state.serial
304 if self.state.hash == self.prev_state.hash and prev > 0:
305 self.state.serial = self.prev_state.serial
307 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
309 self.state.serial = base + 1
311 self.state.serial = prev + 1
312 if prev >= base + 99:
313 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
315 def process(self) -> None:
316 if self.config.add_null_mx:
321 def write_zone(self) -> None:
323 new_file = Path(str(self.zone_file) + '.new')
324 with open(new_file, 'w') as f:
326 new_file.replace(self.zone_file)
328 def write_state(self) -> None:
329 self.state.save(self.state_file)
331 def is_changed(self) -> bool:
332 return self.state.serial != self.prev_state.serial
334 def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
335 net = parse_network(net)
336 assert self.reverse_for is not None
337 assert isinstance(self.reverse_for, IPv4Network)
338 assert self.reverse_for.prefixlen % 8 == 0
339 assert isinstance(net, IPv4Network)
340 assert net.subnet_of(self.reverse_for)
341 assert net.prefixlen < self.reverse_for.prefixlen + 8
343 start = int(net.network_address.packed[net.prefixlen // 8])
344 num = 1 << (8 - net.prefixlen % 8)
346 if subdomain is None:
347 subdomain = f'{start}/{net.prefixlen}'
349 for i in range(start, start + num):
350 target = f'{i}.{subdomain}'
351 self[str(i)].CNAME(parse_name(target, relative=True))
353 return self[subdomain]
355 def gen_null_mx(self) -> None:
356 for name, node in self.zone.items():
357 rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
358 rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
359 if rds_a or rds_aaaa:
360 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
363 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
368 class NscZoneSecondary(NscZone):
369 primary_server: IPAddress
372 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
373 super().__init__(*args, **kwargs)
374 self.zone_type = ZoneType.secondary
375 self.primary_server = primary_server
376 self.secondary_file = self.nsc.secondary_dir / self.safe_name
381 zones: Dict[str, NscZone]
382 default_zone_config: NscZoneConfig
383 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
384 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
389 daemon: 'NscDaemon' # Set by DaemonConfig class
392 directory: str = '.',
393 daemon: Optional['NscDaemon'] = None,
395 self.start_time = datetime.now()
397 self.default_zone_config = NscZoneConfig(**kwargs)
398 self.ipv4_reverse = defaultdict(list)
399 self.ipv6_reverse = defaultdict(list)
401 self.root_dir = Path(directory)
402 self.state_dir = self.root_dir / 'state'
403 self.state_dir.mkdir(parents=True, exist_ok=True)
404 self.zone_dir = self.root_dir / 'zone'
405 self.zone_dir.mkdir(parents=True, exist_ok=True)
406 self.secondary_dir = self.root_dir / 'secondary'
407 self.secondary_dir.mkdir(parents=True, exist_ok=True)
410 from nsconfig.daemon import NscDaemonNull
411 daemon = NscDaemonNull()
416 name: Optional[str] = None,
417 reverse_for: str | IPNetwork | None = None,
418 follow_primary: str | IPAddress | None = None,
419 inherit_config: Optional[NscZoneConfig] = None,
421 if inherit_config is None:
422 inherit_config = self.default_zone_config
424 if reverse_for is not None:
425 if isinstance(reverse_for, str):
426 reverse_for = ip_network(reverse_for, strict=True)
427 name = name or self._reverse_zone_name(reverse_for)
428 assert name is not None
429 assert name not in self.zones
432 if follow_primary is None:
433 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
435 if isinstance(follow_primary, str):
436 follow_primary = ip_address(follow_primary)
437 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
442 def _reverse_zone_name(self, net: IPNetwork) -> str:
443 if isinstance(net, IPv4Network):
444 parts = str(net.network_address).split('.')
445 out = parts[:net.prefixlen // 8]
446 if net.prefixlen % 8 != 0:
447 out.append(parts[len(out)] + '/' + str(net.prefixlen))
448 return '.'.join(reversed(out)) + '.in-addr.arpa'
449 elif isinstance(net, IPv6Network):
450 assert net.prefixlen % 4 == 0
451 nibbles = net.network_address.exploded.replace(':', "")
452 nibbles = nibbles[:net.prefixlen // 4]
453 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
455 raise NotImplementedError()
457 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
458 if isinstance(addr, IPv4Address):
459 self.ipv4_reverse[addr].append(ptr_to)
461 self.ipv6_reverse[addr].append(ptr_to)
463 def dump_reverse(self) -> None:
464 print('### Requests for reverse mappings ###')
465 for ipa4, name in sorted(self.ipv4_reverse.items()):
466 print(f'{ipa4}\t{name}')
467 for ipa6, name in sorted(self.ipv6_reverse.items()):
468 print(f'{ipa6}\t{name}')
470 def fill_reverse(self) -> None:
471 for z in self.zones.values():
472 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
473 if isinstance(z.reverse_for, IPv4Network):
474 for addr4, ptr_list in self.ipv4_reverse.items():
475 if addr4 in z.reverse_for:
476 for ptr_to in ptr_list:
477 z._add_ipv4_reverse(addr4, ptr_to)
479 for addr6, ptr_list in self.ipv6_reverse.items():
480 if addr6 in z.reverse_for:
481 for ptr_to in ptr_list:
482 z._add_ipv6_reverse(addr6, ptr_to)
484 def get_zones(self) -> List[NscZone]:
485 return [self.zones[k] for k in sorted(self.zones.keys())]
487 def process(self) -> None:
489 for z in self.get_zones():