]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Implement alias zones
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.CNAME
10 import dns.rdtypes.ANY.MX
11 import dns.rdtypes.ANY.NS
12 import dns.rdtypes.ANY.PTR
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from enum import Enum, auto
19 import hashlib
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 import json
22 from pathlib import Path
23 import socket
24 import sys
25 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
26
27 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration
28 from nsconfig.util import IPAddress, IPNetwork, IPAddr
29
30
31 if TYPE_CHECKING:
32     from nsconfig.daemon import NscDaemon
33
34
35 class NscNode:
36     nsc_zone: 'NscZonePrimary'
37     name: str
38     node: Node
39     _ttl: int
40
41     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
42         self.nsc_zone = nsc_zone
43         self.name = name
44         self.node = nsc_zone.zone.find_node(name, create=True)
45         self._ttl = nsc_zone.config.default_ttl
46
47     def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
48         if seconds is not None:
49             self._ttl = seconds
50         elif kwargs:
51             self._ttl = parse_duration(timedelta(**kwargs))
52         else:
53             self._ttl = self.nsc_zone.config.default_ttl
54         return self
55
56     def _add(self, rec: Rdata) -> None:
57         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
58         rds.add(rec, ttl=self._ttl)
59
60     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
61         for a in map(parse_address, flatten_list(addrs)):
62             if isinstance(a, IPv4Address):
63                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
64             else:
65                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
66             if reverse:
67                 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name + '.' + self.nsc_zone.name))
68         return self
69
70     def MX(self, pri: int, name: str) -> Self:
71         self._add(
72             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
73         )
74         return self
75
76     def NS(self, *names: str | List[str]) -> Self:
77         for name in map(parse_name, flatten_list(names)):
78             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
79         return self
80
81     def TXT(self, *text: str | List[str]) -> Self:
82         for txt in flatten_list(text):
83             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
84         return self
85
86     def PTR(self, target: Name | str) -> Self:
87         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
88         return self
89
90     def CNAME(self, target: Name | str) -> Self:
91         self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
92         return self
93
94     def generic(self, typ: str, text: str) -> Self:
95         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
96         return self
97
98
99 class NscZoneConfig:
100     admin_email: str
101     refresh: int
102     retry: int
103     expire: int
104     min_ttl: int
105     default_ttl: int
106     origin_server: str
107     daemon_options: List[str]
108     add_null_mx: bool
109
110     default_config: Optional['NscZoneConfig'] = None
111
112     def __init__(self,
113                  admin_email: Optional[str] = None,
114                  refresh: Optional[int | timedelta] = None,
115                  retry: Optional[int | timedelta] = None,
116                  expire: Optional[int | timedelta] = None,
117                  min_ttl: Optional[int | timedelta] = None,
118                  default_ttl: Optional[int | timedelta] = None,
119                  origin_server: Optional[str] = None,
120                  daemon_options: Optional[List[str]] = None,
121                  add_daemon_options: Optional[List[str]] = None,
122                  add_null_mx: Optional[bool] = None,
123                  inherit_config: Optional['NscZoneConfig'] = None,
124                  ) -> None:
125         if inherit_config is None:
126             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
127         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
128         self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
129         self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
130         self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
131         self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
132         self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
133         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
134         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
135         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
136         if add_daemon_options is not None:
137             self.daemon_options += add_daemon_options
138
139     def finalize(self) -> Self:
140         if not self.origin_server:
141             self.origin_server = socket.getfqdn()
142         if not self.admin_email:
143             self.admin_email = f'hostmaster@{self.origin_server}'
144         if self.default_ttl == 0:
145             self.default_ttl = self.min_ttl
146         return self
147
148
149 NscZoneConfig.default_config = NscZoneConfig(
150     admin_email="",
151     refresh=timedelta(hours=8),
152     retry=timedelta(hours=2),
153     expire=timedelta(days=14),
154     min_ttl=timedelta(days=1),
155     default_ttl=0,
156     origin_server="",
157     daemon_options=[],
158     add_null_mx=False,
159 )
160
161
162 class NscZoneState:
163     serial: int
164     hash: str
165
166     def __init__(self) -> None:
167         self.serial = 0
168         self.hash = 'none'
169
170     def load(self, file: Path) -> None:
171         try:
172             with open(file) as f:
173                 js = json.load(f)
174                 assert isinstance(js, dict)
175                 if 'serial' in js:
176                     self.serial = js['serial']
177                 if 'hash' in js:
178                     self.hash = js['hash']
179         except FileNotFoundError:
180             pass
181
182     def save(self, file: Path) -> None:
183         new_file = Path(str(file) + '.new')
184         with open(new_file, 'w') as f:
185             js = {
186                 'serial': self.serial,
187                 'hash': self.hash,
188             }
189             json.dump(js, f, indent=4, sort_keys=True)
190         new_file.replace(file)
191
192
193 class ZoneType(Enum):
194     primary = auto()
195     secondary = auto()
196     alias = auto()
197
198
199 class NscZone:
200     nsc: 'Nsc'
201     name: str
202     safe_name: str                          # For use in file names
203     zone_type: ZoneType
204     reverse_for: Optional[IPNetwork]
205
206     def __init__(self,
207                  nsc: 'Nsc',
208                  name: str,
209                  reverse_for: Optional[IPNetwork],
210                  **kwargs) -> None:
211         self.nsc = nsc
212         self.name = name
213         self.safe_name = name.replace('/', '@')
214         self.config = NscZoneConfig(**kwargs).finalize()
215         self.reverse_for = reverse_for
216
217     def process(self) -> None:
218         pass
219
220     def is_changed(self) -> bool:
221         return False
222
223
224 class NscZonePrimary(NscZone):
225     zone: Zone
226     zone_file: Path
227     state_file: Path
228     state: NscZoneState
229     prev_state: NscZoneState
230     aliases: List['NscZoneAlias']
231
232     def __init__(self, *args, **kwargs) -> None:
233         super().__init__(*args, **kwargs)
234
235         self.zone_type = ZoneType.primary
236         self.zone_file = self.nsc.zone_dir / self.safe_name
237         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
238
239         self.state = NscZoneState()
240         self.prev_state = NscZoneState()
241         self.prev_state.load(self.state_file)
242
243         self.aliases = []
244
245         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
246         self.update_soa()
247
248     def update_soa(self) -> None:
249         conf = self.config
250         soa = dns.rdtypes.ANY.SOA.SOA(
251             RdataClass.IN, RdataType.SOA,
252             mname=conf.origin_server,
253             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
254             serial=self.state.serial,
255             refresh=conf.refresh,
256             retry=conf.retry,
257             expire=conf.expire,
258             minimum=conf.min_ttl,
259         )
260         self.zone.delete_rdataset("", RdataType.SOA)
261         self[""]._add(soa)
262
263     def n(self, name: str) -> NscNode:
264         return NscNode(self, name)
265
266     def __getitem__(self, name: str) -> NscNode:
267         return NscNode(self, name)
268
269     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
270         n = NscNode(self, name)
271         n.A(*args, reverse=reverse)
272         return n
273
274     def zone_header(self) -> str:
275         return (
276             f'; Zone file for {self.name}\n'
277             + '; Generated by NSC, please do not edit manually.\n'
278             + '\n')
279
280     def dump(self, file: Optional[TextIO] = None) -> None:
281         # Could use self.zone.to_file(sys.stdout), but we want better formatting
282         file = file or sys.stdout
283         file.write(self.zone_header())
284         file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
285         last_name = None
286         for name, ttl, rec in self.zone.iterate_rdatas():
287             if name == last_name:
288                 print_name = ""
289             else:
290                 print_name = name
291             file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
292             last_name = name
293
294     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
295         assert isinstance(self.reverse_for, IPv4Network)
296         parts = str(addr).split('.')
297         parts = parts[self.reverse_for.prefixlen // 8:]
298         name = '.'.join(reversed(parts))
299         self.n(name).PTR(ptr_to)
300
301     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
302         assert isinstance(self.reverse_for, IPv6Network)
303         parts = addr.exploded.replace(':', "")
304         parts = parts[self.reverse_for.prefixlen // 4:]
305         name = '.'.join(reversed(parts))
306         self.n(name).PTR(ptr_to)
307
308     def gen_hash(self) -> None:
309         sha = hashlib.sha1()
310         sha.update(self.zone_header().encode('us-ascii'))
311         for name, ttl, rec in self.zone.iterate_rdatas():
312             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
313             sha.update(text.encode('us-ascii'))
314         self.state.hash = sha.hexdigest()[:16]
315
316     def gen_serial(self) -> None:
317         prev = self.prev_state.serial
318         if self.state.hash == self.prev_state.hash and prev > 0:
319             self.state.serial = self.prev_state.serial
320         else:
321             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
322             if prev <= base:
323                 self.state.serial = base + 1
324             else:
325                 self.state.serial = prev + 1
326                 if prev >= base + 99:
327                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
328
329     def process(self) -> None:
330         if self.config.add_null_mx:
331             self.gen_null_mx()
332         self.gen_hash()
333         self.gen_serial()
334
335     def write_zone(self) -> None:
336         self.update_soa()
337         new_file = Path(str(self.zone_file) + '.new')
338         with open(new_file, 'w') as f:
339             self.dump(file=f)
340         new_file.replace(self.zone_file)
341
342     def write_state(self) -> None:
343         self.state.save(self.state_file)
344
345     def is_changed(self) -> bool:
346         return self.state.serial != self.prev_state.serial
347
348     def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
349         net = parse_network(net)
350         assert self.reverse_for is not None
351         assert isinstance(self.reverse_for, IPv4Network)
352         assert self.reverse_for.prefixlen % 8 == 0
353         assert isinstance(net, IPv4Network)
354         assert net.subnet_of(self.reverse_for)
355         assert net.prefixlen < self.reverse_for.prefixlen + 8
356
357         start = int(net.network_address.packed[net.prefixlen // 8])
358         num = 1 << (8 - net.prefixlen % 8)
359
360         if subdomain is None:
361             subdomain = f'{start}/{net.prefixlen}'
362
363         for i in range(start, start + num):
364             target = f'{i}.{subdomain}'
365             self[str(i)].CNAME(parse_name(target, relative=True))
366
367         return self[subdomain]
368
369     def gen_null_mx(self) -> None:
370         for name, node in self.zone.items():
371             rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
372             rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
373             if rds_a or rds_aaaa:
374                 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
375                 if not mx_rds:
376                     mx_rds.add(
377                         dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
378                         ttl=self.config.default_ttl,
379                     )
380
381
382 class NscZoneSecondary(NscZone):
383     primary_server: IPAddress
384     secondary_file: Path
385
386     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
387         super().__init__(*args, **kwargs)
388         self.zone_type = ZoneType.secondary
389         self.primary_server = primary_server
390         self.secondary_file = self.nsc.secondary_dir / self.safe_name
391
392
393 class NscZoneAlias(NscZone):
394     alias_for: NscZonePrimary
395
396     def __init__(self, *args, alias_for=NscZonePrimary, **kwargs) -> None:
397         assert isinstance(alias_for, NscZonePrimary)
398         super().__init__(*args, **kwargs)
399         self.zone_type = ZoneType.alias
400         self.alias_for = alias_for
401         self.zone_file = alias_for.zone_file
402         alias_for.aliases.append(self)
403
404     def is_changed(self) -> bool:
405         return self.alias_for.is_changed()
406
407
408 class Nsc:
409     start_time: datetime
410     zones: Dict[str, NscZone]
411     default_zone_config: NscZoneConfig
412     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
413     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
414     root_dir: Path
415     state_dir: Path
416     zone_dir: Path
417     secondary_dir: Path
418     daemon: 'NscDaemon'  # Set by DaemonConfig class
419
420     def __init__(self,
421                  directory: str = '.',
422                  daemon: Optional['NscDaemon'] = None,
423                  **kwargs) -> None:
424         self.start_time = datetime.now()
425         self.zones = {}
426         self.default_zone_config = NscZoneConfig(**kwargs)
427         self.ipv4_reverse = defaultdict(list)
428         self.ipv6_reverse = defaultdict(list)
429
430         self.root_dir = Path(directory)
431         self.state_dir = self.root_dir / 'state'
432         self.state_dir.mkdir(parents=True, exist_ok=True)
433         self.zone_dir = self.root_dir / 'zone'
434         self.zone_dir.mkdir(parents=True, exist_ok=True)
435         self.secondary_dir = self.root_dir / 'secondary'
436         self.secondary_dir.mkdir(parents=True, exist_ok=True)
437
438         if daemon is None:
439             from nsconfig.daemon import NscDaemonNull
440             daemon = NscDaemonNull()
441         self.daemon = daemon
442         daemon.setup(self)
443
444     def add_zone(self,
445                  name: Optional[str] = None,
446                  reverse_for: str | IPNetwork | None = None,
447                  alias_for: Optional[NscZonePrimary] = None,
448                  follow_primary: str | IPAddress | None = None,
449                  inherit_config: Optional[NscZoneConfig] = None,
450                  **kwargs) -> Zone:
451         if inherit_config is None:
452             inherit_config = self.default_zone_config
453
454         if reverse_for is not None:
455             if isinstance(reverse_for, str):
456                 reverse_for = ip_network(reverse_for, strict=True)
457             name = name or self._reverse_zone_name(reverse_for)
458         assert name is not None
459         assert name not in self.zones
460
461         z: NscZone
462         if alias_for is not None:
463             assert follow_primary is None
464             z = NscZoneAlias(self, name, reverse_for=reverse_for, alias_for=alias_for, inherit_config=inherit_config, **kwargs)
465         elif follow_primary is None:
466             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
467         else:
468             if isinstance(follow_primary, str):
469                 follow_primary = ip_address(follow_primary)
470             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
471
472         self.zones[name] = z
473         return z
474
475     def _reverse_zone_name(self, net: IPNetwork) -> str:
476         if isinstance(net, IPv4Network):
477             parts = str(net.network_address).split('.')
478             out = parts[:net.prefixlen // 8]
479             if net.prefixlen % 8 != 0:
480                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
481             return '.'.join(reversed(out)) + '.in-addr.arpa'
482         elif isinstance(net, IPv6Network):
483             assert net.prefixlen % 4 == 0
484             nibbles = net.network_address.exploded.replace(':', "")
485             nibbles = nibbles[:net.prefixlen // 4]
486             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
487         else:
488             raise NotImplementedError()
489
490     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
491         if isinstance(addr, IPv4Address):
492             self.ipv4_reverse[addr].append(ptr_to)
493         else:
494             self.ipv6_reverse[addr].append(ptr_to)
495
496     def dump_reverse(self) -> None:
497         print('### Requests for reverse mappings ###')
498         for ipa4, name in sorted(self.ipv4_reverse.items()):
499             print(f'{ipa4}\t{name}')
500         for ipa6, name in sorted(self.ipv6_reverse.items()):
501             print(f'{ipa6}\t{name}')
502
503     def fill_reverse(self) -> None:
504         for z in self.zones.values():
505             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
506                 if isinstance(z.reverse_for, IPv4Network):
507                     for addr4, ptr_list in self.ipv4_reverse.items():
508                         if addr4 in z.reverse_for:
509                             for ptr_to in ptr_list:
510                                 z._add_ipv4_reverse(addr4, ptr_to)
511                 else:
512                     for addr6, ptr_list in self.ipv6_reverse.items():
513                         if addr6 in z.reverse_for:
514                             for ptr_to in ptr_list:
515                                 z._add_ipv6_reverse(addr6, ptr_to)
516
517     def get_zones(self) -> List[NscZone]:
518         return [self.zones[k] for k in sorted(self.zones.keys())]
519
520     def process(self) -> None:
521         self.fill_reverse()
522         for z in self.get_zones():
523             z.process()