]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
First bits of daemon configuration
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
18 import hashlib
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
20 import json
21 from pathlib import Path
22 import socket
23 import sys
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
25
26
27 if TYPE_CHECKING:
28     from nsconfig.daemon import NscDaemon
29
30
31 IPAddress = IPv4Address | IPv6Address
32 IPNetwork = IPv4Network | IPv6Network
33 IPAddr = str | IPAddress | List[str | IPAddress]
34
35
36 class NscNode:
37     nsc_zone: 'NscZonePrimary'
38     name: str
39     node: Node
40     _ttl: int
41
42     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
43         self.nsc_zone = nsc_zone
44         self.name = name
45         self.node = nsc_zone.zone.find_node(name, create=True)
46         self._ttl = nsc_zone._min_ttl
47
48     def ttl(self, *args, **kwargs) -> Self:
49         if not args and not kwargs:
50             self._ttl = self.nsc_zone._min_ttl
51         else:
52             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
53         return self
54
55     def _add(self, rec: Rdata) -> None:
56         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
57         rds.add(rec, ttl=self._ttl)
58
59     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
60         out = []
61         for a in addrs:
62             if not isinstance(a, list):
63                 a = [a]
64             for b in a:
65                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
66                     out.append(b)
67                 else:
68                     out.append(ip_address(b))
69         return out
70
71     def _parse_name(self, name: str) -> Name:
72         # FIXME: Names with escaped dots
73         if '.' in name:
74             return dns.name.from_text(name)
75         else:
76             return dns.name.from_text(name, origin=None)
77
78     def _parse_names(self, names: str | List[str]) -> List[Name]:
79         if isinstance(names, str):
80             return [self._parse_name(names)]
81         else:
82             return [self._parse_name(n) for n in names]
83
84     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
85         for a in self._parse_addrs(addrs):
86             if isinstance(a, IPv4Address):
87                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
88             else:
89                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
90             if reverse:
91                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
92         return self
93
94     def MX(self, pri: int, name: str) -> Self:
95         self._add(
96             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
97         )
98         return self
99
100     def NS(self, names: str | List[str]) -> Self:
101         # FIXME: Variadic?
102         for name in self._parse_names(names):
103             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
104         return self
105
106     def TXT(self, text: str) -> Self:
107         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
108         return self
109
110     def PTR(self, target: Name | str) -> Self:
111         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
112         return self
113
114     def generic(self, typ: str, text: str) -> Self:
115         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
116         return self
117
118
119 class NscZoneConfig:
120     admin_email: str
121     refresh: timedelta
122     retry: timedelta
123     expire: timedelta
124     min_ttl: timedelta
125     origin_server: str
126
127     default_config: Optional['NscZoneConfig'] = None
128
129     def __init__(self,
130                  admin_email: Optional[str] = None,
131                  refresh: Optional[timedelta] = None,
132                  retry: Optional[timedelta] = None,
133                  expire: Optional[timedelta] = None,
134                  min_ttl: Optional[timedelta] = None,
135                  origin_server: Optional[str] = None,
136                  inherit_config: Optional['NscZoneConfig'] = None,
137                  ) -> None:
138         if inherit_config is None:
139             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
140         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
141         self.refresh = refresh if refresh is not None else inherit_config.refresh
142         self.retry = retry if retry is not None else inherit_config.retry
143         self.expire = expire if expire is not None else inherit_config.expire
144         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
145         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
146
147     def finalize(self) -> Self:
148         if not self.origin_server:
149             self.origin_server = socket.getfqdn()
150         if not self.admin_email:
151             self.admin_email = f'hostmaster@{self.origin_server}'
152         return self
153
154
155 NscZoneConfig.default_config = NscZoneConfig(
156     admin_email="",
157     refresh=timedelta(hours=8),
158     retry=timedelta(hours=2),
159     expire=timedelta(days=14),
160     min_ttl=timedelta(days=1),
161     origin_server="",
162 )
163
164
165 class NscZoneState:
166     serial: int
167     hash: str
168
169     def __init__(self) -> None:
170         self.serial = 0
171         self.hash = 'none'
172
173     def load(self, file: Path) -> None:
174         try:
175             with open(file) as f:
176                 js = json.load(f)
177                 assert isinstance(js, dict)
178                 if 'serial' in js:
179                     self.serial = js['serial']
180                 if 'hash' in js:
181                     self.hash = js['hash']
182         except FileNotFoundError:
183             pass
184
185     def save(self, file: Path) -> None:
186         new_file = Path(str(file) + '.new')
187         with open(new_file, 'w') as f:
188             js = {
189                 'serial': self.serial,
190                 'hash': self.hash,
191             }
192             json.dump(js, f, indent=4, sort_keys=True)
193         new_file.replace(file)
194
195
196 class ZoneType(Enum):
197     primary = auto()
198     secondary = auto()
199
200
201 class NscZone:
202     nsc: 'Nsc'
203     name: str
204     safe_name: str                          # For use in file names
205     zone_type: ZoneType
206     reverse_for: Optional[IPNetwork]
207
208     def __init__(self,
209                  nsc: 'Nsc',
210                  name: str,
211                  reverse_for: Optional[IPNetwork],
212                  **kwargs) -> None:
213         self.nsc = nsc
214         self.name = name
215         self.safe_name = name.replace('/', '@')
216         self.config = NscZoneConfig(**kwargs).finalize()
217         self.reverse_for = reverse_for
218
219     def process(self) -> None:
220         pass
221
222
223 class NscZonePrimary(NscZone):
224     zone: Zone
225     _min_ttl: int
226     zone_file: Path
227     state_file: Path
228     state: NscZoneState
229     prev_state: NscZoneState
230
231     def __init__(self, *args, **kwargs) -> None:
232         super().__init__(*args, **kwargs)
233
234         self.zone_type = ZoneType.primary
235         self.zone_file = self.nsc.zone_dir / self.safe_name
236         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
237
238         self.state = NscZoneState()
239         self.prev_state = NscZoneState()
240         self.prev_state.load(self.state_file)
241
242         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
243         self._min_ttl = int(self.config.min_ttl.total_seconds())
244         self.update_soa()
245
246     def update_soa(self) -> None:
247         conf = self.config
248         soa = dns.rdtypes.ANY.SOA.SOA(
249             RdataClass.IN, RdataType.SOA,
250             mname=conf.origin_server,
251             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
252             serial=self.state.serial,
253             refresh=int(conf.refresh.total_seconds()),
254             retry=int(conf.retry.total_seconds()),
255             expire=int(conf.expire.total_seconds()),
256             minimum=int(conf.min_ttl.total_seconds()),
257         )
258         self.zone.delete_rdataset("", RdataType.SOA)
259         self[""]._add(soa)
260
261     def n(self, name: str) -> NscNode:
262         return NscNode(self, name)
263
264     def __getitem__(self, name: str) -> NscNode:
265         return NscNode(self, name)
266
267     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
268         n = NscNode(self, name)
269         n.A(*args, reverse=reverse)
270         return n
271
272     def dump(self, file: Optional[TextIO] = None) -> None:
273         # Could use self.zone.to_file(sys.stdout), but we want better formatting
274         file = file or sys.stdout
275         file.write(f'; Zone file for {self.name}\n\n')
276         last_name = None
277         for name, ttl, rec in self.zone.iterate_rdatas():
278             if name == last_name:
279                 print_name = ""
280             else:
281                 print_name = name
282             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
283             last_name = name
284
285     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
286         assert isinstance(self.reverse_for, IPv4Network)
287         parts = str(addr).split('.')
288         parts = parts[self.reverse_for.prefixlen // 8:]
289         name = '.'.join(reversed(parts))
290         self.n(name).PTR(ptr_to)
291
292     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
293         assert isinstance(self.reverse_for, IPv6Network)
294         parts = addr.exploded.replace(':', "")
295         parts = parts[self.reverse_for.prefixlen // 4:]
296         name = '.'.join(reversed(parts))
297         self.n(name).PTR(ptr_to)
298
299     def gen_hash(self) -> None:
300         sha = hashlib.sha1()
301         for name, ttl, rec in self.zone.iterate_rdatas():
302             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
303             sha.update(text.encode('us-ascii'))
304         self.state.hash = sha.hexdigest()[:16]
305
306     def gen_serial(self) -> None:
307         prev = self.prev_state.serial
308         if self.state.hash == self.prev_state.hash and prev > 0:
309             self.state.serial = self.prev_state.serial
310         else:
311             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
312             if prev <= base:
313                 self.state.serial = base + 1
314             else:
315                 self.state.serial = prev + 1
316                 if prev >= base + 99:
317                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
318
319     def process(self) -> None:
320         if self.zone_type == ZoneType.primary:
321             self.gen_hash()
322             self.gen_serial()
323
324     def write_zone(self) -> None:
325         self.update_soa()
326         new_file = Path(str(self.zone_file) + '.new')
327         with open(new_file, 'w') as f:
328             self.dump(file=f)
329         new_file.replace(self.zone_file)
330
331     def write_state(self) -> None:
332         self.state.save(self.state_file)
333
334     def is_changed(self) -> bool:
335         return self.state.serial != self.prev_state.serial
336
337
338 class NscZoneSecondary(NscZone):
339     primary_server: IPAddress
340     secondary_file: Path
341
342     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
343         super().__init__(*args, **kwargs)
344         self.zone_type = ZoneType.secondary
345         self.primary_server = primary_server
346         self.secondary_file = self.nsc.secondary_dir / self.safe_name
347
348
349 class Nsc:
350     start_time: datetime
351     zones: Dict[str, NscZone]
352     default_zone_config: NscZoneConfig
353     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
354     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
355     root_dir: Path
356     state_dir: Path
357     zone_dir: Path
358     secondary_dir: Path
359     daemon: Optional['NscDaemon']  # Set by DaemonConfig class
360
361     def __init__(self,
362                  directory: str = '.',
363                  daemon: Optional['NscDaemon'] = None,
364                  **kwargs) -> None:
365         self.start_time = datetime.now()
366         self.zones = {}
367         self.default_zone_config = NscZoneConfig(**kwargs)
368         self.ipv4_reverse = defaultdict(list)
369         self.ipv6_reverse = defaultdict(list)
370
371         self.root_dir = Path(directory)
372         self.state_dir = self.root_dir / 'state'
373         self.state_dir.mkdir(parents=True, exist_ok=True)
374         self.zone_dir = self.root_dir / 'zone'
375         self.zone_dir.mkdir(parents=True, exist_ok=True)
376         self.secondary_dir = self.root_dir / 'secondary'
377         self.secondary_dir.mkdir(parents=True, exist_ok=True)
378
379         self.daemon = daemon
380         if daemon is not None:
381             daemon.setup(self)
382
383     def add_zone(self,
384                  name: Optional[str] = None,
385                  reverse_for: str | IPNetwork | None = None,
386                  follow_primary: str | IPAddress | None = None,
387                  inherit_config: Optional[NscZoneConfig] = None,
388                  **kwargs) -> Zone:
389         if inherit_config is None:
390             inherit_config = self.default_zone_config
391
392         if reverse_for is not None:
393             if isinstance(reverse_for, str):
394                 reverse_for = ip_network(reverse_for, strict=True)
395             name = name or self._reverse_zone_name(reverse_for)
396         assert name is not None
397         assert name not in self.zones
398
399         z: NscZone
400         if follow_primary is None:
401             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
402         else:
403             if isinstance(follow_primary, str):
404                 follow_primary = ip_address(follow_primary)
405             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
406
407         self.zones[name] = z
408         return z
409
410     def _reverse_zone_name(self, net: IPNetwork) -> str:
411         if isinstance(net, IPv4Network):
412             parts = str(net.network_address).split('.')
413             out = parts[:net.prefixlen // 8]
414             if net.prefixlen % 8 != 0:
415                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
416             return '.'.join(reversed(out)) + '.in-addr.arpa'
417         elif isinstance(net, IPv6Network):
418             assert net.prefixlen % 4 == 0
419             nibbles = net.network_address.exploded.replace(':', "")
420             nibbles = nibbles[:net.prefixlen // 4]
421             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
422         else:
423             raise NotImplementedError()
424
425     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
426         if isinstance(addr, IPv4Address):
427             self.ipv4_reverse[addr].append(ptr_to)
428         else:
429             self.ipv6_reverse[addr].append(ptr_to)
430
431     def dump_reverse(self) -> None:
432         print('### Requests for reverse mappings ###')
433         for ipa4, name in sorted(self.ipv4_reverse.items()):
434             print(f'{ipa4}\t{name}')
435         for ipa6, name in sorted(self.ipv6_reverse.items()):
436             print(f'{ipa6}\t{name}')
437
438     def fill_reverse(self) -> None:
439         for z in self.zones.values():
440             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
441                 if isinstance(z.reverse_for, IPv4Network):
442                     for addr4, ptr_list in self.ipv4_reverse.items():
443                         if addr4 in z.reverse_for:
444                             for ptr_to in ptr_list:
445                                 z._add_ipv4_reverse(addr4, ptr_to)
446                 else:
447                     for addr6, ptr_list in self.ipv6_reverse.items():
448                         if addr6 in z.reverse_for:
449                             for ptr_to in ptr_list:
450                                 z._add_ipv6_reverse(addr6, ptr_to)
451
452     def get_zones(self) -> List[NscZone]:
453         return [self.zones[k] for k in sorted(self.zones.keys())]
454
455     def process(self) -> None:
456         self.fill_reverse()
457         for z in self.get_zones():
458             z.process()