1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 from pathlib import Path
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
27 IPAddress = IPv4Address | IPv6Address
28 IPNetwork = IPv4Network | IPv6Network
29 IPAddr = str | IPAddress | List[str | IPAddress]
38 def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
39 self.nsc_zone = nsc_zone
41 self.node = nsc_zone.zone.find_node(name, create=True)
42 self._ttl = nsc_zone._min_ttl
44 def ttl(self, *args, **kwargs) -> Self:
45 if not args and not kwargs:
46 self._ttl = self.nsc_zone._min_ttl
48 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
51 def _add(self, rec: Rdata) -> None:
52 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
53 rds.add(rec, ttl=self._ttl)
55 def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
58 if not isinstance(a, list):
61 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
64 out.append(ip_address(b))
67 def _parse_name(self, name: str) -> Name:
68 # FIXME: Names with escaped dots
70 return dns.name.from_text(name)
72 return dns.name.from_text(name, origin=None)
74 def _parse_names(self, names: str | List[str]) -> List[Name]:
75 if isinstance(names, str):
76 return [self._parse_name(names)]
78 return [self._parse_name(n) for n in names]
80 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
81 for a in self._parse_addrs(addrs):
82 if isinstance(a, IPv4Address):
83 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
85 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
87 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
90 def MX(self, pri: int, name: str) -> Self:
92 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
96 def NS(self, names: str | List[str]) -> Self:
98 for name in self._parse_names(names):
99 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
102 def TXT(self, text: str) -> Self:
103 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
106 def PTR(self, target: Name | str) -> Self:
107 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
110 def generic(self, typ: str, text: str) -> Self:
111 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
123 default_config: Optional['NscZoneConfig'] = None
126 admin_email: Optional[str] = None,
127 refresh: Optional[timedelta] = None,
128 retry: Optional[timedelta] = None,
129 expire: Optional[timedelta] = None,
130 min_ttl: Optional[timedelta] = None,
131 origin_server: Optional[str] = None,
132 inherit_config: Optional['NscZoneConfig'] = None,
134 if inherit_config is None:
135 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
136 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
137 self.refresh = refresh if refresh is not None else inherit_config.refresh
138 self.retry = retry if retry is not None else inherit_config.retry
139 self.expire = expire if expire is not None else inherit_config.expire
140 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
141 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
143 def finalize(self) -> Self:
144 if not self.origin_server:
145 self.origin_server = socket.getfqdn()
146 if not self.admin_email:
147 self.admin_email = f'hostmaster@{self.origin_server}'
151 NscZoneConfig.default_config = NscZoneConfig(
153 refresh=timedelta(hours=8),
154 retry=timedelta(hours=2),
155 expire=timedelta(days=14),
156 min_ttl=timedelta(days=1),
165 def __init__(self) -> None:
169 def load(self, file: Path) -> None:
171 with open(file) as f:
173 assert isinstance(js, dict)
175 self.serial = js['serial']
177 self.hash = js['hash']
178 except FileNotFoundError:
181 def save(self, file: Path) -> None:
182 new_file = Path(str(file) + '.new')
183 with open(new_file, 'w') as f:
185 'serial': self.serial,
188 json.dump(js, f, indent=4, sort_keys=True)
189 new_file.replace(file)
192 class ZoneType(Enum):
200 safe_name: str # For use in file names
203 reverse_for: Optional[IPNetwork]
205 primary_server: Optional[IPAddress] # For secondary zones
209 prev_state: NscZoneState
213 name: Optional[str] = None,
214 reverse_for: str | IPNetwork | None = None,
215 secondary_for: str | IPAddress | None = None,
217 if reverse_for is not None:
218 if isinstance(reverse_for, str):
219 reverse_for = ip_network(reverse_for, strict=True)
220 name = name or self._reverse_zone_name(reverse_for)
221 assert name is not None
223 if isinstance(secondary_for, str):
224 secondary_for = ip_address(secondary_for)
228 self.safe_name = name.replace('/', '@')
229 self.config = NscZoneConfig(**kwargs).finalize()
230 self.reverse_for = reverse_for
231 self.primary_server = secondary_for
233 if not secondary_for:
234 self.zone_type = ZoneType.primary
235 self.zone_file = nsc.zone_dir / self.safe_name
236 self.state_file = nsc.state_dir / (self.safe_name + '.json')
238 self.state = NscZoneState()
239 self.prev_state = NscZoneState()
240 self.prev_state.load(self.state_file)
242 self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
243 self._min_ttl = int(self.config.min_ttl.total_seconds())
246 self.zone_type = ZoneType.secondary
247 self.zone_file = nsc.secondary_dir / self.safe_name
249 def _reverse_zone_name(self, net: IPNetwork) -> str:
250 if isinstance(net, IPv4Network):
251 parts = str(net.network_address).split('.')
252 out = parts[:net.prefixlen // 8]
253 if net.prefixlen % 8 != 0:
254 out.append(parts[len(out)] + '/' + str(net.prefixlen))
255 return '.'.join(reversed(out)) + '.in-addr.arpa'
256 elif isinstance(net, IPv6Network):
257 assert net.prefixlen % 4 == 0
258 nibbles = net.network_address.exploded.replace(':', "")
259 nibbles = nibbles[:net.prefixlen // 4]
260 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
262 raise NotImplementedError()
264 def update_soa(self) -> None:
265 assert self.zone_type == ZoneType.primary
267 soa = dns.rdtypes.ANY.SOA.SOA(
268 RdataClass.IN, RdataType.SOA,
269 mname=conf.origin_server,
270 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
271 serial=self.state.serial,
272 refresh=int(conf.refresh.total_seconds()),
273 retry=int(conf.retry.total_seconds()),
274 expire=int(conf.expire.total_seconds()),
275 minimum=int(conf.min_ttl.total_seconds()),
277 self.zone.delete_rdataset("", RdataType.SOA)
280 def n(self, name: str) -> NscNode:
281 return NscNode(self, name)
283 def __getitem__(self, name: str) -> NscNode:
284 return NscNode(self, name)
286 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
287 n = NscNode(self, name)
288 n.A(*args, reverse=reverse)
291 def dump(self, file: Optional[TextIO] = None) -> None:
292 # Could use self.zone.to_file(sys.stdout), but we want better formatting
293 file = file or sys.stdout
294 file.write(f'; Zone file for {self.name}\n\n')
296 for name, ttl, rec in self.zone.iterate_rdatas():
297 if name == last_name:
301 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
304 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
305 # Called only for addresses from this reverse network
306 assert self.reverse_for is not None
307 parts = str(addr).split('.')
308 parts = parts[self.reverse_for.prefixlen // 8:]
309 name = '.'.join(reversed(parts))
310 self.n(name).PTR(ptr_to)
312 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
313 # Called only for addresses from this reverse network
314 assert self.reverse_for is not None
315 parts = addr.exploded.replace(':', "")
316 parts = parts[self.reverse_for.prefixlen // 4:]
317 name = '.'.join(reversed(parts))
318 self.n(name).PTR(ptr_to)
320 def gen_hash(self) -> None:
322 for name, ttl, rec in self.zone.iterate_rdatas():
323 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
324 sha.update(text.encode('us-ascii'))
325 self.state.hash = sha.hexdigest()[:16]
327 def gen_serial(self) -> None:
328 prev = self.prev_state.serial
329 if self.state.hash == self.prev_state.hash and prev > 0:
330 self.state.serial = self.prev_state.serial
332 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
334 self.state.serial = base + 1
336 self.state.serial = prev + 1
337 if prev >= base + 99:
338 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
340 def process(self) -> None:
341 if self.zone_type == ZoneType.primary:
345 def write_zone(self) -> None:
346 assert self.zone_type == ZoneType.primary
348 new_file = Path(str(self.zone_file) + '.new')
349 with open(new_file, 'w') as f:
351 new_file.replace(self.zone_file)
353 def write_state(self) -> None:
354 assert self.zone_type == ZoneType.primary
355 self.state.save(self.state_file)
360 zones: Dict[str, NscZone]
361 default_zone_config: NscZoneConfig
362 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
363 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
369 def __init__(self, directory: str = '.', **kwargs) -> None:
370 self.start_time = datetime.now()
372 self.default_zone_config = NscZoneConfig(**kwargs)
373 self.ipv4_reverse = defaultdict(list)
374 self.ipv6_reverse = defaultdict(list)
376 self.root_dir = Path(directory)
377 self.state_dir = self.root_dir / 'state'
378 self.state_dir.mkdir(parents=True, exist_ok=True)
379 self.zone_dir = self.root_dir / 'zone'
380 self.zone_dir.mkdir(parents=True, exist_ok=True)
381 self.secondary_dir = self.root_dir / 'secondary'
382 self.secondary_dir.mkdir(parents=True, exist_ok=True)
384 def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
385 if inherit_config is None:
386 inherit_config = self.default_zone_config
387 z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
388 assert z.name not in self.zones
389 self.zones[z.name] = z
392 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
393 if isinstance(addr, IPv4Address):
394 self.ipv4_reverse[addr].append(ptr_to)
396 self.ipv6_reverse[addr].append(ptr_to)
398 def dump_reverse(self) -> None:
399 print('### Requests for reverse mappings ###')
400 for ipa4, name in sorted(self.ipv4_reverse.items()):
401 print(f'{ipa4}\t{name}')
402 for ipa6, name in sorted(self.ipv6_reverse.items()):
403 print(f'{ipa6}\t{name}')
405 def fill_reverse(self) -> None:
406 for z in self.zones.values():
407 if z.zone_type == ZoneType.primary and z.reverse_for is not None:
408 if isinstance(z.reverse_for, IPv4Network):
409 for addr4, ptr_list in self.ipv4_reverse.items():
410 if addr4 in z.reverse_for:
411 for ptr_to in ptr_list:
412 z._add_ipv4_reverse(addr4, ptr_to)
414 for addr6, ptr_list in self.ipv6_reverse.items():
415 if addr6 in z.reverse_for:
416 for ptr_to in ptr_list:
417 z._add_ipv6_reverse(addr6, ptr_to)
419 def get_zones(self) -> List[NscZone]:
420 return [self.zones[k] for k in sorted(self.zones.keys())]
422 def process(self) -> None:
424 for z in self.get_zones():