]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Support for secondary zones
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
18 import hashlib
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
20 import json
21 from pathlib import Path
22 import socket
23 import sys
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
25
26
27 IPAddress = IPv4Address | IPv6Address
28 IPNetwork = IPv4Network | IPv6Network
29 IPAddr = str | IPAddress | List[str | IPAddress]
30
31
32 class NscNode:
33     nsc_zone: 'NscZone'
34     name: str
35     node: Node
36     _ttl: int
37
38     def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
39         self.nsc_zone = nsc_zone
40         self.name = name
41         self.node = nsc_zone.zone.find_node(name, create=True)
42         self._ttl = nsc_zone._min_ttl
43
44     def ttl(self, *args, **kwargs) -> Self:
45         if not args and not kwargs:
46             self._ttl = self.nsc_zone._min_ttl
47         else:
48             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
49         return self
50
51     def _add(self, rec: Rdata) -> None:
52         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
53         rds.add(rec, ttl=self._ttl)
54
55     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
56         out = []
57         for a in addrs:
58             if not isinstance(a, list):
59                 a = [a]
60             for b in a:
61                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
62                     out.append(b)
63                 else:
64                     out.append(ip_address(b))
65         return out
66
67     def _parse_name(self, name: str) -> Name:
68         # FIXME: Names with escaped dots
69         if '.' in name:
70             return dns.name.from_text(name)
71         else:
72             return dns.name.from_text(name, origin=None)
73
74     def _parse_names(self, names: str | List[str]) -> List[Name]:
75         if isinstance(names, str):
76             return [self._parse_name(names)]
77         else:
78             return [self._parse_name(n) for n in names]
79
80     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
81         for a in self._parse_addrs(addrs):
82             if isinstance(a, IPv4Address):
83                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
84             else:
85                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
86             if reverse:
87                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
88         return self
89
90     def MX(self, pri: int, name: str) -> Self:
91         self._add(
92             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
93         )
94         return self
95
96     def NS(self, names: str | List[str]) -> Self:
97         # FIXME: Variadic?
98         for name in self._parse_names(names):
99             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
100         return self
101
102     def TXT(self, text: str) -> Self:
103         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
104         return self
105
106     def PTR(self, target: Name | str) -> Self:
107         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
108         return self
109
110     def generic(self, typ: str, text: str) -> Self:
111         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
112         return self
113
114
115 class NscZoneConfig:
116     admin_email: str
117     refresh: timedelta
118     retry: timedelta
119     expire: timedelta
120     min_ttl: timedelta
121     origin_server: str
122
123     default_config: Optional['NscZoneConfig'] = None
124
125     def __init__(self,
126                  admin_email: Optional[str] = None,
127                  refresh: Optional[timedelta] = None,
128                  retry: Optional[timedelta] = None,
129                  expire: Optional[timedelta] = None,
130                  min_ttl: Optional[timedelta] = None,
131                  origin_server: Optional[str] = None,
132                  inherit_config: Optional['NscZoneConfig'] = None,
133                  ) -> None:
134         if inherit_config is None:
135             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
136         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
137         self.refresh = refresh if refresh is not None else inherit_config.refresh
138         self.retry = retry if retry is not None else inherit_config.retry
139         self.expire = expire if expire is not None else inherit_config.expire
140         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
141         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
142
143     def finalize(self) -> Self:
144         if not self.origin_server:
145             self.origin_server = socket.getfqdn()
146         if not self.admin_email:
147             self.admin_email = f'hostmaster@{self.origin_server}'
148         return self
149
150
151 NscZoneConfig.default_config = NscZoneConfig(
152     admin_email="",
153     refresh=timedelta(hours=8),
154     retry=timedelta(hours=2),
155     expire=timedelta(days=14),
156     min_ttl=timedelta(days=1),
157     origin_server="",
158 )
159
160
161 class NscZoneState:
162     serial: int
163     hash: str
164
165     def __init__(self) -> None:
166         self.serial = 0
167         self.hash = 'none'
168
169     def load(self, file: Path) -> None:
170         try:
171             with open(file) as f:
172                 js = json.load(f)
173                 assert isinstance(js, dict)
174                 if 'serial' in js:
175                     self.serial = js['serial']
176                 if 'hash' in js:
177                     self.hash = js['hash']
178         except FileNotFoundError:
179             pass
180
181     def save(self, file: Path) -> None:
182         new_file = Path(str(file) + '.new')
183         with open(new_file, 'w') as f:
184             js = {
185                 'serial': self.serial,
186                 'hash': self.hash,
187             }
188             json.dump(js, f, indent=4, sort_keys=True)
189         new_file.replace(file)
190
191
192 class ZoneType(Enum):
193     primary = auto()
194     secondary = auto()
195
196
197 class NscZone:
198     nsc: 'Nsc'
199     name: str
200     safe_name: str                          # For use in file names
201     zone: Zone
202     _min_ttl: int
203     reverse_for: Optional[IPNetwork]
204     zone_type: ZoneType
205     primary_server: Optional[IPAddress]     # For secondary zones
206     zone_file: Path
207     state_file: Path
208     state: NscZoneState
209     prev_state: NscZoneState
210
211     def __init__(self,
212                  nsc: 'Nsc',
213                  name: Optional[str] = None,
214                  reverse_for: str | IPNetwork | None = None,
215                  secondary_for: str | IPAddress | None = None,
216                  **kwargs) -> None:
217         if reverse_for is not None:
218             if isinstance(reverse_for, str):
219                 reverse_for = ip_network(reverse_for, strict=True)
220             name = name or self._reverse_zone_name(reverse_for)
221         assert name is not None
222
223         if isinstance(secondary_for, str):
224             secondary_for = ip_address(secondary_for)
225
226         self.nsc = nsc
227         self.name = name
228         self.safe_name = name.replace('/', '@')
229         self.config = NscZoneConfig(**kwargs).finalize()
230         self.reverse_for = reverse_for
231         self.primary_server = secondary_for
232
233         if not secondary_for:
234             self.zone_type = ZoneType.primary
235             self.zone_file = nsc.zone_dir / self.safe_name
236             self.state_file = nsc.state_dir / (self.safe_name + '.json')
237
238             self.state = NscZoneState()
239             self.prev_state = NscZoneState()
240             self.prev_state.load(self.state_file)
241
242             self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
243             self._min_ttl = int(self.config.min_ttl.total_seconds())
244             self.update_soa()
245         else:
246             self.zone_type = ZoneType.secondary
247             self.zone_file = nsc.secondary_dir / self.safe_name
248
249     def _reverse_zone_name(self, net: IPNetwork) -> str:
250         if isinstance(net, IPv4Network):
251             parts = str(net.network_address).split('.')
252             out = parts[:net.prefixlen // 8]
253             if net.prefixlen % 8 != 0:
254                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
255             return '.'.join(reversed(out)) + '.in-addr.arpa'
256         elif isinstance(net, IPv6Network):
257             assert net.prefixlen % 4 == 0
258             nibbles = net.network_address.exploded.replace(':', "")
259             nibbles = nibbles[:net.prefixlen // 4]
260             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
261         else:
262             raise NotImplementedError()
263
264     def update_soa(self) -> None:
265         assert self.zone_type == ZoneType.primary
266         conf = self.config
267         soa = dns.rdtypes.ANY.SOA.SOA(
268             RdataClass.IN, RdataType.SOA,
269             mname=conf.origin_server,
270             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
271             serial=self.state.serial,
272             refresh=int(conf.refresh.total_seconds()),
273             retry=int(conf.retry.total_seconds()),
274             expire=int(conf.expire.total_seconds()),
275             minimum=int(conf.min_ttl.total_seconds()),
276         )
277         self.zone.delete_rdataset("", RdataType.SOA)
278         self[""]._add(soa)
279
280     def n(self, name: str) -> NscNode:
281         return NscNode(self, name)
282
283     def __getitem__(self, name: str) -> NscNode:
284         return NscNode(self, name)
285
286     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
287         n = NscNode(self, name)
288         n.A(*args, reverse=reverse)
289         return n
290
291     def dump(self, file: Optional[TextIO] = None) -> None:
292         # Could use self.zone.to_file(sys.stdout), but we want better formatting
293         file = file or sys.stdout
294         file.write(f'; Zone file for {self.name}\n\n')
295         last_name = None
296         for name, ttl, rec in self.zone.iterate_rdatas():
297             if name == last_name:
298                 print_name = ""
299             else:
300                 print_name = name
301             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
302             last_name = name
303
304     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
305         # Called only for addresses from this reverse network
306         assert self.reverse_for is not None
307         parts = str(addr).split('.')
308         parts = parts[self.reverse_for.prefixlen // 8:]
309         name = '.'.join(reversed(parts))
310         self.n(name).PTR(ptr_to)
311
312     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
313         # Called only for addresses from this reverse network
314         assert self.reverse_for is not None
315         parts = addr.exploded.replace(':', "")
316         parts = parts[self.reverse_for.prefixlen // 4:]
317         name = '.'.join(reversed(parts))
318         self.n(name).PTR(ptr_to)
319
320     def gen_hash(self) -> None:
321         sha = hashlib.sha1()
322         for name, ttl, rec in self.zone.iterate_rdatas():
323             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
324             sha.update(text.encode('us-ascii'))
325         self.state.hash = sha.hexdigest()[:16]
326
327     def gen_serial(self) -> None:
328         prev = self.prev_state.serial
329         if self.state.hash == self.prev_state.hash and prev > 0:
330             self.state.serial = self.prev_state.serial
331         else:
332             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
333             if prev <= base:
334                 self.state.serial = base + 1
335             else:
336                 self.state.serial = prev + 1
337                 if prev >= base + 99:
338                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
339
340     def process(self) -> None:
341         if self.zone_type == ZoneType.primary:
342             self.gen_hash()
343             self.gen_serial()
344
345     def write_zone(self) -> None:
346         assert self.zone_type == ZoneType.primary
347         self.update_soa()
348         new_file = Path(str(self.zone_file) + '.new')
349         with open(new_file, 'w') as f:
350             self.dump(file=f)
351         new_file.replace(self.zone_file)
352
353     def write_state(self) -> None:
354         assert self.zone_type == ZoneType.primary
355         self.state.save(self.state_file)
356
357
358 class Nsc:
359     start_time: datetime
360     zones: Dict[str, NscZone]
361     default_zone_config: NscZoneConfig
362     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
363     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
364     root_dir: Path
365     state_dir: Path
366     zone_dir: Path
367     secondary_dir: Path
368
369     def __init__(self, directory: str = '.', **kwargs) -> None:
370         self.start_time = datetime.now()
371         self.zones = {}
372         self.default_zone_config = NscZoneConfig(**kwargs)
373         self.ipv4_reverse = defaultdict(list)
374         self.ipv6_reverse = defaultdict(list)
375
376         self.root_dir = Path(directory)
377         self.state_dir = self.root_dir / 'state'
378         self.state_dir.mkdir(parents=True, exist_ok=True)
379         self.zone_dir = self.root_dir / 'zone'
380         self.zone_dir.mkdir(parents=True, exist_ok=True)
381         self.secondary_dir = self.root_dir / 'secondary'
382         self.secondary_dir.mkdir(parents=True, exist_ok=True)
383
384     def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
385         if inherit_config is None:
386             inherit_config = self.default_zone_config
387         z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
388         assert z.name not in self.zones
389         self.zones[z.name] = z
390         return z
391
392     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
393         if isinstance(addr, IPv4Address):
394             self.ipv4_reverse[addr].append(ptr_to)
395         else:
396             self.ipv6_reverse[addr].append(ptr_to)
397
398     def dump_reverse(self) -> None:
399         print('### Requests for reverse mappings ###')
400         for ipa4, name in sorted(self.ipv4_reverse.items()):
401             print(f'{ipa4}\t{name}')
402         for ipa6, name in sorted(self.ipv6_reverse.items()):
403             print(f'{ipa6}\t{name}')
404
405     def fill_reverse(self) -> None:
406         for z in self.zones.values():
407             if z.zone_type == ZoneType.primary and z.reverse_for is not None:
408                 if isinstance(z.reverse_for, IPv4Network):
409                     for addr4, ptr_list in self.ipv4_reverse.items():
410                         if addr4 in z.reverse_for:
411                             for ptr_to in ptr_list:
412                                 z._add_ipv4_reverse(addr4, ptr_to)
413                 else:
414                     for addr6, ptr_list in self.ipv6_reverse.items():
415                         if addr6 in z.reverse_for:
416                             for ptr_to in ptr_list:
417                                 z._add_ipv6_reverse(addr6, ptr_to)
418
419     def get_zones(self) -> List[NscZone]:
420         return [self.zones[k] for k in sorted(self.zones.keys())]
421
422     def process(self) -> None:
423         self.fill_reverse()
424         for z in self.get_zones():
425             z.process()