1 from collections import defaultdict
2 from datetime import datetime, timedelta
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
18 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
20 from pathlib import Path
23 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
26 IPAddress = IPv4Address | IPv6Address
27 IPNetwork = IPv4Network | IPv6Network
28 IPAddr = str | IPAddress | List[str | IPAddress]
37 def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
38 self.nsc_zone = nsc_zone
40 self.node = nsc_zone.zone.find_node(name, create=True)
41 self._ttl = nsc_zone._min_ttl
43 def ttl(self, *args, **kwargs) -> Self:
44 if not args and not kwargs:
45 self._ttl = self.nsc_zone._min_ttl
47 self._ttl = int(timedelta(*args, **kwargs).total_seconds())
50 def _add(self, rec: Rdata) -> None:
51 rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
52 rds.add(rec, ttl=self._ttl)
54 def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
57 if not isinstance(a, list):
60 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
63 out.append(ip_address(b))
66 def _parse_name(self, name: str) -> Name:
67 # FIXME: Names with escaped dots
69 return dns.name.from_text(name)
71 return dns.name.from_text(name, origin=None)
73 def _parse_names(self, names: str | List[str]) -> List[Name]:
74 if isinstance(names, str):
75 return [self._parse_name(names)]
77 return [self._parse_name(n) for n in names]
79 def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
80 for a in self._parse_addrs(addrs):
81 if isinstance(a, IPv4Address):
82 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
84 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
86 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
89 def MX(self, pri: int, name: str) -> Self:
91 dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
95 def NS(self, names: str | List[str]) -> Self:
97 for name in self._parse_names(names):
98 self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
101 def TXT(self, text: str) -> Self:
102 self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
105 def PTR(self, target: Name | str) -> Self:
106 self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
109 def generic(self, typ: str, text: str) -> Self:
110 self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
122 default_config: Optional['NscZoneConfig'] = None
125 admin_email: Optional[str] = None,
126 refresh: Optional[timedelta] = None,
127 retry: Optional[timedelta] = None,
128 expire: Optional[timedelta] = None,
129 min_ttl: Optional[timedelta] = None,
130 origin_server: Optional[str] = None,
131 inherit_config: Optional['NscZoneConfig'] = None,
133 if inherit_config is None:
134 inherit_config = NscZoneConfig.default_config or self # to satisfy the type checker
135 self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
136 self.refresh = refresh if refresh is not None else inherit_config.refresh
137 self.retry = retry if retry is not None else inherit_config.retry
138 self.expire = expire if expire is not None else inherit_config.expire
139 self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
140 self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
142 def finalize(self) -> Self:
143 if not self.origin_server:
144 self.origin_server = socket.getfqdn()
145 if not self.admin_email:
146 self.admin_email = f'hostmaster@{self.origin_server}'
150 NscZoneConfig.default_config = NscZoneConfig(
152 refresh=timedelta(hours=8),
153 retry=timedelta(hours=2),
154 expire=timedelta(days=14),
155 min_ttl=timedelta(days=1),
164 def __init__(self) -> None:
168 def load(self, file: Path) -> None:
170 with open(file) as f:
172 assert isinstance(js, dict)
174 self.serial = js['serial']
176 self.hash = js['hash']
177 except FileNotFoundError:
180 def save(self, file: Path) -> None:
181 new_file = Path(str(file) + '.new')
182 with open(new_file, 'w') as f:
184 'serial': self.serial,
187 json.dump(js, f, indent=4, sort_keys=True)
188 new_file.replace(file)
194 safe_name: str # For use in file names
197 reverse_for: Optional[IPNetwork]
201 prev_state: NscZoneState
203 def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
206 self.safe_name = name.replace('/', '@')
207 self.config = NscZoneConfig(**kwargs).finalize()
208 self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
209 self._min_ttl = int(self.config.min_ttl.total_seconds())
210 self.reverse_for = reverse_for
212 self.zone_file = nsc.zone_dir / self.safe_name
213 self.state_file = nsc.state_dir / (self.safe_name + '.json')
214 self.state = NscZoneState()
215 self.prev_state = NscZoneState()
216 self.prev_state.load(self.state_file)
220 def update_soa(self) -> None:
222 soa = dns.rdtypes.ANY.SOA.SOA(
223 RdataClass.IN, RdataType.SOA,
224 mname=conf.origin_server,
225 rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
226 serial=self.state.serial,
227 refresh=int(conf.refresh.total_seconds()),
228 retry=int(conf.retry.total_seconds()),
229 expire=int(conf.expire.total_seconds()),
230 minimum=int(conf.min_ttl.total_seconds()),
232 self.zone.delete_rdataset("", RdataType.SOA)
235 def n(self, name: str) -> NscNode:
236 return NscNode(self, name)
238 def __getitem__(self, name: str) -> NscNode:
239 return NscNode(self, name)
241 def host(self, name: str, *args, reverse: bool = True) -> NscNode:
242 n = NscNode(self, name)
243 n.A(*args, reverse=reverse)
246 def dump(self, file: Optional[TextIO] = None) -> None:
247 # Could use self.zone.to_file(sys.stdout), but we want better formatting
248 file = file or sys.stdout
249 file.write(f'; Zone file for {self.name}\n\n')
251 for name, ttl, rec in self.zone.iterate_rdatas():
252 if name == last_name:
256 file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
259 def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
260 # Called only for addresses from this reverse network
261 assert self.reverse_for is not None
262 parts = str(addr).split('.')
263 parts = parts[self.reverse_for.prefixlen // 8:]
264 name = '.'.join(reversed(parts))
265 self.n(name).PTR(ptr_to)
267 def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
268 # Called only for addresses from this reverse network
269 assert self.reverse_for is not None
270 parts = addr.exploded.replace(':', "")
271 parts = parts[self.reverse_for.prefixlen // 4:]
272 name = '.'.join(reversed(parts))
273 self.n(name).PTR(ptr_to)
275 def gen_hash(self) -> None:
277 for name, ttl, rec in self.zone.iterate_rdatas():
278 text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
279 sha.update(text.encode('us-ascii'))
280 self.state.hash = sha.hexdigest()[:16]
282 def gen_serial(self) -> None:
283 prev = self.prev_state.serial
284 if self.state.hash == self.prev_state.hash and prev > 0:
285 self.state.serial = self.prev_state.serial
287 base = int(self.nsc.start_time.strftime('%Y%m%d00'))
289 self.state.serial = base + 1
291 self.state.serial = prev + 1
292 if prev >= base + 99:
293 print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
295 def process(self) -> None:
299 def write_zone(self) -> None:
301 new_file = Path(str(self.zone_file) + '.new')
302 with open(new_file, 'w') as f:
304 new_file.replace(self.zone_file)
306 def write_state(self) -> None:
307 self.state.save(self.state_file)
312 zones: Dict[str, NscZone]
313 default_zone_config: NscZoneConfig
314 ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
315 ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
320 def __init__(self, directory: str = '.', **kwargs) -> None:
321 self.start_time = datetime.now()
323 self.default_zone_config = NscZoneConfig(**kwargs)
324 self.ipv4_reverse = defaultdict(list)
325 self.ipv6_reverse = defaultdict(list)
327 self.root_dir = Path(directory)
328 self.state_dir = self.root_dir / 'state'
329 self.state_dir.mkdir(parents=True, exist_ok=True)
330 self.zone_dir = self.root_dir / 'zone'
331 self.zone_dir.mkdir(parents=True, exist_ok=True)
333 def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
334 if inherit_config is None:
335 inherit_config = self.default_zone_config
336 z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
337 assert z.name not in self.zones
338 self.zones[z.name] = z
341 def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
342 if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
343 net = ip_network(net, strict=True)
344 name = name or self._reverse_zone_name(net)
345 return self.add_zone(name, reverse_for=net, **kwargs)
347 def _reverse_zone_name(self, net: IPNetwork) -> str:
348 if isinstance(net, IPv4Network):
349 parts = str(net.network_address).split('.')
350 out = parts[:net.prefixlen // 8]
351 if net.prefixlen % 8 != 0:
352 out.append(parts[len(out)] + '/' + str(net.prefixlen))
353 return '.'.join(reversed(out)) + '.in-addr.arpa'
354 elif isinstance(net, IPv6Network):
355 assert net.prefixlen % 4 == 0
356 nibbles = net.network_address.exploded.replace(':', "")
357 nibbles = nibbles[:net.prefixlen // 4]
358 return '.'.join(reversed(nibbles)) + '.ip6.arpa'
360 raise NotImplementedError()
362 def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
363 if isinstance(addr, IPv4Address):
364 self.ipv4_reverse[addr].append(ptr_to)
366 self.ipv6_reverse[addr].append(ptr_to)
368 def dump_reverse(self) -> None:
369 print('### Requests for reverse mappings ###')
370 for ipa4, name in sorted(self.ipv4_reverse.items()):
371 print(f'{ipa4}\t{name}')
372 for ipa6, name in sorted(self.ipv6_reverse.items()):
373 print(f'{ipa6}\t{name}')
375 def fill_reverse(self) -> None:
376 for z in self.zones.values():
377 if z.reverse_for is not None:
378 if isinstance(z.reverse_for, IPv4Network):
379 for addr4, ptr_list in self.ipv4_reverse.items():
380 if addr4 in z.reverse_for:
381 for ptr_to in ptr_list:
382 z._add_ipv4_reverse(addr4, ptr_to)
384 for addr6, ptr_list in self.ipv6_reverse.items():
385 if addr6 in z.reverse_for:
386 for ptr_to in ptr_list:
387 z._add_ipv6_reverse(addr6, ptr_to)
389 def get_zones(self) -> List[NscZone]:
390 return [self.zones[k] for k in sorted(self.zones.keys())]
392 def process(self) -> None:
394 for z in self.get_zones():