]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Rename test.py -> example.py
[pynsc.git] / nsconfig / core.py
1 # PyNSC: Main data structures
2 # (c) 2024 Martin Mareš <mj@ucw.cz>
3
4 from collections import defaultdict
5 from datetime import datetime, timedelta
6 import dns.name
7 from dns.name import Name
8 from dns.node import Node
9 from dns.rdata import Rdata
10 from dns.rdataclass import RdataClass
11 from dns.rdatatype import RdataType
12 import dns.rdtypes.ANY.CNAME
13 import dns.rdtypes.ANY.MX
14 import dns.rdtypes.ANY.NS
15 import dns.rdtypes.ANY.PTR
16 import dns.rdtypes.ANY.SOA
17 import dns.rdtypes.ANY.TXT
18 import dns.rdtypes.IN.A
19 import dns.rdtypes.IN.AAAA
20 from dns.zone import Zone
21 from enum import Enum, auto
22 import hashlib
23 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
24 import json
25 from pathlib import Path
26 import socket
27 import sys
28 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
29
30 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration
31 from nsconfig.util import IPAddress, IPNetwork, IPAddr
32
33
34 if TYPE_CHECKING:
35     from nsconfig.daemon import NscDaemon
36
37
38 class NscNode:
39     nsc_zone: 'NscZonePrimary'
40     name: str
41     node: Node
42     _ttl: int
43
44     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
45         self.nsc_zone = nsc_zone
46         self.name = name
47         self.node = nsc_zone.zone.find_node(name, create=True)
48         self._ttl = nsc_zone.config.default_ttl
49
50     def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
51         if seconds is not None:
52             self._ttl = seconds
53         elif kwargs:
54             self._ttl = parse_duration(timedelta(**kwargs))
55         else:
56             self._ttl = self.nsc_zone.config.default_ttl
57         return self
58
59     def _add(self, rec: Rdata) -> None:
60         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
61         rds.add(rec, ttl=self._ttl)
62
63     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
64         for a in map(parse_address, flatten_list(addrs)):
65             if isinstance(a, IPv4Address):
66                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
67             else:
68                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
69             if reverse:
70                 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name, origin=self.nsc_zone.dns_name))
71         return self
72
73     def MX(self, pri: int, name: str) -> Self:
74         self._add(
75             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
76         )
77         return self
78
79     def NS(self, *names: str | List[str]) -> Self:
80         for name in map(parse_name, flatten_list(names)):
81             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
82         return self
83
84     def TXT(self, *text: str | List[str]) -> Self:
85         for txt in flatten_list(text):
86             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
87         return self
88
89     def PTR(self, target: Name | str) -> Self:
90         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
91         return self
92
93     def CNAME(self, target: Name | str) -> Self:
94         self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
95         return self
96
97     def generic(self, typ: str, text: str) -> Self:
98         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
99         return self
100
101
102 class NscZoneConfig:
103     admin_email: str
104     refresh: int
105     retry: int
106     expire: int
107     min_ttl: int
108     default_ttl: int
109     origin_server: str
110     daemon_options: List[str]
111     add_null_mx: bool
112
113     default_config: Optional['NscZoneConfig'] = None
114
115     def __init__(self,
116                  admin_email: Optional[str] = None,
117                  refresh: Optional[int | timedelta] = None,
118                  retry: Optional[int | timedelta] = None,
119                  expire: Optional[int | timedelta] = None,
120                  min_ttl: Optional[int | timedelta] = None,
121                  default_ttl: Optional[int | timedelta] = None,
122                  origin_server: Optional[str] = None,
123                  daemon_options: Optional[List[str]] = None,
124                  add_daemon_options: Optional[List[str]] = None,
125                  add_null_mx: Optional[bool] = None,
126                  inherit_config: Optional['NscZoneConfig'] = None,
127                  ) -> None:
128         if inherit_config is None:
129             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
130         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
131         self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
132         self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
133         self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
134         self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
135         self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
136         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
137         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
138         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
139         if add_daemon_options is not None:
140             self.daemon_options += add_daemon_options
141
142     def finalize(self) -> Self:
143         if not self.origin_server:
144             self.origin_server = socket.getfqdn()
145         if not self.admin_email:
146             self.admin_email = f'hostmaster@{self.origin_server}'
147         if self.default_ttl == 0:
148             self.default_ttl = self.min_ttl
149         return self
150
151
152 NscZoneConfig.default_config = NscZoneConfig(
153     admin_email="",
154     refresh=timedelta(hours=8),
155     retry=timedelta(hours=2),
156     expire=timedelta(days=14),
157     min_ttl=timedelta(days=1),
158     default_ttl=0,
159     origin_server="",
160     daemon_options=[],
161     add_null_mx=False,
162 )
163
164
165 class NscZoneState:
166     serial: int
167     hash: str
168
169     def __init__(self) -> None:
170         self.serial = 0
171         self.hash = 'none'
172
173     def load(self, file: Path) -> None:
174         try:
175             with open(file) as f:
176                 js = json.load(f)
177                 assert isinstance(js, dict)
178                 if 'serial' in js:
179                     self.serial = js['serial']
180                 if 'hash' in js:
181                     self.hash = js['hash']
182         except FileNotFoundError:
183             pass
184
185     def save(self, file: Path) -> None:
186         new_file = Path(str(file) + '.new')
187         with open(new_file, 'w') as f:
188             js = {
189                 'serial': self.serial,
190                 'hash': self.hash,
191             }
192             json.dump(js, f, indent=4, sort_keys=True)
193         new_file.replace(file)
194
195
196 class ZoneType(Enum):
197     primary = auto()
198     secondary = auto()
199     alias = auto()
200
201
202 class NscZone:
203     nsc: 'Nsc'
204     name: str
205     dns_name: Name
206     safe_name: str                          # For use in file names
207     zone_type: ZoneType
208     reverse_for: Optional[IPNetwork]
209
210     def __init__(self,
211                  nsc: 'Nsc',
212                  name: str,
213                  reverse_for: Optional[IPNetwork],
214                  **kwargs) -> None:
215         self.nsc = nsc
216         self.name = name
217         self.dns_name = dns.name.from_text(name)
218         self.safe_name = name.replace('/', '@')
219         self.config = NscZoneConfig(**kwargs).finalize()
220         self.reverse_for = reverse_for
221
222     def process(self) -> None:
223         pass
224
225     def is_changed(self) -> bool:
226         return False
227
228
229 class NscZonePrimary(NscZone):
230     zone: Zone
231     zone_file: Path
232     state_file: Path
233     state: NscZoneState
234     prev_state: NscZoneState
235     aliases: List['NscZoneAlias']
236
237     def __init__(self, *args, **kwargs) -> None:
238         super().__init__(*args, **kwargs)
239
240         self.zone_type = ZoneType.primary
241         self.zone_file = self.nsc.zone_dir / self.safe_name
242         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
243
244         self.state = NscZoneState()
245         self.prev_state = NscZoneState()
246         self.prev_state.load(self.state_file)
247
248         self.aliases = []
249
250         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
251         self.update_soa()
252
253     def update_soa(self) -> None:
254         conf = self.config
255         soa = dns.rdtypes.ANY.SOA.SOA(
256             RdataClass.IN, RdataType.SOA,
257             mname=conf.origin_server,
258             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
259             serial=self.state.serial,
260             refresh=conf.refresh,
261             retry=conf.retry,
262             expire=conf.expire,
263             minimum=conf.min_ttl,
264         )
265         self.zone.delete_rdataset("", RdataType.SOA)
266         self[""]._add(soa)
267
268     def n(self, name: str) -> NscNode:
269         return NscNode(self, name)
270
271     def __getitem__(self, name: str) -> NscNode:
272         return NscNode(self, name)
273
274     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
275         n = NscNode(self, name)
276         n.A(*args, reverse=reverse)
277         return n
278
279     def zone_header(self) -> str:
280         return (
281             f'; Zone file for {self.name}\n'
282             + '; Generated by NSC, please do not edit manually.\n'
283             + '\n')
284
285     def dump(self, file: Optional[TextIO] = None) -> None:
286         # Could use self.zone.to_file(sys.stdout), but we want better formatting
287         file = file or sys.stdout
288         file.write(self.zone_header())
289         file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
290         last_name = None
291         for name, ttl, rec in self.zone.iterate_rdatas():
292             if name == last_name:
293                 print_name = ""
294             else:
295                 print_name = name
296             file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
297             last_name = name
298
299     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
300         assert isinstance(self.reverse_for, IPv4Network)
301         parts = str(addr).split('.')
302         parts = parts[self.reverse_for.prefixlen // 8:]
303         name = '.'.join(reversed(parts))
304         self.n(name).PTR(ptr_to)
305
306     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
307         assert isinstance(self.reverse_for, IPv6Network)
308         parts = addr.exploded.replace(':', "")
309         parts = parts[self.reverse_for.prefixlen // 4:]
310         name = '.'.join(reversed(parts))
311         self.n(name).PTR(ptr_to)
312
313     def gen_hash(self) -> None:
314         sha = hashlib.sha1()
315         sha.update(self.zone_header().encode('us-ascii'))
316         for name, ttl, rec in self.zone.iterate_rdatas():
317             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
318             sha.update(text.encode('us-ascii'))
319         self.state.hash = sha.hexdigest()[:16]
320
321     def gen_serial(self) -> None:
322         prev = self.prev_state.serial
323         if self.state.hash == self.prev_state.hash and prev > 0:
324             self.state.serial = self.prev_state.serial
325         else:
326             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
327             if prev <= base:
328                 self.state.serial = base + 1
329             else:
330                 self.state.serial = prev + 1
331                 if prev >= base + 99:
332                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
333
334     def process(self) -> None:
335         if self.config.add_null_mx:
336             self.gen_null_mx()
337         self.gen_hash()
338         self.gen_serial()
339
340     def write_zone(self) -> None:
341         self.update_soa()
342         new_file = Path(str(self.zone_file) + '.new')
343         with open(new_file, 'w') as f:
344             self.dump(file=f)
345         new_file.replace(self.zone_file)
346
347     def write_state(self) -> None:
348         self.state.save(self.state_file)
349
350     def is_changed(self) -> bool:
351         return self.state.serial != self.prev_state.serial
352
353     def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
354         net = parse_network(net)
355         assert self.reverse_for is not None
356         assert isinstance(self.reverse_for, IPv4Network)
357         assert self.reverse_for.prefixlen % 8 == 0
358         assert isinstance(net, IPv4Network)
359         assert net.subnet_of(self.reverse_for)
360         assert net.prefixlen < self.reverse_for.prefixlen + 8
361
362         start = int(net.network_address.packed[net.prefixlen // 8])
363         num = 1 << (8 - net.prefixlen % 8)
364
365         if subdomain is None:
366             subdomain = f'{start}/{net.prefixlen}'
367
368         for i in range(start, start + num):
369             target = f'{i}.{subdomain}'
370             self[str(i)].CNAME(parse_name(target, relative=True))
371
372         return self[subdomain]
373
374     def gen_null_mx(self) -> None:
375         for name, node in self.zone.items():
376             rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
377             rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
378             if rds_a or rds_aaaa:
379                 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
380                 if not mx_rds:
381                     mx_rds.add(
382                         dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
383                         ttl=self.config.default_ttl,
384                     )
385
386
387 class NscZoneSecondary(NscZone):
388     primary_server: IPAddress
389     secondary_file: Path
390
391     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
392         super().__init__(*args, **kwargs)
393         self.zone_type = ZoneType.secondary
394         self.primary_server = primary_server
395         self.secondary_file = self.nsc.secondary_dir / self.safe_name
396
397
398 class NscZoneAlias(NscZone):
399     alias_for: NscZonePrimary
400
401     def __init__(self, *args, alias_for=NscZonePrimary, **kwargs) -> None:
402         assert isinstance(alias_for, NscZonePrimary)
403         super().__init__(*args, **kwargs)
404         self.zone_type = ZoneType.alias
405         self.alias_for = alias_for
406         self.zone_file = alias_for.zone_file
407         alias_for.aliases.append(self)
408
409     def is_changed(self) -> bool:
410         return self.alias_for.is_changed()
411
412
413 class Nsc:
414     start_time: datetime
415     zones: Dict[str, NscZone]
416     default_zone_config: NscZoneConfig
417     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
418     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
419     root_dir: Path
420     state_dir: Path
421     zone_dir: Path
422     secondary_dir: Path
423     daemon: 'NscDaemon'  # Set by DaemonConfig class
424
425     def __init__(self,
426                  directory: str = '.',
427                  daemon: Optional['NscDaemon'] = None,
428                  **kwargs) -> None:
429         self.start_time = datetime.now()
430         self.zones = {}
431         self.default_zone_config = NscZoneConfig(**kwargs)
432         self.ipv4_reverse = defaultdict(list)
433         self.ipv6_reverse = defaultdict(list)
434
435         self.root_dir = Path(directory)
436         self.state_dir = self.root_dir / 'state'
437         self.state_dir.mkdir(parents=True, exist_ok=True)
438         self.zone_dir = self.root_dir / 'zone'
439         self.zone_dir.mkdir(parents=True, exist_ok=True)
440         self.secondary_dir = self.root_dir / 'secondary'
441         self.secondary_dir.mkdir(parents=True, exist_ok=True)
442
443         if daemon is None:
444             from nsconfig.daemon import NscDaemonNull
445             daemon = NscDaemonNull()
446         self.daemon = daemon
447         daemon.setup(self)
448
449     def add_zone(self,
450                  name: Optional[str] = None,
451                  reverse_for: str | IPNetwork | None = None,
452                  alias_for: Optional[NscZonePrimary] = None,
453                  follow_primary: str | IPAddress | None = None,
454                  inherit_config: Optional[NscZoneConfig] = None,
455                  **kwargs) -> Zone:
456         if inherit_config is None:
457             inherit_config = self.default_zone_config
458
459         if reverse_for is not None:
460             if isinstance(reverse_for, str):
461                 reverse_for = ip_network(reverse_for, strict=True)
462             name = name or self._reverse_zone_name(reverse_for)
463         assert name is not None
464         assert name not in self.zones
465
466         z: NscZone
467         if alias_for is not None:
468             assert follow_primary is None
469             z = NscZoneAlias(self, name, reverse_for=reverse_for, alias_for=alias_for, inherit_config=inherit_config, **kwargs)
470         elif follow_primary is None:
471             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
472         else:
473             if isinstance(follow_primary, str):
474                 follow_primary = ip_address(follow_primary)
475             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
476
477         self.zones[name] = z
478         return z
479
480     def __getitem__(self, name: str) -> NscZone:
481         return self.zones[name]
482
483     def _reverse_zone_name(self, net: IPNetwork) -> str:
484         if isinstance(net, IPv4Network):
485             parts = str(net.network_address).split('.')
486             out = parts[:net.prefixlen // 8]
487             if net.prefixlen % 8 != 0:
488                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
489             return '.'.join(reversed(out)) + '.in-addr.arpa'
490         elif isinstance(net, IPv6Network):
491             assert net.prefixlen % 4 == 0
492             nibbles = net.network_address.exploded.replace(':', "")
493             nibbles = nibbles[:net.prefixlen // 4]
494             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
495         else:
496             raise NotImplementedError()
497
498     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
499         if isinstance(addr, IPv4Address):
500             self.ipv4_reverse[addr].append(ptr_to)
501         else:
502             self.ipv6_reverse[addr].append(ptr_to)
503
504     def dump_reverse(self) -> None:
505         print('### Requests for reverse mappings ###')
506         for ipa4, name in sorted(self.ipv4_reverse.items()):
507             print(f'{ipa4}\t{name}')
508         for ipa6, name in sorted(self.ipv6_reverse.items()):
509             print(f'{ipa6}\t{name}')
510
511     def fill_reverse(self) -> None:
512         for z in self.zones.values():
513             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
514                 if isinstance(z.reverse_for, IPv4Network):
515                     for addr4, ptr_list in self.ipv4_reverse.items():
516                         if addr4 in z.reverse_for:
517                             for ptr_to in ptr_list:
518                                 z._add_ipv4_reverse(addr4, ptr_to)
519                 else:
520                     for addr6, ptr_list in self.ipv6_reverse.items():
521                         if addr6 in z.reverse_for:
522                             for ptr_to in ptr_list:
523                                 z._add_ipv6_reverse(addr6, ptr_to)
524
525     def get_zones(self) -> List[NscZone]:
526         return [self.zones[k] for k in sorted(self.zones.keys())]
527
528     def process(self) -> None:
529         self.fill_reverse()
530         for z in self.get_zones():
531             z.process()