1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.CNAME
10 import dns.rdtypes.ANY.MX
11 import dns.rdtypes.ANY.NS
12 import dns.rdtypes.ANY.PTR
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from enum import Enum, auto
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
22 from pathlib import Path
25 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
27 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name
28 from nsconfig.util import IPAddress, IPNetwork, IPAddr
32 from nsconfig.daemon import NscDaemon
36 nsc_zone: 'NscZonePrimary'
41 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
42 self.nsc_zone = nsc_zone
44 self.node = nsc_zone.zone.find_node(name, create=True)
45 self._ttl = nsc_zone._min_ttl
47 def ttl(self, *args, **kwargs) -> Self:
48 if not args and not kwargs:
49 self._ttl = self.nsc_zone._min_ttl
51 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
54 def _add(self, rec: Rdata) -> None:
55 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
56 rds.add(rec, ttl=self._ttl)
58 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
59 for a in map(parse_address, flatten_list(addrs)):
60 if isinstance(a, IPv4Address):
61 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
63 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
65 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name + '.' + self.nsc_zone.name))
68 def MX(self, pri: int, name: str) -> Self:
70 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
74 def NS(self, *names: str | List[str]) -> Self:
75 for name in map(parse_name, flatten_list(names)):
76 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
79 def TXT(self, *text: str | List[str]) -> Self:
80 for txt in flatten_list(text):
81 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
84 def PTR(self, target: Name | str) -> Self:
85 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
88 def CNAME(self, target: Name | str) -> Self:
89 self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
92 def generic(self, typ: str, text: str) -> Self:
93 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
104 daemon_options: List[str]
106 default_config: Optional['NscZoneConfig'] = None
109 admin_email: Optional[str] = None,
110 refresh: Optional[timedelta] = None,
111 retry: Optional[timedelta] = None,
112 expire: Optional[timedelta] = None,
113 min_ttl: Optional[timedelta] = None,
114 origin_server: Optional[str] = None,
115 daemon_options: Optional[List[str]] = None,
116 add_daemon_options: Optional[List[str]] = None,
117 inherit_config: Optional['NscZoneConfig'] = None,
119 if inherit_config is None:
120 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
121 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
122 self.refresh = refresh if refresh is not None else inherit_config.refresh
123 self.retry = retry if retry is not None else inherit_config.retry
124 self.expire = expire if expire is not None else inherit_config.expire
125 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
126 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
127 self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
128 if add_daemon_options is not None:
129 self.daemon_options += add_daemon_options
131 def finalize(self) -> Self:
132 if not self.origin_server:
133 self.origin_server = socket.getfqdn()
134 if not self.admin_email:
135 self.admin_email = f'hostmaster@{self.origin_server}'
139 NscZoneConfig.default_config = NscZoneConfig(
141 refresh=timedelta(hours=8),
142 retry=timedelta(hours=2),
143 expire=timedelta(days=14),
144 min_ttl=timedelta(days=1),
154 def __init__(self) -> None:
158 def load(self, file: Path) -> None:
160 with open(file) as f:
162 assert isinstance(js, dict)
164 self.serial = js['serial']
166 self.hash = js['hash']
167 except FileNotFoundError:
170 def save(self, file: Path) -> None:
171 new_file = Path(str(file) + '.new')
172 with open(new_file, 'w') as f:
174 'serial': self.serial,
177 json.dump(js, f, indent=4, sort_keys=True)
178 new_file.replace(file)
181 class ZoneType(Enum):
189 safe_name: str # For use in file names
191 reverse_for: Optional[IPNetwork]
196 reverse_for: Optional[IPNetwork],
200 self.safe_name = name.replace('/', '@')
201 self.config = NscZoneConfig(**kwargs).finalize()
202 self.reverse_for = reverse_for
204 def process(self) -> None:
208 class NscZonePrimary(NscZone):
214 prev_state: NscZoneState
216 def __init__(self, *args, **kwargs) -> None:
217 super().__init__(*args, **kwargs)
219 self.zone_type = ZoneType.primary
220 self.zone_file = self.nsc.zone_dir / self.safe_name
221 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
223 self.state = NscZoneState()
224 self.prev_state = NscZoneState()
225 self.prev_state.load(self.state_file)
227 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
228 self._min_ttl = int(self.config.min_ttl.total_seconds())
231 def update_soa(self) -> None:
233 soa = dns.rdtypes.ANY.SOA.SOA(
234 RdataClass.IN, RdataType.SOA,
235 mname=conf.origin_server,
236 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
237 serial=self.state.serial,
238 refresh=int(conf.refresh.total_seconds()),
239 retry=int(conf.retry.total_seconds()),
240 expire=int(conf.expire.total_seconds()),
241 minimum=int(conf.min_ttl.total_seconds()),
243 self.zone.delete_rdataset("", RdataType.SOA)
246 def n(self, name: str) -> NscNode:
247 return NscNode(self, name)
249 def __getitem__(self, name: str) -> NscNode:
250 return NscNode(self, name)
252 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
253 n = NscNode(self, name)
254 n.A(*args, reverse=reverse)
257 def zone_header(self) -> str:
259 f'; Zone file for {self.name}\n'
260 + '; Generated by NSC, please do not edit manually.\n'
263 def dump(self, file: Optional[TextIO] = None) -> None:
264 # Could use self.zone.to_file(sys.stdout), but we want better formatting
265 file = file or sys.stdout
266 file.write(self.zone_header())
268 for name, ttl, rec in self.zone.iterate_rdatas():
269 if name == last_name:
273 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
276 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
277 assert isinstance(self.reverse_for, IPv4Network)
278 parts = str(addr).split('.')
279 parts = parts[self.reverse_for.prefixlen // 8:]
280 name = '.'.join(reversed(parts))
281 self.n(name).PTR(ptr_to)
283 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
284 assert isinstance(self.reverse_for, IPv6Network)
285 parts = addr.exploded.replace(':', "")
286 parts = parts[self.reverse_for.prefixlen // 4:]
287 name = '.'.join(reversed(parts))
288 self.n(name).PTR(ptr_to)
290 def gen_hash(self) -> None:
292 sha.update(self.zone_header().encode('us-ascii'))
293 for name, ttl, rec in self.zone.iterate_rdatas():
294 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
295 sha.update(text.encode('us-ascii'))
296 self.state.hash = sha.hexdigest()[:16]
298 def gen_serial(self) -> None:
299 prev = self.prev_state.serial
300 if self.state.hash == self.prev_state.hash and prev > 0:
301 self.state.serial = self.prev_state.serial
303 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
305 self.state.serial = base + 1
307 self.state.serial = prev + 1
308 if prev >= base + 99:
309 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
311 def process(self) -> None:
312 if self.zone_type == ZoneType.primary:
316 def write_zone(self) -> None:
318 new_file = Path(str(self.zone_file) + '.new')
319 with open(new_file, 'w') as f:
321 new_file.replace(self.zone_file)
323 def write_state(self) -> None:
324 self.state.save(self.state_file)
326 def is_changed(self) -> bool:
327 return self.state.serial != self.prev_state.serial
329 def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
330 net = parse_network(net)
331 assert self.reverse_for is not None
332 assert isinstance(self.reverse_for, IPv4Network)
333 assert self.reverse_for.prefixlen % 8 == 0
334 assert isinstance(net, IPv4Network)
335 assert net.subnet_of(self.reverse_for)
336 assert net.prefixlen < self.reverse_for.prefixlen + 8
338 start = int(net.network_address.packed[net.prefixlen // 8])
339 num = 1 << (8 - net.prefixlen % 8)
341 if subdomain is None:
342 subdomain = f'{start}/{net.prefixlen}'
344 for i in range(start, start + num):
345 target = f'{i}.{subdomain}'
346 self[str(i)].CNAME(parse_name(target, relative=True))
348 return self[subdomain]
351 class NscZoneSecondary(NscZone):
352 primary_server: IPAddress
355 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
356 super().__init__(*args, **kwargs)
357 self.zone_type = ZoneType.secondary
358 self.primary_server = primary_server
359 self.secondary_file = self.nsc.secondary_dir / self.safe_name
364 zones: Dict[str, NscZone]
365 default_zone_config: NscZoneConfig
366 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
367 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
372 daemon: 'NscDaemon' # Set by DaemonConfig class
375 directory: str = '.',
376 daemon: Optional['NscDaemon'] = None,
378 self.start_time = datetime.now()
380 self.default_zone_config = NscZoneConfig(**kwargs)
381 self.ipv4_reverse = defaultdict(list)
382 self.ipv6_reverse = defaultdict(list)
384 self.root_dir = Path(directory)
385 self.state_dir = self.root_dir / 'state'
386 self.state_dir.mkdir(parents=True, exist_ok=True)
387 self.zone_dir = self.root_dir / 'zone'
388 self.zone_dir.mkdir(parents=True, exist_ok=True)
389 self.secondary_dir = self.root_dir / 'secondary'
390 self.secondary_dir.mkdir(parents=True, exist_ok=True)
393 from nsconfig.daemon import NscDaemonNull
394 daemon = NscDaemonNull()
399 name: Optional[str] = None,
400 reverse_for: str | IPNetwork | None = None,
401 follow_primary: str | IPAddress | None = None,
402 inherit_config: Optional[NscZoneConfig] = None,
404 if inherit_config is None:
405 inherit_config = self.default_zone_config
407 if reverse_for is not None:
408 if isinstance(reverse_for, str):
409 reverse_for = ip_network(reverse_for, strict=True)
410 name = name or self._reverse_zone_name(reverse_for)
411 assert name is not None
412 assert name not in self.zones
415 if follow_primary is None:
416 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
418 if isinstance(follow_primary, str):
419 follow_primary = ip_address(follow_primary)
420 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
425 def _reverse_zone_name(self, net: IPNetwork) -> str:
426 if isinstance(net, IPv4Network):
427 parts = str(net.network_address).split('.')
428 out = parts[:net.prefixlen // 8]
429 if net.prefixlen % 8 != 0:
430 out.append(parts[len(out)] + '/' + str(net.prefixlen))
431 return '.'.join(reversed(out)) + '.in-addr.arpa'
432 elif isinstance(net, IPv6Network):
433 assert net.prefixlen % 4 == 0
434 nibbles = net.network_address.exploded.replace(':', "")
435 nibbles = nibbles[:net.prefixlen // 4]
436 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
438 raise NotImplementedError()
440 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
441 if isinstance(addr, IPv4Address):
442 self.ipv4_reverse[addr].append(ptr_to)
444 self.ipv6_reverse[addr].append(ptr_to)
446 def dump_reverse(self) -> None:
447 print('### Requests for reverse mappings ###')
448 for ipa4, name in sorted(self.ipv4_reverse.items()):
449 print(f'{ipa4}\t{name}')
450 for ipa6, name in sorted(self.ipv6_reverse.items()):
451 print(f'{ipa6}\t{name}')
453 def fill_reverse(self) -> None:
454 for z in self.zones.values():
455 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
456 if isinstance(z.reverse_for, IPv4Network):
457 for addr4, ptr_list in self.ipv4_reverse.items():
458 if addr4 in z.reverse_for:
459 for ptr_to in ptr_list:
460 z._add_ipv4_reverse(addr4, ptr_to)
462 for addr6, ptr_list in self.ipv6_reverse.items():
463 if addr6 in z.reverse_for:
464 for ptr_to in ptr_list:
465 z._add_ipv6_reverse(addr6, ptr_to)
467 def get_zones(self) -> List[NscZone]:
468 return [self.zones[k] for k in sorted(self.zones.keys())]
470 def process(self) -> None:
472 for z in self.get_zones():