]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Rename test.py -> example.py
[pynsc.git] / nsconfig / core.py
1 # PyNSC: Main data structures
2 # (c) 2024 Martin Mareš <mj@ucw.cz>
3
4 from collections import defaultdict
5 from datetime import datetime, timedelta
6 import dns.name
7 from dns.name import Name
8 from dns.node import Node
9 from dns.rdata import Rdata
10 from dns.rdataclass import RdataClass
11 from dns.rdatatype import RdataType
12 import dns.rdtypes.ANY.CNAME
13 import dns.rdtypes.ANY.DNAME
14 import dns.rdtypes.ANY.MX
15 import dns.rdtypes.ANY.NS
16 import dns.rdtypes.ANY.PTR
17 import dns.rdtypes.ANY.SOA
18 import dns.rdtypes.ANY.TXT
19 import dns.rdtypes.IN.A
20 import dns.rdtypes.IN.AAAA
21 import dns.rdtypes.IN.SRV
22 from dns.zone import Zone
23 from enum import Enum, auto
24 import hashlib
25 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
26 import json
27 from pathlib import Path
28 import socket
29 import sys
30 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, Tuple, TYPE_CHECKING
31
32 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration, parse_rname
33 from nsconfig.util import IPAddress, IPNetwork, IPAddr, NameParseMode
34
35
36 if TYPE_CHECKING:
37     from nsconfig.daemon import NscDaemon
38
39
40 class NscNode:
41     nsc_zone: 'NscZonePrimary'
42     name: Name
43     node: Node
44     _ttl: int
45
46     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
47         self.nsc_zone = nsc_zone
48         self.name = self._parse_lhs_name(name)
49         self.node = nsc_zone.zone.find_node(self.name, create=True)
50         self._ttl = nsc_zone.config.default_ttl
51
52     def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
53         if seconds is not None:
54             self._ttl = seconds
55         elif kwargs:
56             self._ttl = parse_duration(timedelta(**kwargs))
57         else:
58             self._ttl = self.nsc_zone.config.default_ttl
59         return self
60
61     def _add(self, rec: Rdata) -> None:
62         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
63         rds.add(rec, ttl=self._ttl)
64
65     def _parse_lhs_name(self, name, **kwargs):
66         return parse_name(name, mode=NameParseMode.relative, **kwargs)
67
68     def _parse_rhs_name(self, name, **kwargs):
69         return parse_name(name, mode=self.nsc_zone.config.name_parse_mode, **kwargs)
70
71     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
72         for a in map(parse_address, flatten_list(addrs)):
73             if isinstance(a, IPv4Address):
74                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
75             else:
76                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
77             if reverse:
78                 self.nsc_zone.nsc._add_reverse_mapping(a, self.name.choose_relativity(origin=self.nsc_zone.dns_name, relativize=False))
79         return self
80
81     def CNAME(self, target: Name | str) -> Self:
82         self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
83         return self
84
85     def DNAME(self, target: Name | str) -> Self:
86         self._add(dns.rdtypes.ANY.DNAME.DNAME(RdataClass.IN, RdataType.DNAME, target))
87         return self
88
89     def MX(self, pri: int, name: str) -> Self:
90         self._add(
91             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_rhs_name(name))
92         )
93         return self
94
95     def MX_list(self, mxs: List[Tuple[int, str]]) -> Self:
96         for pri, name in mxs:
97             self.MX(pri, name)
98         return self
99
100     def NS(self, *names: str | List[str]) -> Self:
101         for name in flatten_list(names):
102             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, self._parse_rhs_name(name)))
103         return self
104
105     def PTR(self, target: Name | str) -> Self:
106         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
107         return self
108
109     def SRV(self, priority: int, weight: int, port: int, target: Name | str) -> Self:
110         self._add(dns.rdtypes.IN.SRV.SRV(RdataClass.IN, RdataType.SRV, priority, weight, port, self._parse_rhs_name(target)))
111         return self
112
113     def TXT(self, *text: str | List[str]) -> Self:
114         for txt in flatten_list(text):
115             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
116         return self
117
118     def generic(self, typ: str, text: str) -> Self:
119         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
120         return self
121
122     def alias(self, *aliases: str) -> Self:
123         # FIXME: Inter-zone aliases?
124         for alias in flatten_list(aliases):
125             self.nsc_zone[alias].CNAME(self.name)
126         return self
127
128
129 class NscZoneConfig:
130     admin_email: str
131     refresh: int
132     retry: int
133     expire: int
134     min_ttl: int
135     default_ttl: int
136     origin_server: str
137     daemon_options: List[str]
138     add_null_mx: bool
139     name_parse_mode: NameParseMode
140
141     default_config: Optional['NscZoneConfig'] = None
142
143     def __init__(self,
144                  *,
145                  admin_email: Optional[str] = None,
146                  refresh: Optional[int | timedelta] = None,
147                  retry: Optional[int | timedelta] = None,
148                  expire: Optional[int | timedelta] = None,
149                  min_ttl: Optional[int | timedelta] = None,
150                  default_ttl: Optional[int | timedelta] = None,
151                  origin_server: Optional[str] = None,
152                  daemon_options: Optional[List[str]] = None,
153                  add_daemon_options: Optional[List[str]] = None,
154                  add_null_mx: Optional[bool] = None,
155                  name_parse_mode: Optional[NameParseMode] = None,
156                  inherit_config: Optional['NscZoneConfig'] = None,
157                  ) -> None:
158         if inherit_config is None:
159             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
160         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
161         self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
162         self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
163         self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
164         self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
165         self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
166         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
167         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
168         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
169         self.name_parse_mode = name_parse_mode if name_parse_mode is not None else inherit_config.name_parse_mode
170         if add_daemon_options is not None:
171             self.daemon_options += add_daemon_options
172
173     def finalize(self) -> Self:
174         if not self.origin_server:
175             self.origin_server = socket.getfqdn()
176         if not self.admin_email:
177             self.admin_email = f'hostmaster@{self.origin_server}'
178         if self.default_ttl == 0:
179             self.default_ttl = self.min_ttl
180         return self
181
182
183 NscZoneConfig.default_config = NscZoneConfig(
184     admin_email="",
185     refresh=timedelta(hours=8),
186     retry=timedelta(hours=2),
187     expire=timedelta(days=14),
188     min_ttl=timedelta(days=1),
189     default_ttl=0,
190     origin_server="",
191     daemon_options=[],
192     add_null_mx=False,
193     name_parse_mode=NameParseMode.absolute,
194 )
195
196
197 class NscZoneState:
198     serial: int
199     hash: str
200
201     def __init__(self) -> None:
202         self.serial = 0
203         self.hash = 'none'
204
205     def load(self, file: Path) -> None:
206         try:
207             with open(file) as f:
208                 js = json.load(f)
209                 assert isinstance(js, dict)
210                 if 'serial' in js:
211                     self.serial = js['serial']
212                 if 'hash' in js:
213                     self.hash = js['hash']
214         except FileNotFoundError:
215             pass
216
217     def save(self, file: Path) -> None:
218         new_file = Path(str(file) + '.new')
219         with open(new_file, 'w') as f:
220             js = {
221                 'serial': self.serial,
222                 'hash': self.hash,
223             }
224             json.dump(js, f, indent=4, sort_keys=True)
225         new_file.replace(file)
226
227
228 class ZoneType(Enum):
229     primary = auto()
230     secondary = auto()
231     alias = auto()
232
233
234 class NscZone:
235     nsc: 'Nsc'
236     name: str
237     dns_name: Name
238     safe_name: str                          # For use in file names
239     zone_type: ZoneType
240     config: NscZoneConfig
241     reverse_for: Optional[IPNetwork]
242
243     def __init__(self,
244                  nsc: 'Nsc',
245                  name: str,
246                  reverse_for: Optional[IPNetwork],
247                  **kwargs) -> None:
248         self.nsc = nsc
249         self.name = name
250         self.dns_name = dns.name.from_text(name)
251         self.safe_name = name.replace('/', '@')
252         self.config = NscZoneConfig(**kwargs).finalize()
253         self.reverse_for = reverse_for
254
255     def process(self) -> None:
256         pass
257
258     def is_changed(self) -> bool:
259         return False
260
261
262 class NscZonePrimary(NscZone):
263     zone: Zone
264     zone_file: Path
265     state_file: Path
266     state: NscZoneState
267     prev_state: NscZoneState
268     aliases: List['NscZoneAlias']
269
270     def __init__(self, *args, **kwargs) -> None:
271         super().__init__(*args, **kwargs)
272
273         self.zone_type = ZoneType.primary
274         self.zone_file = self.nsc.zone_dir / self.safe_name
275         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
276
277         self.state = NscZoneState()
278         self.prev_state = NscZoneState()
279         self.prev_state.load(self.state_file)
280
281         self.aliases = []
282
283         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
284         self.update_soa()
285
286     @property
287     def root(self):
288         return NscNode(self, "")
289
290     def update_soa(self) -> None:
291         conf = self.config
292         soa = dns.rdtypes.ANY.SOA.SOA(
293             RdataClass.IN, RdataType.SOA,
294             mname=conf.origin_server,
295             rname=parse_rname(conf.admin_email),
296             serial=self.state.serial,
297             refresh=conf.refresh,
298             retry=conf.retry,
299             expire=conf.expire,
300             minimum=conf.min_ttl,
301         )
302         self.zone.delete_rdataset("", RdataType.SOA)
303         self.root._add(soa)
304
305     def n(self, name: str) -> NscNode:
306         return NscNode(self, name)
307
308     def __getitem__(self, name: str) -> NscNode:
309         return NscNode(self, name)
310
311     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
312         n = NscNode(self, name)
313         n.A(*args, reverse=reverse)
314         return n
315
316     def zone_header(self) -> str:
317         return (
318             f'; Zone file for {self.name}\n'
319             + '; Generated by NSC, please do not edit manually.\n'
320             + '\n')
321
322     def dump(self, file: Optional[TextIO] = None) -> None:
323         # Could use self.zone.to_file(sys.stdout), but we want better formatting
324         file = file or sys.stdout
325         file.write(self.zone_header())
326         file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
327         last_name = None
328         for name, ttl, rec in self.zone.iterate_rdatas():
329             if name == last_name:
330                 print_name = ""
331             else:
332                 print_name = name
333             file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
334             last_name = name
335
336     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
337         assert isinstance(self.reverse_for, IPv4Network)
338         parts = str(addr).split('.')
339         parts = parts[self.reverse_for.prefixlen // 8:]
340         name = '.'.join(reversed(parts))
341         self.n(name).PTR(ptr_to)
342
343     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
344         assert isinstance(self.reverse_for, IPv6Network)
345         parts = addr.exploded.replace(':', "")
346         parts = parts[self.reverse_for.prefixlen // 4:]
347         name = '.'.join(reversed(parts))
348         self.n(name).PTR(ptr_to)
349
350     def gen_hash(self) -> None:
351         sha = hashlib.sha1()
352         sha.update(self.zone_header().encode('us-ascii'))
353         for name, ttl, rec in self.zone.iterate_rdatas():
354             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
355             sha.update(text.encode('us-ascii'))
356         self.state.hash = sha.hexdigest()[:16]
357
358     def gen_serial(self) -> None:
359         prev = self.prev_state.serial
360         if self.state.hash == self.prev_state.hash and prev > 0:
361             self.state.serial = self.prev_state.serial
362         else:
363             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
364             if prev <= base:
365                 self.state.serial = base + 1
366             else:
367                 self.state.serial = prev + 1
368                 if prev >= base + 99:
369                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
370
371     def process(self) -> None:
372         if self.config.add_null_mx:
373             self.gen_null_mx()
374         self.gen_hash()
375         self.gen_serial()
376
377     def write_zone(self) -> None:
378         self.update_soa()
379         new_file = Path(str(self.zone_file) + '.new')
380         with open(new_file, 'w') as f:
381             self.dump(file=f)
382         new_file.replace(self.zone_file)
383
384     def write_state(self) -> None:
385         self.state.save(self.state_file)
386
387     def is_changed(self) -> bool:
388         return self.state.serial != self.prev_state.serial
389
390     def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
391         net = parse_network(net)
392         assert self.reverse_for is not None
393         assert isinstance(self.reverse_for, IPv4Network)
394         assert self.reverse_for.prefixlen % 8 == 0
395         assert isinstance(net, IPv4Network)
396         assert net.subnet_of(self.reverse_for)
397         assert net.prefixlen < self.reverse_for.prefixlen + 8
398
399         start = int(net.network_address.packed[net.prefixlen // 8])
400         num = 1 << (8 - net.prefixlen % 8)
401
402         if subdomain is None:
403             subdomain = f'{start}/{net.prefixlen}'
404
405         for i in range(start, start + num):
406             target = f'{i}.{subdomain}'
407             self[str(i)].CNAME(parse_name(target, mode=NameParseMode.relative))
408
409         return self[subdomain]
410
411     def gen_null_mx(self) -> None:
412         for name, node in self.zone.items():
413             rds_a = node.get_rdataset(RdataClass.IN, RdataType.A)
414             rds_aaaa = node.get_rdataset(RdataClass.IN, RdataType.AAAA)
415             if rds_a or rds_aaaa:
416                 mx_rds = node.get_rdataset(RdataClass.IN, RdataType.MX, create=True)
417                 if not mx_rds:
418                     mx_rds.add(
419                         dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
420                         ttl=self.config.default_ttl,
421                     )
422
423
424 class NscZoneSecondary(NscZone):
425     primary_server: IPAddress
426     secondary_file: Path
427
428     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
429         super().__init__(*args, **kwargs)
430         self.zone_type = ZoneType.secondary
431         self.primary_server = primary_server
432         self.secondary_file = self.nsc.secondary_dir / self.safe_name
433
434
435 class NscZoneAlias(NscZone):
436     alias_for: NscZonePrimary
437
438     def __init__(self, *args, alias_for=NscZonePrimary, **kwargs) -> None:
439         assert isinstance(alias_for, NscZonePrimary)
440         super().__init__(*args, **kwargs)
441         self.zone_type = ZoneType.alias
442         self.alias_for = alias_for
443         self.zone_file = alias_for.zone_file
444         alias_for.aliases.append(self)
445
446     def is_changed(self) -> bool:
447         return self.alias_for.is_changed()
448
449
450 class Nsc:
451     start_time: datetime
452     zones: Dict[str, NscZone]
453     default_zone_config: NscZoneConfig
454     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
455     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
456     root_dir: Path
457     state_dir: Path
458     zone_dir: Path
459     secondary_dir: Path
460     daemon: 'NscDaemon'
461
462     def __init__(self,
463                  *,
464                  directory: str = '.',
465                  daemon: Optional['NscDaemon'] = None,
466                  **kwargs) -> None:
467         self.start_time = datetime.now()
468         self.zones = {}
469         self.default_zone_config = NscZoneConfig(**kwargs)
470         self.ipv4_reverse = defaultdict(list)
471         self.ipv6_reverse = defaultdict(list)
472
473         self.root_dir = Path(directory)
474         self.state_dir = self.root_dir / 'state'
475         self.state_dir.mkdir(parents=True, exist_ok=True)
476         self.zone_dir = self.root_dir / 'zone'
477         self.zone_dir.mkdir(parents=True, exist_ok=True)
478         self.secondary_dir = self.root_dir / 'secondary'
479         self.secondary_dir.mkdir(parents=True, exist_ok=True)
480
481         if daemon is None:
482             from nsconfig.daemon import NscDaemonNull
483             daemon = NscDaemonNull()
484         self.daemon = daemon
485         daemon.setup(self)
486
487     def add_zone(self,
488                  name: Optional[str] = None,
489                  *,
490                  reverse_for: str | IPNetwork | None = None,
491                  alias_for: Optional[NscZonePrimary] = None,
492                  secondary_to: str | IPAddress | None = None,
493                  inherit_config: Optional[NscZoneConfig] = None,
494                  **kwargs) -> Zone:
495         if inherit_config is None:
496             inherit_config = self.default_zone_config
497
498         if reverse_for is not None:
499             if isinstance(reverse_for, str):
500                 reverse_for = ip_network(reverse_for, strict=True)
501             name = name or self._reverse_zone_name(reverse_for)
502         assert name is not None
503         assert name not in self.zones
504
505         z: NscZone
506         if alias_for is not None:
507             assert secondary_to is None
508             z = NscZoneAlias(self, name, reverse_for=reverse_for, alias_for=alias_for, inherit_config=inherit_config, **kwargs)
509         elif secondary_to is None:
510             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
511         else:
512             if isinstance(secondary_to, str):
513                 secondary_to = ip_address(secondary_to)
514             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=secondary_to, inherit_config=inherit_config, **kwargs)
515
516         self.zones[name] = z
517         return z
518
519     def __getitem__(self, name: str) -> NscZone:
520         return self.zones[name]
521
522     def _reverse_zone_name(self, net: IPNetwork) -> str:
523         if isinstance(net, IPv4Network):
524             parts = str(net.network_address).split('.')
525             out = parts[:net.prefixlen // 8]
526             if net.prefixlen % 8 != 0:
527                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
528             return '.'.join(reversed(out)) + '.in-addr.arpa'
529         elif isinstance(net, IPv6Network):
530             assert net.prefixlen % 4 == 0
531             nibbles = net.network_address.exploded.replace(':', "")
532             nibbles = nibbles[:net.prefixlen // 4]
533             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
534         else:
535             raise NotImplementedError()
536
537     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
538         if isinstance(addr, IPv4Address):
539             self.ipv4_reverse[addr].append(ptr_to)
540         else:
541             self.ipv6_reverse[addr].append(ptr_to)
542
543     def dump_reverse(self) -> None:
544         print('### Requests for reverse mappings ###')
545         for ipa4, name in sorted(self.ipv4_reverse.items()):
546             print(f'{ipa4}\t{name}')
547         for ipa6, name in sorted(self.ipv6_reverse.items()):
548             print(f'{ipa6}\t{name}')
549
550     def fill_reverse(self) -> None:
551         for z in self.zones.values():
552             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
553                 if isinstance(z.reverse_for, IPv4Network):
554                     for addr4, ptr_list in self.ipv4_reverse.items():
555                         if addr4 in z.reverse_for:
556                             for ptr_to in ptr_list:
557                                 z._add_ipv4_reverse(addr4, ptr_to)
558                 else:
559                     for addr6, ptr_list in self.ipv6_reverse.items():
560                         if addr6 in z.reverse_for:
561                             for ptr_to in ptr_list:
562                                 z._add_ipv6_reverse(addr6, ptr_to)
563
564     def get_zones(self) -> List[NscZone]:
565         return [self.zones[k] for k in sorted(self.zones.keys())]
566
567     def process(self) -> None:
568         self.fill_reverse()
569         for z in self.get_zones():
570             z.process()