]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
88baf6623a16c5677469c03a20f6f960deecc7c3
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.CNAME
10 import dns.rdtypes.ANY.MX
11 import dns.rdtypes.ANY.NS
12 import dns.rdtypes.ANY.PTR
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from enum import Enum, auto
19 import hashlib
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 import json
22 from pathlib import Path
23 import socket
24 import sys
25 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
26
27 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration
28 from nsconfig.util import IPAddress, IPNetwork, IPAddr
29
30
31 if TYPE_CHECKING:
32     from nsconfig.daemon import NscDaemon
33
34
35 class NscNode:
36     nsc_zone: 'NscZonePrimary'
37     name: str
38     node: Node
39     _ttl: int
40
41     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
42         self.nsc_zone = nsc_zone
43         self.name = name
44         self.node = nsc_zone.zone.find_node(name, create=True)
45         self._ttl = nsc_zone.config.default_ttl
46
47     def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
48         if seconds is not None:
49             self._ttl = seconds
50         elif kwargs:
51             self._ttl = parse_duration(timedelta(**kwargs))
52         else:
53             self._ttl = self.nsc_zone.config.default_ttl
54         return self
55
56     def _add(self, rec: Rdata) -> None:
57         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
58         rds.add(rec, ttl=self._ttl)
59
60     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
61         for a in map(parse_address, flatten_list(addrs)):
62             if isinstance(a, IPv4Address):
63                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
64             else:
65                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
66             if reverse:
67                 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name, origin=self.nsc_zone.dns_name))
68         return self
69
70     def MX(self, pri: int, name: str) -> Self:
71         self._add(
72             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
73         )
74         return self
75
76     def NS(self, *names: str | List[str]) -> Self:
77         for name in map(parse_name, flatten_list(names)):
78             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
79         return self
80
81     def TXT(self, *text: str | List[str]) -> Self:
82         for txt in flatten_list(text):
83             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
84         return self
85
86     def PTR(self, target: Name | str) -> Self:
87         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
88         return self
89
90     def CNAME(self, target: Name | str) -> Self:
91         self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
92         return self
93
94     def generic(self, typ: str, text: str) -> Self:
95         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
96         return self
97
98
99 class NscZoneConfig:
100     admin_email: str
101     refresh: int
102     retry: int
103     expire: int
104     min_ttl: int
105     default_ttl: int
106     origin_server: str
107     daemon_options: List[str]
108     add_null_mx: bool
109
110     default_config: Optional['NscZoneConfig'] = None
111
112     def __init__(self,
113                  admin_email: Optional[str] = None,
114                  refresh: Optional[int | timedelta] = None,
115                  retry: Optional[int | timedelta] = None,
116                  expire: Optional[int | timedelta] = None,
117                  min_ttl: Optional[int | timedelta] = None,
118                  default_ttl: Optional[int | timedelta] = None,
119                  origin_server: Optional[str] = None,
120                  daemon_options: Optional[List[str]] = None,
121                  add_daemon_options: Optional[List[str]] = None,
122                  add_null_mx: Optional[bool] = None,
123                  inherit_config: Optional['NscZoneConfig'] = None,
124                  ) -> None:
125         if inherit_config is None:
126             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
127         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
128         self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
129         self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
130         self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
131         self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
132         self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
133         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
134         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
135         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
136         if add_daemon_options is not None:
137             self.daemon_options += add_daemon_options
138
139     def finalize(self) -> Self:
140         if not self.origin_server:
141             self.origin_server = socket.getfqdn()
142         if not self.admin_email:
143             self.admin_email = f'hostmaster@{self.origin_server}'
144         if self.default_ttl == 0:
145             self.default_ttl = self.min_ttl
146         return self
147
148
149 NscZoneConfig.default_config = NscZoneConfig(
150     admin_email="",
151     refresh=timedelta(hours=8),
152     retry=timedelta(hours=2),
153     expire=timedelta(days=14),
154     min_ttl=timedelta(days=1),
155     default_ttl=0,
156     origin_server="",
157     daemon_options=[],
158     add_null_mx=False,
159 )
160
161
162 class NscZoneState:
163     serial: int
164     hash: str
165
166     def __init__(self) -> None:
167         self.serial = 0
168         self.hash = 'none'
169
170     def load(self, file: Path) -> None:
171         try:
172             with open(file) as f:
173                 js = json.load(f)
174                 assert isinstance(js, dict)
175                 if 'serial' in js:
176                     self.serial = js['serial']
177                 if 'hash' in js:
178                     self.hash = js['hash']
179         except FileNotFoundError:
180             pass
181
182     def save(self, file: Path) -> None:
183         new_file = Path(str(file) + '.new')
184         with open(new_file, 'w') as f:
185             js = {
186                 'serial': self.serial,
187                 'hash': self.hash,
188             }
189             json.dump(js, f, indent=4, sort_keys=True)
190         new_file.replace(file)
191
192
193 class ZoneType(Enum):
194     primary = auto()
195     secondary = auto()
196     alias = auto()
197
198
199 class NscZone:
200     nsc: 'Nsc'
201     name: str
202     dns_name: Name
203     safe_name: str                          # For use in file names
204     zone_type: ZoneType
205     reverse_for: Optional[IPNetwork]
206
207     def __init__(self,
208                  nsc: 'Nsc',
209                  name: str,
210                  reverse_for: Optional[IPNetwork],
211                  **kwargs) -> None:
212         self.nsc = nsc
213         self.name = name
214         self.dns_name = dns.name.from_text(name)
215         self.safe_name = name.replace('/', '@')
216         self.config = NscZoneConfig(**kwargs).finalize()
217         self.reverse_for = reverse_for
218
219     def process(self) -> None:
220         pass
221
222     def is_changed(self) -> bool:
223         return False
224
225
226 class NscZonePrimary(NscZone):
227     zone: Zone
228     zone_file: Path
229     state_file: Path
230     state: NscZoneState
231     prev_state: NscZoneState
232     aliases: List['NscZoneAlias']
233
234     def __init__(self, *args, **kwargs) -> None:
235         super().__init__(*args, **kwargs)
236
237         self.zone_type = ZoneType.primary
238         self.zone_file = self.nsc.zone_dir / self.safe_name
239         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
240
241         self.state = NscZoneState()
242         self.prev_state = NscZoneState()
243         self.prev_state.load(self.state_file)
244
245         self.aliases = []
246
247         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
248         self.update_soa()
249
250     def update_soa(self) -> None:
251         conf = self.config
252         soa = dns.rdtypes.ANY.SOA.SOA(
253             RdataClass.IN, RdataType.SOA,
254             mname=conf.origin_server,
255             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
256             serial=self.state.serial,
257             refresh=conf.refresh,
258             retry=conf.retry,
259             expire=conf.expire,
260             minimum=conf.min_ttl,
261         )
262         self.zone.delete_rdataset("", RdataType.SOA)
263         self[""]._add(soa)
264
265     def n(self, name: str) -> NscNode:
266         return NscNode(self, name)
267
268     def __getitem__(self, name: str) -> NscNode:
269         return NscNode(self, name)
270
271     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
272         n = NscNode(self, name)
273         n.A(*args, reverse=reverse)
274         return n
275
276     def zone_header(self) -> str:
277         return (
278             f'; Zone file for {self.name}\n'
279             + '; Generated by NSC, please do not edit manually.\n'
280             + '\n')
281
282     def dump(self, file: Optional[TextIO] = None) -> None:
283         # Could use self.zone.to_file(sys.stdout), but we want better formatting
284         file = file or sys.stdout
285         file.write(self.zone_header())
286         file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
287         last_name = None
288         for name, ttl, rec in self.zone.iterate_rdatas():
289             if name == last_name:
290                 print_name = ""
291             else:
292                 print_name = name
293             file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
294             last_name = name
295
296     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
297         assert isinstance(self.reverse_for, IPv4Network)
298         parts = str(addr).split('.')
299         parts = parts[self.reverse_for.prefixlen // 8:]
300         name = '.'.join(reversed(parts))
301         self.n(name).PTR(ptr_to)
302
303     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
304         assert isinstance(self.reverse_for, IPv6Network)
305         parts = addr.exploded.replace(':', "")
306         parts = parts[self.reverse_for.prefixlen // 4:]
307         name = '.'.join(reversed(parts))
308         self.n(name).PTR(ptr_to)
309
310     def gen_hash(self) -> None:
311         sha = hashlib.sha1()
312         sha.update(self.zone_header().encode('us-ascii'))
313         for name, ttl, rec in self.zone.iterate_rdatas():
314             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
315             sha.update(text.encode('us-ascii'))
316         self.state.hash = sha.hexdigest()[:16]
317
318     def gen_serial(self) -> None:
319         prev = self.prev_state.serial
320         if self.state.hash == self.prev_state.hash and prev > 0:
321             self.state.serial = self.prev_state.serial
322         else:
323             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
324             if prev <= base:
325                 self.state.serial = base + 1
326             else:
327                 self.state.serial = prev + 1
328                 if prev >= base + 99:
329                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
330
331     def process(self) -> None:
332         if self.config.add_null_mx:
333             self.gen_null_mx()
334         self.gen_hash()
335         self.gen_serial()
336
337     def write_zone(self) -> None:
338         self.update_soa()
339         new_file = Path(str(self.zone_file) + '.new')
340         with open(new_file, 'w') as f:
341             self.dump(file=f)
342         new_file.replace(self.zone_file)
343
344     def write_state(self) -> None:
345         self.state.save(self.state_file)
346
347     def is_changed(self) -> bool:
348         return self.state.serial != self.prev_state.serial
349
350     def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
351         net = parse_network(net)
352         assert self.reverse_for is not None
353         assert isinstance(self.reverse_for, IPv4Network)
354         assert self.reverse_for.prefixlen % 8 == 0
355         assert isinstance(net, IPv4Network)
356         assert net.subnet_of(self.reverse_for)
357         assert net.prefixlen < self.reverse_for.prefixlen + 8
358
359         start = int(net.network_address.packed[net.prefixlen // 8])
360         num = 1 << (8 - net.prefixlen % 8)
361
362         if subdomain is None:
363             subdomain = f'{start}/{net.prefixlen}'
364
365         for i in range(start, start + num):
366             target = f'{i}.{subdomain}'
367             self[str(i)].CNAME(parse_name(target, relative=True))
368
369         return self[subdomain]
370
371     def gen_null_mx(self) -> None:
372         for name, node in self.zone.items():
373             rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
374             rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
375             if rds_a or rds_aaaa:
376                 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
377                 if not mx_rds:
378                     mx_rds.add(
379                         dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
380                         ttl=self.config.default_ttl,
381                     )
382
383
384 class NscZoneSecondary(NscZone):
385     primary_server: IPAddress
386     secondary_file: Path
387
388     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
389         super().__init__(*args, **kwargs)
390         self.zone_type = ZoneType.secondary
391         self.primary_server = primary_server
392         self.secondary_file = self.nsc.secondary_dir / self.safe_name
393
394
395 class NscZoneAlias(NscZone):
396     alias_for: NscZonePrimary
397
398     def __init__(self, *args, alias_for=NscZonePrimary, **kwargs) -> None:
399         assert isinstance(alias_for, NscZonePrimary)
400         super().__init__(*args, **kwargs)
401         self.zone_type = ZoneType.alias
402         self.alias_for = alias_for
403         self.zone_file = alias_for.zone_file
404         alias_for.aliases.append(self)
405
406     def is_changed(self) -> bool:
407         return self.alias_for.is_changed()
408
409
410 class Nsc:
411     start_time: datetime
412     zones: Dict[str, NscZone]
413     default_zone_config: NscZoneConfig
414     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
415     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
416     root_dir: Path
417     state_dir: Path
418     zone_dir: Path
419     secondary_dir: Path
420     daemon: 'NscDaemon'  # Set by DaemonConfig class
421
422     def __init__(self,
423                  directory: str = '.',
424                  daemon: Optional['NscDaemon'] = None,
425                  **kwargs) -> None:
426         self.start_time = datetime.now()
427         self.zones = {}
428         self.default_zone_config = NscZoneConfig(**kwargs)
429         self.ipv4_reverse = defaultdict(list)
430         self.ipv6_reverse = defaultdict(list)
431
432         self.root_dir = Path(directory)
433         self.state_dir = self.root_dir / 'state'
434         self.state_dir.mkdir(parents=True, exist_ok=True)
435         self.zone_dir = self.root_dir / 'zone'
436         self.zone_dir.mkdir(parents=True, exist_ok=True)
437         self.secondary_dir = self.root_dir / 'secondary'
438         self.secondary_dir.mkdir(parents=True, exist_ok=True)
439
440         if daemon is None:
441             from nsconfig.daemon import NscDaemonNull
442             daemon = NscDaemonNull()
443         self.daemon = daemon
444         daemon.setup(self)
445
446     def add_zone(self,
447                  name: Optional[str] = None,
448                  reverse_for: str | IPNetwork | None = None,
449                  alias_for: Optional[NscZonePrimary] = None,
450                  follow_primary: str | IPAddress | None = None,
451                  inherit_config: Optional[NscZoneConfig] = None,
452                  **kwargs) -> Zone:
453         if inherit_config is None:
454             inherit_config = self.default_zone_config
455
456         if reverse_for is not None:
457             if isinstance(reverse_for, str):
458                 reverse_for = ip_network(reverse_for, strict=True)
459             name = name or self._reverse_zone_name(reverse_for)
460         assert name is not None
461         assert name not in self.zones
462
463         z: NscZone
464         if alias_for is not None:
465             assert follow_primary is None
466             z = NscZoneAlias(self, name, reverse_for=reverse_for, alias_for=alias_for, inherit_config=inherit_config, **kwargs)
467         elif follow_primary is None:
468             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
469         else:
470             if isinstance(follow_primary, str):
471                 follow_primary = ip_address(follow_primary)
472             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
473
474         self.zones[name] = z
475         return z
476
477     def __getitem__(self, name: str) -> NscZone:
478         return self.zones[name]
479
480     def _reverse_zone_name(self, net: IPNetwork) -> str:
481         if isinstance(net, IPv4Network):
482             parts = str(net.network_address).split('.')
483             out = parts[:net.prefixlen // 8]
484             if net.prefixlen % 8 != 0:
485                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
486             return '.'.join(reversed(out)) + '.in-addr.arpa'
487         elif isinstance(net, IPv6Network):
488             assert net.prefixlen % 4 == 0
489             nibbles = net.network_address.exploded.replace(':', "")
490             nibbles = nibbles[:net.prefixlen // 4]
491             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
492         else:
493             raise NotImplementedError()
494
495     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
496         if isinstance(addr, IPv4Address):
497             self.ipv4_reverse[addr].append(ptr_to)
498         else:
499             self.ipv6_reverse[addr].append(ptr_to)
500
501     def dump_reverse(self) -> None:
502         print('### Requests for reverse mappings ###')
503         for ipa4, name in sorted(self.ipv4_reverse.items()):
504             print(f'{ipa4}\t{name}')
505         for ipa6, name in sorted(self.ipv6_reverse.items()):
506             print(f'{ipa6}\t{name}')
507
508     def fill_reverse(self) -> None:
509         for z in self.zones.values():
510             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
511                 if isinstance(z.reverse_for, IPv4Network):
512                     for addr4, ptr_list in self.ipv4_reverse.items():
513                         if addr4 in z.reverse_for:
514                             for ptr_to in ptr_list:
515                                 z._add_ipv4_reverse(addr4, ptr_to)
516                 else:
517                     for addr6, ptr_list in self.ipv6_reverse.items():
518                         if addr6 in z.reverse_for:
519                             for ptr_to in ptr_list:
520                                 z._add_ipv6_reverse(addr6, ptr_to)
521
522     def get_zones(self) -> List[NscZone]:
523         return [self.zones[k] for k in sorted(self.zones.keys())]
524
525     def process(self) -> None:
526         self.fill_reverse()
527         for z in self.get_zones():
528             z.process()