]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Refactor primary/secondary to use a class hierarchy
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
18 import hashlib
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
20 import json
21 from pathlib import Path
22 import socket
23 import sys
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
25
26
27 IPAddress = IPv4Address | IPv6Address
28 IPNetwork = IPv4Network | IPv6Network
29 IPAddr = str | IPAddress | List[str | IPAddress]
30
31
32 class NscNode:
33     nsc_zone: 'NscZonePrimary'
34     name: str
35     node: Node
36     _ttl: int
37
38     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
39         self.nsc_zone = nsc_zone
40         self.name = name
41         self.node = nsc_zone.zone.find_node(name, create=True)
42         self._ttl = nsc_zone._min_ttl
43
44     def ttl(self, *args, **kwargs) -> Self:
45         if not args and not kwargs:
46             self._ttl = self.nsc_zone._min_ttl
47         else:
48             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
49         return self
50
51     def _add(self, rec: Rdata) -> None:
52         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
53         rds.add(rec, ttl=self._ttl)
54
55     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
56         out = []
57         for a in addrs:
58             if not isinstance(a, list):
59                 a = [a]
60             for b in a:
61                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
62                     out.append(b)
63                 else:
64                     out.append(ip_address(b))
65         return out
66
67     def _parse_name(self, name: str) -> Name:
68         # FIXME: Names with escaped dots
69         if '.' in name:
70             return dns.name.from_text(name)
71         else:
72             return dns.name.from_text(name, origin=None)
73
74     def _parse_names(self, names: str | List[str]) -> List[Name]:
75         if isinstance(names, str):
76             return [self._parse_name(names)]
77         else:
78             return [self._parse_name(n) for n in names]
79
80     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
81         for a in self._parse_addrs(addrs):
82             if isinstance(a, IPv4Address):
83                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
84             else:
85                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
86             if reverse:
87                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
88         return self
89
90     def MX(self, pri: int, name: str) -> Self:
91         self._add(
92             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
93         )
94         return self
95
96     def NS(self, names: str | List[str]) -> Self:
97         # FIXME: Variadic?
98         for name in self._parse_names(names):
99             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
100         return self
101
102     def TXT(self, text: str) -> Self:
103         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
104         return self
105
106     def PTR(self, target: Name | str) -> Self:
107         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
108         return self
109
110     def generic(self, typ: str, text: str) -> Self:
111         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
112         return self
113
114
115 class NscZoneConfig:
116     admin_email: str
117     refresh: timedelta
118     retry: timedelta
119     expire: timedelta
120     min_ttl: timedelta
121     origin_server: str
122
123     default_config: Optional['NscZoneConfig'] = None
124
125     def __init__(self,
126                  admin_email: Optional[str] = None,
127                  refresh: Optional[timedelta] = None,
128                  retry: Optional[timedelta] = None,
129                  expire: Optional[timedelta] = None,
130                  min_ttl: Optional[timedelta] = None,
131                  origin_server: Optional[str] = None,
132                  inherit_config: Optional['NscZoneConfig'] = None,
133                  ) -> None:
134         if inherit_config is None:
135             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
136         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
137         self.refresh = refresh if refresh is not None else inherit_config.refresh
138         self.retry = retry if retry is not None else inherit_config.retry
139         self.expire = expire if expire is not None else inherit_config.expire
140         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
141         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
142
143     def finalize(self) -> Self:
144         if not self.origin_server:
145             self.origin_server = socket.getfqdn()
146         if not self.admin_email:
147             self.admin_email = f'hostmaster@{self.origin_server}'
148         return self
149
150
151 NscZoneConfig.default_config = NscZoneConfig(
152     admin_email="",
153     refresh=timedelta(hours=8),
154     retry=timedelta(hours=2),
155     expire=timedelta(days=14),
156     min_ttl=timedelta(days=1),
157     origin_server="",
158 )
159
160
161 class NscZoneState:
162     serial: int
163     hash: str
164
165     def __init__(self) -> None:
166         self.serial = 0
167         self.hash = 'none'
168
169     def load(self, file: Path) -> None:
170         try:
171             with open(file) as f:
172                 js = json.load(f)
173                 assert isinstance(js, dict)
174                 if 'serial' in js:
175                     self.serial = js['serial']
176                 if 'hash' in js:
177                     self.hash = js['hash']
178         except FileNotFoundError:
179             pass
180
181     def save(self, file: Path) -> None:
182         new_file = Path(str(file) + '.new')
183         with open(new_file, 'w') as f:
184             js = {
185                 'serial': self.serial,
186                 'hash': self.hash,
187             }
188             json.dump(js, f, indent=4, sort_keys=True)
189         new_file.replace(file)
190
191
192 class ZoneType(Enum):
193     primary = auto()
194     secondary = auto()
195
196
197 class NscZone:
198     nsc: 'Nsc'
199     name: str
200     safe_name: str                          # For use in file names
201     zone_type: ZoneType
202     reverse_for: Optional[IPNetwork]
203
204     def __init__(self,
205                  nsc: 'Nsc',
206                  name: str,
207                  reverse_for: Optional[IPNetwork],
208                  **kwargs) -> None:
209         self.nsc = nsc
210         self.name = name
211         self.safe_name = name.replace('/', '@')
212         self.config = NscZoneConfig(**kwargs).finalize()
213         self.reverse_for = reverse_for
214
215     def process(self) -> None:
216         pass
217
218
219 class NscZonePrimary(NscZone):
220     zone: Zone
221     _min_ttl: int
222     zone_file: Path
223     state_file: Path
224     state: NscZoneState
225     prev_state: NscZoneState
226
227     def __init__(self, *args, **kwargs) -> None:
228         super().__init__(*args, **kwargs)
229
230         self.zone_type = ZoneType.primary
231         self.zone_file = self.nsc.zone_dir / self.safe_name
232         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
233
234         self.state = NscZoneState()
235         self.prev_state = NscZoneState()
236         self.prev_state.load(self.state_file)
237
238         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
239         self._min_ttl = int(self.config.min_ttl.total_seconds())
240         self.update_soa()
241
242     def update_soa(self) -> None:
243         conf = self.config
244         soa = dns.rdtypes.ANY.SOA.SOA(
245             RdataClass.IN, RdataType.SOA,
246             mname=conf.origin_server,
247             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
248             serial=self.state.serial,
249             refresh=int(conf.refresh.total_seconds()),
250             retry=int(conf.retry.total_seconds()),
251             expire=int(conf.expire.total_seconds()),
252             minimum=int(conf.min_ttl.total_seconds()),
253         )
254         self.zone.delete_rdataset("", RdataType.SOA)
255         self[""]._add(soa)
256
257     def n(self, name: str) -> NscNode:
258         return NscNode(self, name)
259
260     def __getitem__(self, name: str) -> NscNode:
261         return NscNode(self, name)
262
263     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
264         n = NscNode(self, name)
265         n.A(*args, reverse=reverse)
266         return n
267
268     def dump(self, file: Optional[TextIO] = None) -> None:
269         # Could use self.zone.to_file(sys.stdout), but we want better formatting
270         file = file or sys.stdout
271         file.write(f'; Zone file for {self.name}\n\n')
272         last_name = None
273         for name, ttl, rec in self.zone.iterate_rdatas():
274             if name == last_name:
275                 print_name = ""
276             else:
277                 print_name = name
278             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
279             last_name = name
280
281     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
282         assert isinstance(self.reverse_for, IPv4Network)
283         parts = str(addr).split('.')
284         parts = parts[self.reverse_for.prefixlen // 8:]
285         name = '.'.join(reversed(parts))
286         self.n(name).PTR(ptr_to)
287
288     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
289         assert isinstance(self.reverse_for, IPv6Network)
290         parts = addr.exploded.replace(':', "")
291         parts = parts[self.reverse_for.prefixlen // 4:]
292         name = '.'.join(reversed(parts))
293         self.n(name).PTR(ptr_to)
294
295     def gen_hash(self) -> None:
296         sha = hashlib.sha1()
297         for name, ttl, rec in self.zone.iterate_rdatas():
298             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
299             sha.update(text.encode('us-ascii'))
300         self.state.hash = sha.hexdigest()[:16]
301
302     def gen_serial(self) -> None:
303         prev = self.prev_state.serial
304         if self.state.hash == self.prev_state.hash and prev > 0:
305             self.state.serial = self.prev_state.serial
306         else:
307             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
308             if prev <= base:
309                 self.state.serial = base + 1
310             else:
311                 self.state.serial = prev + 1
312                 if prev >= base + 99:
313                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
314
315     def process(self) -> None:
316         if self.zone_type == ZoneType.primary:
317             self.gen_hash()
318             self.gen_serial()
319
320     def write_zone(self) -> None:
321         self.update_soa()
322         new_file = Path(str(self.zone_file) + '.new')
323         with open(new_file, 'w') as f:
324             self.dump(file=f)
325         new_file.replace(self.zone_file)
326
327     def write_state(self) -> None:
328         self.state.save(self.state_file)
329
330
331 class NscZoneSecondary(NscZone):
332     primary_server: IPAddress
333     secondary_file: Path
334
335     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
336         super().__init__(*args, **kwargs)
337         self.zone_type = ZoneType.secondary
338         self.primary_server = primary_server
339         self.secondary_file = self.nsc.secondary_dir / self.safe_name
340
341
342 class Nsc:
343     start_time: datetime
344     zones: Dict[str, NscZone]
345     default_zone_config: NscZoneConfig
346     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
347     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
348     root_dir: Path
349     state_dir: Path
350     zone_dir: Path
351     secondary_dir: Path
352
353     def __init__(self, directory: str = '.', **kwargs) -> None:
354         self.start_time = datetime.now()
355         self.zones = {}
356         self.default_zone_config = NscZoneConfig(**kwargs)
357         self.ipv4_reverse = defaultdict(list)
358         self.ipv6_reverse = defaultdict(list)
359
360         self.root_dir = Path(directory)
361         self.state_dir = self.root_dir / 'state'
362         self.state_dir.mkdir(parents=True, exist_ok=True)
363         self.zone_dir = self.root_dir / 'zone'
364         self.zone_dir.mkdir(parents=True, exist_ok=True)
365         self.secondary_dir = self.root_dir / 'secondary'
366         self.secondary_dir.mkdir(parents=True, exist_ok=True)
367
368     def add_zone(self,
369                  name: Optional[str] = None,
370                  reverse_for: str | IPNetwork | None = None,
371                  secondary_for: str | IPAddress | None = None,
372                  inherit_config: Optional[NscZoneConfig] = None,
373                  **kwargs) -> Zone:
374         if inherit_config is None:
375             inherit_config = self.default_zone_config
376
377         if reverse_for is not None:
378             if isinstance(reverse_for, str):
379                 reverse_for = ip_network(reverse_for, strict=True)
380             name = name or self._reverse_zone_name(reverse_for)
381         assert name is not None
382         assert name not in self.zones
383
384         z: NscZone
385         if secondary_for is None:
386             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
387         else:
388             if isinstance(secondary_for, str):
389                 secondary_for = ip_address(secondary_for)
390             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=secondary_for, inherit_config=inherit_config, **kwargs)
391
392         self.zones[name] = z
393         return z
394
395     def _reverse_zone_name(self, net: IPNetwork) -> str:
396         if isinstance(net, IPv4Network):
397             parts = str(net.network_address).split('.')
398             out = parts[:net.prefixlen // 8]
399             if net.prefixlen % 8 != 0:
400                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
401             return '.'.join(reversed(out)) + '.in-addr.arpa'
402         elif isinstance(net, IPv6Network):
403             assert net.prefixlen % 4 == 0
404             nibbles = net.network_address.exploded.replace(':', "")
405             nibbles = nibbles[:net.prefixlen // 4]
406             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
407         else:
408             raise NotImplementedError()
409
410     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
411         if isinstance(addr, IPv4Address):
412             self.ipv4_reverse[addr].append(ptr_to)
413         else:
414             self.ipv6_reverse[addr].append(ptr_to)
415
416     def dump_reverse(self) -> None:
417         print('### Requests for reverse mappings ###')
418         for ipa4, name in sorted(self.ipv4_reverse.items()):
419             print(f'{ipa4}\t{name}')
420         for ipa6, name in sorted(self.ipv6_reverse.items()):
421             print(f'{ipa6}\t{name}')
422
423     def fill_reverse(self) -> None:
424         for z in self.zones.values():
425             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
426                 if isinstance(z.reverse_for, IPv4Network):
427                     for addr4, ptr_list in self.ipv4_reverse.items():
428                         if addr4 in z.reverse_for:
429                             for ptr_to in ptr_list:
430                                 z._add_ipv4_reverse(addr4, ptr_to)
431                 else:
432                     for addr6, ptr_list in self.ipv6_reverse.items():
433                         if addr6 in z.reverse_for:
434                             for ptr_to in ptr_list:
435                                 z._add_ipv6_reverse(addr6, ptr_to)
436
437     def get_zones(self) -> List[NscZone]:
438         return [self.zones[k] for k in sorted(self.zones.keys())]
439
440     def process(self) -> None:
441         self.fill_reverse()
442         for z in self.get_zones():
443             z.process()