]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Add mechanism for classless reverse delegations
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.CNAME
10 import dns.rdtypes.ANY.MX
11 import dns.rdtypes.ANY.NS
12 import dns.rdtypes.ANY.PTR
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from enum import Enum, auto
19 import hashlib
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 import json
22 from pathlib import Path
23 import socket
24 import sys
25 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
26
27 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name
28 from nsconfig.util import IPAddress, IPNetwork, IPAddr
29
30
31 if TYPE_CHECKING:
32     from nsconfig.daemon import NscDaemon
33
34
35 class NscNode:
36     nsc_zone: 'NscZonePrimary'
37     name: str
38     node: Node
39     _ttl: int
40
41     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
42         self.nsc_zone = nsc_zone
43         self.name = name
44         self.node = nsc_zone.zone.find_node(name, create=True)
45         self._ttl = nsc_zone._min_ttl
46
47     def ttl(self, *args, **kwargs) -> Self:
48         if not args and not kwargs:
49             self._ttl = self.nsc_zone._min_ttl
50         else:
51             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
52         return self
53
54     def _add(self, rec: Rdata) -> None:
55         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
56         rds.add(rec, ttl=self._ttl)
57
58     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
59         for a in map(parse_address, flatten_list(addrs)):
60             if isinstance(a, IPv4Address):
61                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
62             else:
63                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
64             if reverse:
65                 self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name + '.' + self.nsc_zone.name))
66         return self
67
68     def MX(self, pri: int, name: str) -> Self:
69         self._add(
70             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
71         )
72         return self
73
74     def NS(self, *names: str | List[str]) -> Self:
75         for name in map(parse_name, flatten_list(names)):
76             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
77         return self
78
79     def TXT(self, *text: str | List[str]) -> Self:
80         for txt in flatten_list(text):
81             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
82         return self
83
84     def PTR(self, target: Name | str) -> Self:
85         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
86         return self
87
88     def CNAME(self, target: Name | str) -> Self:
89         self._add(dns.rdtypes.ANY.CNAME.CNAME(RdataClass.IN, RdataType.CNAME, target))
90         return self
91
92     def generic(self, typ: str, text: str) -> Self:
93         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
94         return self
95
96
97 class NscZoneConfig:
98     admin_email: str
99     refresh: timedelta
100     retry: timedelta
101     expire: timedelta
102     min_ttl: timedelta
103     origin_server: str
104     daemon_options: List[str]
105
106     default_config: Optional['NscZoneConfig'] = None
107
108     def __init__(self,
109                  admin_email: Optional[str] = None,
110                  refresh: Optional[timedelta] = None,
111                  retry: Optional[timedelta] = None,
112                  expire: Optional[timedelta] = None,
113                  min_ttl: Optional[timedelta] = None,
114                  origin_server: Optional[str] = None,
115                  daemon_options: Optional[List[str]] = None,
116                  add_daemon_options: Optional[List[str]] = None,
117                  inherit_config: Optional['NscZoneConfig'] = None,
118                  ) -> None:
119         if inherit_config is None:
120             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
121         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
122         self.refresh = refresh if refresh is not None else inherit_config.refresh
123         self.retry = retry if retry is not None else inherit_config.retry
124         self.expire = expire if expire is not None else inherit_config.expire
125         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
126         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
127         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
128         if add_daemon_options is not None:
129             self.daemon_options += add_daemon_options
130
131     def finalize(self) -> Self:
132         if not self.origin_server:
133             self.origin_server = socket.getfqdn()
134         if not self.admin_email:
135             self.admin_email = f'hostmaster@{self.origin_server}'
136         return self
137
138
139 NscZoneConfig.default_config = NscZoneConfig(
140     admin_email="",
141     refresh=timedelta(hours=8),
142     retry=timedelta(hours=2),
143     expire=timedelta(days=14),
144     min_ttl=timedelta(days=1),
145     origin_server="",
146     daemon_options=[],
147 )
148
149
150 class NscZoneState:
151     serial: int
152     hash: str
153
154     def __init__(self) -> None:
155         self.serial = 0
156         self.hash = 'none'
157
158     def load(self, file: Path) -> None:
159         try:
160             with open(file) as f:
161                 js = json.load(f)
162                 assert isinstance(js, dict)
163                 if 'serial' in js:
164                     self.serial = js['serial']
165                 if 'hash' in js:
166                     self.hash = js['hash']
167         except FileNotFoundError:
168             pass
169
170     def save(self, file: Path) -> None:
171         new_file = Path(str(file) + '.new')
172         with open(new_file, 'w') as f:
173             js = {
174                 'serial': self.serial,
175                 'hash': self.hash,
176             }
177             json.dump(js, f, indent=4, sort_keys=True)
178         new_file.replace(file)
179
180
181 class ZoneType(Enum):
182     primary = auto()
183     secondary = auto()
184
185
186 class NscZone:
187     nsc: 'Nsc'
188     name: str
189     safe_name: str                          # For use in file names
190     zone_type: ZoneType
191     reverse_for: Optional[IPNetwork]
192
193     def __init__(self,
194                  nsc: 'Nsc',
195                  name: str,
196                  reverse_for: Optional[IPNetwork],
197                  **kwargs) -> None:
198         self.nsc = nsc
199         self.name = name
200         self.safe_name = name.replace('/', '@')
201         self.config = NscZoneConfig(**kwargs).finalize()
202         self.reverse_for = reverse_for
203
204     def process(self) -> None:
205         pass
206
207
208 class NscZonePrimary(NscZone):
209     zone: Zone
210     _min_ttl: int
211     zone_file: Path
212     state_file: Path
213     state: NscZoneState
214     prev_state: NscZoneState
215
216     def __init__(self, *args, **kwargs) -> None:
217         super().__init__(*args, **kwargs)
218
219         self.zone_type = ZoneType.primary
220         self.zone_file = self.nsc.zone_dir / self.safe_name
221         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
222
223         self.state = NscZoneState()
224         self.prev_state = NscZoneState()
225         self.prev_state.load(self.state_file)
226
227         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
228         self._min_ttl = int(self.config.min_ttl.total_seconds())
229         self.update_soa()
230
231     def update_soa(self) -> None:
232         conf = self.config
233         soa = dns.rdtypes.ANY.SOA.SOA(
234             RdataClass.IN, RdataType.SOA,
235             mname=conf.origin_server,
236             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
237             serial=self.state.serial,
238             refresh=int(conf.refresh.total_seconds()),
239             retry=int(conf.retry.total_seconds()),
240             expire=int(conf.expire.total_seconds()),
241             minimum=int(conf.min_ttl.total_seconds()),
242         )
243         self.zone.delete_rdataset("", RdataType.SOA)
244         self[""]._add(soa)
245
246     def n(self, name: str) -> NscNode:
247         return NscNode(self, name)
248
249     def __getitem__(self, name: str) -> NscNode:
250         return NscNode(self, name)
251
252     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
253         n = NscNode(self, name)
254         n.A(*args, reverse=reverse)
255         return n
256
257     def zone_header(self) -> str:
258         return (
259             f'; Zone file for {self.name}\n'
260             + '; Generated by NSC, please do not edit manually.\n'
261             + '\n')
262
263     def dump(self, file: Optional[TextIO] = None) -> None:
264         # Could use self.zone.to_file(sys.stdout), but we want better formatting
265         file = file or sys.stdout
266         file.write(self.zone_header())
267         last_name = None
268         for name, ttl, rec in self.zone.iterate_rdatas():
269             if name == last_name:
270                 print_name = ""
271             else:
272                 print_name = name
273             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
274             last_name = name
275
276     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
277         assert isinstance(self.reverse_for, IPv4Network)
278         parts = str(addr).split('.')
279         parts = parts[self.reverse_for.prefixlen // 8:]
280         name = '.'.join(reversed(parts))
281         self.n(name).PTR(ptr_to)
282
283     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
284         assert isinstance(self.reverse_for, IPv6Network)
285         parts = addr.exploded.replace(':', "")
286         parts = parts[self.reverse_for.prefixlen // 4:]
287         name = '.'.join(reversed(parts))
288         self.n(name).PTR(ptr_to)
289
290     def gen_hash(self) -> None:
291         sha = hashlib.sha1()
292         sha.update(self.zone_header().encode('us-ascii'))
293         for name, ttl, rec in self.zone.iterate_rdatas():
294             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
295             sha.update(text.encode('us-ascii'))
296         self.state.hash = sha.hexdigest()[:16]
297
298     def gen_serial(self) -> None:
299         prev = self.prev_state.serial
300         if self.state.hash == self.prev_state.hash and prev > 0:
301             self.state.serial = self.prev_state.serial
302         else:
303             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
304             if prev <= base:
305                 self.state.serial = base + 1
306             else:
307                 self.state.serial = prev + 1
308                 if prev >= base + 99:
309                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
310
311     def process(self) -> None:
312         if self.zone_type == ZoneType.primary:
313             self.gen_hash()
314             self.gen_serial()
315
316     def write_zone(self) -> None:
317         self.update_soa()
318         new_file = Path(str(self.zone_file) + '.new')
319         with open(new_file, 'w') as f:
320             self.dump(file=f)
321         new_file.replace(self.zone_file)
322
323     def write_state(self) -> None:
324         self.state.save(self.state_file)
325
326     def is_changed(self) -> bool:
327         return self.state.serial != self.prev_state.serial
328
329     def delegate_classless(self, net: str | IPNetwork, subdomain: Optional[str] = None) -> NscNode:
330         net = parse_network(net)
331         assert self.reverse_for is not None
332         assert isinstance(self.reverse_for, IPv4Network)
333         assert self.reverse_for.prefixlen % 8 == 0
334         assert isinstance(net, IPv4Network)
335         assert net.subnet_of(self.reverse_for)
336         assert net.prefixlen < self.reverse_for.prefixlen + 8
337
338         start = int(net.network_address.packed[net.prefixlen // 8])
339         num = 1 << (8 - net.prefixlen % 8)
340
341         if subdomain is None:
342             subdomain = f'{start}/{net.prefixlen}'
343
344         for i in range(start, start + num):
345             target = f'{i}.{subdomain}'
346             self[str(i)].CNAME(parse_name(target, relative=True))
347
348         return self[subdomain]
349
350
351 class NscZoneSecondary(NscZone):
352     primary_server: IPAddress
353     secondary_file: Path
354
355     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
356         super().__init__(*args, **kwargs)
357         self.zone_type = ZoneType.secondary
358         self.primary_server = primary_server
359         self.secondary_file = self.nsc.secondary_dir / self.safe_name
360
361
362 class Nsc:
363     start_time: datetime
364     zones: Dict[str, NscZone]
365     default_zone_config: NscZoneConfig
366     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
367     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
368     root_dir: Path
369     state_dir: Path
370     zone_dir: Path
371     secondary_dir: Path
372     daemon: 'NscDaemon'  # Set by DaemonConfig class
373
374     def __init__(self,
375                  directory: str = '.',
376                  daemon: Optional['NscDaemon'] = None,
377                  **kwargs) -> None:
378         self.start_time = datetime.now()
379         self.zones = {}
380         self.default_zone_config = NscZoneConfig(**kwargs)
381         self.ipv4_reverse = defaultdict(list)
382         self.ipv6_reverse = defaultdict(list)
383
384         self.root_dir = Path(directory)
385         self.state_dir = self.root_dir / 'state'
386         self.state_dir.mkdir(parents=True, exist_ok=True)
387         self.zone_dir = self.root_dir / 'zone'
388         self.zone_dir.mkdir(parents=True, exist_ok=True)
389         self.secondary_dir = self.root_dir / 'secondary'
390         self.secondary_dir.mkdir(parents=True, exist_ok=True)
391
392         if daemon is None:
393             from nsconfig.daemon import NscDaemonNull
394             daemon = NscDaemonNull()
395         self.daemon = daemon
396         daemon.setup(self)
397
398     def add_zone(self,
399                  name: Optional[str] = None,
400                  reverse_for: str | IPNetwork | None = None,
401                  follow_primary: str | IPAddress | None = None,
402                  inherit_config: Optional[NscZoneConfig] = None,
403                  **kwargs) -> Zone:
404         if inherit_config is None:
405             inherit_config = self.default_zone_config
406
407         if reverse_for is not None:
408             if isinstance(reverse_for, str):
409                 reverse_for = ip_network(reverse_for, strict=True)
410             name = name or self._reverse_zone_name(reverse_for)
411         assert name is not None
412         assert name not in self.zones
413
414         z: NscZone
415         if follow_primary is None:
416             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
417         else:
418             if isinstance(follow_primary, str):
419                 follow_primary = ip_address(follow_primary)
420             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
421
422         self.zones[name] = z
423         return z
424
425     def _reverse_zone_name(self, net: IPNetwork) -> str:
426         if isinstance(net, IPv4Network):
427             parts = str(net.network_address).split('.')
428             out = parts[:net.prefixlen // 8]
429             if net.prefixlen % 8 != 0:
430                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
431             return '.'.join(reversed(out)) + '.in-addr.arpa'
432         elif isinstance(net, IPv6Network):
433             assert net.prefixlen % 4 == 0
434             nibbles = net.network_address.exploded.replace(':', "")
435             nibbles = nibbles[:net.prefixlen // 4]
436             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
437         else:
438             raise NotImplementedError()
439
440     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
441         if isinstance(addr, IPv4Address):
442             self.ipv4_reverse[addr].append(ptr_to)
443         else:
444             self.ipv6_reverse[addr].append(ptr_to)
445
446     def dump_reverse(self) -> None:
447         print('### Requests for reverse mappings ###')
448         for ipa4, name in sorted(self.ipv4_reverse.items()):
449             print(f'{ipa4}\t{name}')
450         for ipa6, name in sorted(self.ipv6_reverse.items()):
451             print(f'{ipa6}\t{name}')
452
453     def fill_reverse(self) -> None:
454         for z in self.zones.values():
455             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
456                 if isinstance(z.reverse_for, IPv4Network):
457                     for addr4, ptr_list in self.ipv4_reverse.items():
458                         if addr4 in z.reverse_for:
459                             for ptr_to in ptr_list:
460                                 z._add_ipv4_reverse(addr4, ptr_to)
461                 else:
462                     for addr6, ptr_list in self.ipv6_reverse.items():
463                         if addr6 in z.reverse_for:
464                             for ptr_to in ptr_list:
465                                 z._add_ipv6_reverse(addr6, ptr_to)
466
467     def get_zones(self) -> List[NscZone]:
468         return [self.zones[k] for k in sorted(self.zones.keys())]
469
470     def process(self) -> None:
471         self.fill_reverse()
472         for z in self.get_zones():
473             z.process()