1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 from pathlib import Path
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
26 from nsconfig.util import flatten_list
30 from nsconfig.daemon import NscDaemon
33 IPAddress = IPv4Address | IPv6Address
34 IPNetwork = IPv4Network | IPv6Network
35 IPAddr = str | IPAddress | List[str | IPAddress]
39 nsc_zone: 'NscZonePrimary'
44 def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
45 self.nsc_zone = nsc_zone
47 self.node = nsc_zone.zone.find_node(name, create=True)
48 self._ttl = nsc_zone._min_ttl
50 def ttl(self, *args, **kwargs) -> Self:
51 if not args and not kwargs:
52 self._ttl = self.nsc_zone._min_ttl
54 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
57 def _add(self, rec: Rdata) -> None:
58 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
59 rds.add(rec, ttl=self._ttl)
61 def _parse_addr(self, addr: IPAddr | str) -> IPAddress:
62 if isinstance(addr, IPv4Address) or isinstance(addr, IPv6Address):
64 elif isinstance(addr, str):
65 return ip_address(addr)
67 raise ValueError('Cannot parse IP address')
69 def _parse_name(self, name: str) -> Name:
70 # FIXME: Names with escaped dots
72 return dns.name.from_text(name)
74 return dns.name.from_text(name, origin=None)
76 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
77 for a in map(self._parse_addr, flatten_list(addrs)):
78 if isinstance(a, IPv4Address):
79 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
81 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
83 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
86 def MX(self, pri: int, name: str) -> Self:
88 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
92 def NS(self, *names: str | List[str]) -> Self:
93 for name in map(self._parse_name, flatten_list(names)):
94 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
97 def TXT(self, *text: str | List[str]) -> Self:
98 for txt in flatten_list(text):
99 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
102 def PTR(self, target: Name | str) -> Self:
103 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
106 def generic(self, typ: str, text: str) -> Self:
107 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
119 default_config: Optional['NscZoneConfig'] = None
122 admin_email: Optional[str] = None,
123 refresh: Optional[timedelta] = None,
124 retry: Optional[timedelta] = None,
125 expire: Optional[timedelta] = None,
126 min_ttl: Optional[timedelta] = None,
127 origin_server: Optional[str] = None,
128 inherit_config: Optional['NscZoneConfig'] = None,
130 if inherit_config is None:
131 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
132 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
133 self.refresh = refresh if refresh is not None else inherit_config.refresh
134 self.retry = retry if retry is not None else inherit_config.retry
135 self.expire = expire if expire is not None else inherit_config.expire
136 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
137 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
139 def finalize(self) -> Self:
140 if not self.origin_server:
141 self.origin_server = socket.getfqdn()
142 if not self.admin_email:
143 self.admin_email = f'hostmaster@{self.origin_server}'
147 NscZoneConfig.default_config = NscZoneConfig(
149 refresh=timedelta(hours=8),
150 retry=timedelta(hours=2),
151 expire=timedelta(days=14),
152 min_ttl=timedelta(days=1),
161 def __init__(self) -> None:
165 def load(self, file: Path) -> None:
167 with open(file) as f:
169 assert isinstance(js, dict)
171 self.serial = js['serial']
173 self.hash = js['hash']
174 except FileNotFoundError:
177 def save(self, file: Path) -> None:
178 new_file = Path(str(file) + '.new')
179 with open(new_file, 'w') as f:
181 'serial': self.serial,
184 json.dump(js, f, indent=4, sort_keys=True)
185 new_file.replace(file)
188 class ZoneType(Enum):
196 safe_name: str # For use in file names
198 reverse_for: Optional[IPNetwork]
203 reverse_for: Optional[IPNetwork],
207 self.safe_name = name.replace('/', '@')
208 self.config = NscZoneConfig(**kwargs).finalize()
209 self.reverse_for = reverse_for
211 def process(self) -> None:
215 class NscZonePrimary(NscZone):
221 prev_state: NscZoneState
223 def __init__(self, *args, **kwargs) -> None:
224 super().__init__(*args, **kwargs)
226 self.zone_type = ZoneType.primary
227 self.zone_file = self.nsc.zone_dir / self.safe_name
228 self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
230 self.state = NscZoneState()
231 self.prev_state = NscZoneState()
232 self.prev_state.load(self.state_file)
234 self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
235 self._min_ttl = int(self.config.min_ttl.total_seconds())
238 def update_soa(self) -> None:
240 soa = dns.rdtypes.ANY.SOA.SOA(
241 RdataClass.IN, RdataType.SOA,
242 mname=conf.origin_server,
243 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
244 serial=self.state.serial,
245 refresh=int(conf.refresh.total_seconds()),
246 retry=int(conf.retry.total_seconds()),
247 expire=int(conf.expire.total_seconds()),
248 minimum=int(conf.min_ttl.total_seconds()),
250 self.zone.delete_rdataset("", RdataType.SOA)
253 def n(self, name: str) -> NscNode:
254 return NscNode(self, name)
256 def __getitem__(self, name: str) -> NscNode:
257 return NscNode(self, name)
259 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
260 n = NscNode(self, name)
261 n.A(*args, reverse=reverse)
264 def dump(self, file: Optional[TextIO] = None) -> None:
265 # Could use self.zone.to_file(sys.stdout), but we want better formatting
266 file = file or sys.stdout
267 file.write(f'; Zone file for {self.name}\n\n')
269 for name, ttl, rec in self.zone.iterate_rdatas():
270 if name == last_name:
274 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
277 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
278 assert isinstance(self.reverse_for, IPv4Network)
279 parts = str(addr).split('.')
280 parts = parts[self.reverse_for.prefixlen // 8:]
281 name = '.'.join(reversed(parts))
282 self.n(name).PTR(ptr_to)
284 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
285 assert isinstance(self.reverse_for, IPv6Network)
286 parts = addr.exploded.replace(':', "")
287 parts = parts[self.reverse_for.prefixlen // 4:]
288 name = '.'.join(reversed(parts))
289 self.n(name).PTR(ptr_to)
291 def gen_hash(self) -> None:
293 for name, ttl, rec in self.zone.iterate_rdatas():
294 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
295 sha.update(text.encode('us-ascii'))
296 self.state.hash = sha.hexdigest()[:16]
298 def gen_serial(self) -> None:
299 prev = self.prev_state.serial
300 if self.state.hash == self.prev_state.hash and prev > 0:
301 self.state.serial = self.prev_state.serial
303 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
305 self.state.serial = base + 1
307 self.state.serial = prev + 1
308 if prev >= base + 99:
309 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
311 def process(self) -> None:
312 if self.zone_type == ZoneType.primary:
316 def write_zone(self) -> None:
318 new_file = Path(str(self.zone_file) + '.new')
319 with open(new_file, 'w') as f:
321 new_file.replace(self.zone_file)
323 def write_state(self) -> None:
324 self.state.save(self.state_file)
326 def is_changed(self) -> bool:
327 return self.state.serial != self.prev_state.serial
330 class NscZoneSecondary(NscZone):
331 primary_server: IPAddress
334 def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
335 super().__init__(*args, **kwargs)
336 self.zone_type = ZoneType.secondary
337 self.primary_server = primary_server
338 self.secondary_file = self.nsc.secondary_dir / self.safe_name
343 zones: Dict[str, NscZone]
344 default_zone_config: NscZoneConfig
345 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
346 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
351 daemon: 'NscDaemon' # Set by DaemonConfig class
354 directory: str = '.',
355 daemon: Optional['NscDaemon'] = None,
357 self.start_time = datetime.now()
359 self.default_zone_config = NscZoneConfig(**kwargs)
360 self.ipv4_reverse = defaultdict(list)
361 self.ipv6_reverse = defaultdict(list)
363 self.root_dir = Path(directory)
364 self.state_dir = self.root_dir / 'state'
365 self.state_dir.mkdir(parents=True, exist_ok=True)
366 self.zone_dir = self.root_dir / 'zone'
367 self.zone_dir.mkdir(parents=True, exist_ok=True)
368 self.secondary_dir = self.root_dir / 'secondary'
369 self.secondary_dir.mkdir(parents=True, exist_ok=True)
372 from nsconfig.daemon import NscDaemonNull
373 daemon = NscDaemonNull()
378 name: Optional[str] = None,
379 reverse_for: str | IPNetwork | None = None,
380 follow_primary: str | IPAddress | None = None,
381 inherit_config: Optional[NscZoneConfig] = None,
383 if inherit_config is None:
384 inherit_config = self.default_zone_config
386 if reverse_for is not None:
387 if isinstance(reverse_for, str):
388 reverse_for = ip_network(reverse_for, strict=True)
389 name = name or self._reverse_zone_name(reverse_for)
390 assert name is not None
391 assert name not in self.zones
394 if follow_primary is None:
395 z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
397 if isinstance(follow_primary, str):
398 follow_primary = ip_address(follow_primary)
399 z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
404 def _reverse_zone_name(self, net: IPNetwork) -> str:
405 if isinstance(net, IPv4Network):
406 parts = str(net.network_address).split('.')
407 out = parts[:net.prefixlen // 8]
408 if net.prefixlen % 8 != 0:
409 out.append(parts[len(out)] + '/' + str(net.prefixlen))
410 return '.'.join(reversed(out)) + '.in-addr.arpa'
411 elif isinstance(net, IPv6Network):
412 assert net.prefixlen % 4 == 0
413 nibbles = net.network_address.exploded.replace(':', "")
414 nibbles = nibbles[:net.prefixlen // 4]
415 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
417 raise NotImplementedError()
419 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
420 if isinstance(addr, IPv4Address):
421 self.ipv4_reverse[addr].append(ptr_to)
423 self.ipv6_reverse[addr].append(ptr_to)
425 def dump_reverse(self) -> None:
426 print('### Requests for reverse mappings ###')
427 for ipa4, name in sorted(self.ipv4_reverse.items()):
428 print(f'{ipa4}\t{name}')
429 for ipa6, name in sorted(self.ipv6_reverse.items()):
430 print(f'{ipa6}\t{name}')
432 def fill_reverse(self) -> None:
433 for z in self.zones.values():
434 if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
435 if isinstance(z.reverse_for, IPv4Network):
436 for addr4, ptr_list in self.ipv4_reverse.items():
437 if addr4 in z.reverse_for:
438 for ptr_to in ptr_list:
439 z._add_ipv4_reverse(addr4, ptr_to)
441 for addr6, ptr_list in self.ipv6_reverse.items():
442 if addr6 in z.reverse_for:
443 for ptr_to in ptr_list:
444 z._add_ipv6_reverse(addr6, ptr_to)
446 def get_zones(self) -> List[NscZone]:
447 return [self.zones[k] for k in sorted(self.zones.keys())]
449 def process(self) -> None:
451 for z in self.get_zones():