]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Implement add_null_mx
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.CNAME
10 import dns.rdtypes.ANY.MX
11 import dns.rdtypes.ANY.NS
12 import dns.rdtypes.ANY.PTR
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from enum import Enum, auto
19 import hashlib
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 import json
22 from pathlib import Path
23 import socket
24 import sys
25 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
26
27 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name
28 from nsconfig.util import IPAddress, IPNetwork, IPAddr
29
30
31 if TYPE_CHECKING:
32     from nsconfig.daemon import NscDaemon
33
34
35 class NscNode:
36     nsc_zone: 'NscZonePrimary'
37     name: str
38     node: Node
39     _ttl: int
40
41     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
42         self.nsc_zone = nsc_zone
43         self.name = name
44         self.node = nsc_zone.zone.find_node(name, create=True)
45         self._ttl = nsc_zone._min_ttl
46
47     def ttl(self, *args, **kwargs) -> Self:
48         if not args and not kwargs:
49             self._ttl = self.nsc_zone._min_ttl
50         else:
51             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
52         return self
53
54     def _add(self, rec: Rdata) -> None:
55         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
56         rds.add(rec, ttl=self._ttl)
57
58     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
59         for a in map(parse_address, flatten_list(addrs)):
60             if isinstance(a, IPv4Address):
61                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
62             else:
63                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
64             if reverse:
65                 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name + '.' + self.nsc_zone.name))
66         return self
67
68     def MX(self, pri: int, name: str) -> Self:
69         self._add(
70             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
71         )
72         return self
73
74     def NS(self, *names: str | List[str]) -> Self:
75         for name in map(parse_name, flatten_list(names)):
76             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
77         return self
78
79     def TXT(self, *text: str | List[str]) -> Self:
80         for txt in flatten_list(text):
81             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
82         return self
83
84     def PTR(self, target: Name | str) -> Self:
85         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
86         return self
87
88     def CNAME(self, target: Name | str) -> Self:
89         self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
90         return self
91
92     def generic(self, typ: str, text: str) -> Self:
93         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
94         return self
95
96
97 class NscZoneConfig:
98     admin_email: str
99     refresh: timedelta
100     retry: timedelta
101     expire: timedelta
102     min_ttl: timedelta
103     origin_server: str
104     daemon_options: List[str]
105     add_null_mx: bool
106
107     default_config: Optional['NscZoneConfig'] = None
108
109     def __init__(self,
110                  admin_email: Optional[str] = None,
111                  refresh: Optional[timedelta] = None,
112                  retry: Optional[timedelta] = None,
113                  expire: Optional[timedelta] = None,
114                  min_ttl: Optional[timedelta] = None,
115                  origin_server: Optional[str] = None,
116                  daemon_options: Optional[List[str]] = None,
117                  add_daemon_options: Optional[List[str]] = None,
118                  add_null_mx: Optional[bool] = None,
119                  inherit_config: Optional['NscZoneConfig'] = None,
120                  ) -> None:
121         if inherit_config is None:
122             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
123         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
124         self.refresh = refresh if refresh is not None else inherit_config.refresh
125         self.retry = retry if retry is not None else inherit_config.retry
126         self.expire = expire if expire is not None else inherit_config.expire
127         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
128         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
129         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
130         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
131         if add_daemon_options is not None:
132             self.daemon_options += add_daemon_options
133
134     def finalize(self) -> Self:
135         if not self.origin_server:
136             self.origin_server = socket.getfqdn()
137         if not self.admin_email:
138             self.admin_email = f'hostmaster@{self.origin_server}'
139         return self
140
141
142 NscZoneConfig.default_config = NscZoneConfig(
143     admin_email="",
144     refresh=timedelta(hours=8),
145     retry=timedelta(hours=2),
146     expire=timedelta(days=14),
147     min_ttl=timedelta(days=1),
148     origin_server="",
149     daemon_options=[],
150     add_null_mx=False,
151 )
152
153
154 class NscZoneState:
155     serial: int
156     hash: str
157
158     def __init__(self) -> None:
159         self.serial = 0
160         self.hash = 'none'
161
162     def load(self, file: Path) -> None:
163         try:
164             with open(file) as f:
165                 js = json.load(f)
166                 assert isinstance(js, dict)
167                 if 'serial' in js:
168                     self.serial = js['serial']
169                 if 'hash' in js:
170                     self.hash = js['hash']
171         except FileNotFoundError:
172             pass
173
174     def save(self, file: Path) -> None:
175         new_file = Path(str(file) + '.new')
176         with open(new_file, 'w') as f:
177             js = {
178                 'serial': self.serial,
179                 'hash': self.hash,
180             }
181             json.dump(js, f, indent=4, sort_keys=True)
182         new_file.replace(file)
183
184
185 class ZoneType(Enum):
186     primary = auto()
187     secondary = auto()
188
189
190 class NscZone:
191     nsc: 'Nsc'
192     name: str
193     safe_name: str                          # For use in file names
194     zone_type: ZoneType
195     reverse_for: Optional[IPNetwork]
196
197     def __init__(self,
198                  nsc: 'Nsc',
199                  name: str,
200                  reverse_for: Optional[IPNetwork],
201                  **kwargs) -> None:
202         self.nsc = nsc
203         self.name = name
204         self.safe_name = name.replace('/', '@')
205         self.config = NscZoneConfig(**kwargs).finalize()
206         self.reverse_for = reverse_for
207
208     def process(self) -> None:
209         pass
210
211
212 class NscZonePrimary(NscZone):
213     zone: Zone
214     _min_ttl: int
215     zone_file: Path
216     state_file: Path
217     state: NscZoneState
218     prev_state: NscZoneState
219
220     def __init__(self, *args, **kwargs) -> None:
221         super().__init__(*args, **kwargs)
222
223         self.zone_type = ZoneType.primary
224         self.zone_file = self.nsc.zone_dir / self.safe_name
225         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
226
227         self.state = NscZoneState()
228         self.prev_state = NscZoneState()
229         self.prev_state.load(self.state_file)
230
231         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
232         self._min_ttl = int(self.config.min_ttl.total_seconds())
233         self.update_soa()
234
235     def update_soa(self) -> None:
236         conf = self.config
237         soa = dns.rdtypes.ANY.SOA.SOA(
238             RdataClass.IN, RdataType.SOA,
239             mname=conf.origin_server,
240             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
241             serial=self.state.serial,
242             refresh=int(conf.refresh.total_seconds()),
243             retry=int(conf.retry.total_seconds()),
244             expire=int(conf.expire.total_seconds()),
245             minimum=int(conf.min_ttl.total_seconds()),
246         )
247         self.zone.delete_rdataset("", RdataType.SOA)
248         self[""]._add(soa)
249
250     def n(self, name: str) -> NscNode:
251         return NscNode(self, name)
252
253     def __getitem__(self, name: str) -> NscNode:
254         return NscNode(self, name)
255
256     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
257         n = NscNode(self, name)
258         n.A(*args, reverse=reverse)
259         return n
260
261     def zone_header(self) -> str:
262         return (
263             f'; Zone file for {self.name}\n'
264             + '; Generated by NSC, please do not edit manually.\n'
265             + '\n')
266
267     def dump(self, file: Optional[TextIO] = None) -> None:
268         # Could use self.zone.to_file(sys.stdout), but we want better formatting
269         file = file or sys.stdout
270         file.write(self.zone_header())
271         last_name = None
272         for name, ttl, rec in self.zone.iterate_rdatas():
273             if name == last_name:
274                 print_name = ""
275             else:
276                 print_name = name
277             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
278             last_name = name
279
280     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
281         assert isinstance(self.reverse_for, IPv4Network)
282         parts = str(addr).split('.')
283         parts = parts[self.reverse_for.prefixlen // 8:]
284         name = '.'.join(reversed(parts))
285         self.n(name).PTR(ptr_to)
286
287     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
288         assert isinstance(self.reverse_for, IPv6Network)
289         parts = addr.exploded.replace(':', "")
290         parts = parts[self.reverse_for.prefixlen // 4:]
291         name = '.'.join(reversed(parts))
292         self.n(name).PTR(ptr_to)
293
294     def gen_hash(self) -> None:
295         sha = hashlib.sha1()
296         sha.update(self.zone_header().encode('us-ascii'))
297         for name, ttl, rec in self.zone.iterate_rdatas():
298             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
299             sha.update(text.encode('us-ascii'))
300         self.state.hash = sha.hexdigest()[:16]
301
302     def gen_serial(self) -> None:
303         prev = self.prev_state.serial
304         if self.state.hash == self.prev_state.hash and prev > 0:
305             self.state.serial = self.prev_state.serial
306         else:
307             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
308             if prev <= base:
309                 self.state.serial = base + 1
310             else:
311                 self.state.serial = prev + 1
312                 if prev >= base + 99:
313                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
314
315     def process(self) -> None:
316         if self.config.add_null_mx:
317             self.gen_null_mx()
318         self.gen_hash()
319         self.gen_serial()
320
321     def write_zone(self) -> None:
322         self.update_soa()
323         new_file = Path(str(self.zone_file) + '.new')
324         with open(new_file, 'w') as f:
325             self.dump(file=f)
326         new_file.replace(self.zone_file)
327
328     def write_state(self) -> None:
329         self.state.save(self.state_file)
330
331     def is_changed(self) -> bool:
332         return self.state.serial != self.prev_state.serial
333
334     def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
335         net = parse_network(net)
336         assert self.reverse_for is not None
337         assert isinstance(self.reverse_for, IPv4Network)
338         assert self.reverse_for.prefixlen % 8 == 0
339         assert isinstance(net, IPv4Network)
340         assert net.subnet_of(self.reverse_for)
341         assert net.prefixlen < self.reverse_for.prefixlen + 8
342
343         start = int(net.network_address.packed[net.prefixlen // 8])
344         num = 1 << (8 - net.prefixlen % 8)
345
346         if subdomain is None:
347             subdomain = f'{start}/{net.prefixlen}'
348
349         for i in range(start, start + num):
350             target = f'{i}.{subdomain}'
351             self[str(i)].CNAME(parse_name(target, relative=True))
352
353         return self[subdomain]
354
355     def gen_null_mx(self) -> None:
356         for name, node in self.zone.items():
357             rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
358             rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
359             if rds_a or rds_aaaa:
360                 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
361                 if not mx_rds:
362                     mx_rds.add(
363                         dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
364                         ttl=self._min_ttl,
365                     )
366
367
368 class NscZoneSecondary(NscZone):
369     primary_server: IPAddress
370     secondary_file: Path
371
372     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
373         super().__init__(*args, **kwargs)
374         self.zone_type = ZoneType.secondary
375         self.primary_server = primary_server
376         self.secondary_file = self.nsc.secondary_dir / self.safe_name
377
378
379 class Nsc:
380     start_time: datetime
381     zones: Dict[str, NscZone]
382     default_zone_config: NscZoneConfig
383     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
384     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
385     root_dir: Path
386     state_dir: Path
387     zone_dir: Path
388     secondary_dir: Path
389     daemon: 'NscDaemon'  # Set by DaemonConfig class
390
391     def __init__(self,
392                  directory: str = '.',
393                  daemon: Optional['NscDaemon'] = None,
394                  **kwargs) -> None:
395         self.start_time = datetime.now()
396         self.zones = {}
397         self.default_zone_config = NscZoneConfig(**kwargs)
398         self.ipv4_reverse = defaultdict(list)
399         self.ipv6_reverse = defaultdict(list)
400
401         self.root_dir = Path(directory)
402         self.state_dir = self.root_dir / 'state'
403         self.state_dir.mkdir(parents=True, exist_ok=True)
404         self.zone_dir = self.root_dir / 'zone'
405         self.zone_dir.mkdir(parents=True, exist_ok=True)
406         self.secondary_dir = self.root_dir / 'secondary'
407         self.secondary_dir.mkdir(parents=True, exist_ok=True)
408
409         if daemon is None:
410             from nsconfig.daemon import NscDaemonNull
411             daemon = NscDaemonNull()
412         self.daemon = daemon
413         daemon.setup(self)
414
415     def add_zone(self,
416                  name: Optional[str] = None,
417                  reverse_for: str | IPNetwork | None = None,
418                  follow_primary: str | IPAddress | None = None,
419                  inherit_config: Optional[NscZoneConfig] = None,
420                  **kwargs) -> Zone:
421         if inherit_config is None:
422             inherit_config = self.default_zone_config
423
424         if reverse_for is not None:
425             if isinstance(reverse_for, str):
426                 reverse_for = ip_network(reverse_for, strict=True)
427             name = name or self._reverse_zone_name(reverse_for)
428         assert name is not None
429         assert name not in self.zones
430
431         z: NscZone
432         if follow_primary is None:
433             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
434         else:
435             if isinstance(follow_primary, str):
436                 follow_primary = ip_address(follow_primary)
437             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
438
439         self.zones[name] = z
440         return z
441
442     def _reverse_zone_name(self, net: IPNetwork) -> str:
443         if isinstance(net, IPv4Network):
444             parts = str(net.network_address).split('.')
445             out = parts[:net.prefixlen // 8]
446             if net.prefixlen % 8 != 0:
447                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
448             return '.'.join(reversed(out)) + '.in-addr.arpa'
449         elif isinstance(net, IPv6Network):
450             assert net.prefixlen % 4 == 0
451             nibbles = net.network_address.exploded.replace(':', "")
452             nibbles = nibbles[:net.prefixlen // 4]
453             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
454         else:
455             raise NotImplementedError()
456
457     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
458         if isinstance(addr, IPv4Address):
459             self.ipv4_reverse[addr].append(ptr_to)
460         else:
461             self.ipv6_reverse[addr].append(ptr_to)
462
463     def dump_reverse(self) -> None:
464         print('### Requests for reverse mappings ###')
465         for ipa4, name in sorted(self.ipv4_reverse.items()):
466             print(f'{ipa4}\t{name}')
467         for ipa6, name in sorted(self.ipv6_reverse.items()):
468             print(f'{ipa6}\t{name}')
469
470     def fill_reverse(self) -> None:
471         for z in self.zones.values():
472             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
473                 if isinstance(z.reverse_for, IPv4Network):
474                     for addr4, ptr_list in self.ipv4_reverse.items():
475                         if addr4 in z.reverse_for:
476                             for ptr_to in ptr_list:
477                                 z._add_ipv4_reverse(addr4, ptr_to)
478                 else:
479                     for addr6, ptr_list in self.ipv6_reverse.items():
480                         if addr6 in z.reverse_for:
481                             for ptr_to in ptr_list:
482                                 z._add_ipv6_reverse(addr6, ptr_to)
483
484     def get_zones(self) -> List[NscZone]:
485         return [self.zones[k] for k in sorted(self.zones.keys())]
486
487     def process(self) -> None:
488         self.fill_reverse()
489         for z in self.get_zones():
490             z.process()