]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
53739ef304afb03744c24c24b125aa70971835a4
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
18 import hashlib
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
20 import json
21 from pathlib import Path
22 import socket
23 import sys
24 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
25
26 from nsconfig.util import flatten_list
27
28
29 if TYPE_CHECKING:
30     from nsconfig.daemon import NscDaemon
31
32
33 IPAddress = IPv4Address | IPv6Address
34 IPNetwork = IPv4Network | IPv6Network
35 IPAddr = str | IPAddress | List[str | IPAddress]
36
37
38 class NscNode:
39     nsc_zone: 'NscZonePrimary'
40     name: str
41     node: Node
42     _ttl: int
43
44     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
45         self.nsc_zone = nsc_zone
46         self.name = name
47         self.node = nsc_zone.zone.find_node(name, create=True)
48         self._ttl = nsc_zone._min_ttl
49
50     def ttl(self, *args, **kwargs) -> Self:
51         if not args and not kwargs:
52             self._ttl = self.nsc_zone._min_ttl
53         else:
54             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
55         return self
56
57     def _add(self, rec: Rdata) -> None:
58         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
59         rds.add(rec, ttl=self._ttl)
60
61     def _parse_addr(self, addr: IPAddr | str) -> IPAddress:
62         if isinstance(addr, IPv4Address) or isinstance(addr, IPv6Address):
63             return addr
64         elif isinstance(addr, str):
65             return ip_address(addr)
66         else:
67             raise ValueError('Cannot parse IP address')
68
69     def _parse_name(self, name: str) -> Name:
70         # FIXME: Names with escaped dots
71         if '.' in name:
72             return dns.name.from_text(name)
73         else:
74             return dns.name.from_text(name, origin=None)
75
76     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
77         for a in map(self._parse_addr, flatten_list(addrs)):
78             if isinstance(a, IPv4Address):
79                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
80             else:
81                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
82             if reverse:
83                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
84         return self
85
86     def MX(self, pri: int, name: str) -> Self:
87         self._add(
88             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
89         )
90         return self
91
92     def NS(self, *names: str | List[str]) -> Self:
93         for name in map(self._parse_name, flatten_list(names)):
94             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
95         return self
96
97     def TXT(self, *text: str | List[str]) -> Self:
98         for txt in flatten_list(text):
99             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
100         return self
101
102     def PTR(self, target: Name | str) -> Self:
103         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
104         return self
105
106     def generic(self, typ: str, text: str) -> Self:
107         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
108         return self
109
110
111 class NscZoneConfig:
112     admin_email: str
113     refresh: timedelta
114     retry: timedelta
115     expire: timedelta
116     min_ttl: timedelta
117     origin_server: str
118     daemon_options: List[str]
119
120     default_config: Optional['NscZoneConfig'] = None
121
122     def __init__(self,
123                  admin_email: Optional[str] = None,
124                  refresh: Optional[timedelta] = None,
125                  retry: Optional[timedelta] = None,
126                  expire: Optional[timedelta] = None,
127                  min_ttl: Optional[timedelta] = None,
128                  origin_server: Optional[str] = None,
129                  daemon_options: Optional[List[str]] = None,
130                  add_daemon_options: Optional[List[str]] = None,
131                  inherit_config: Optional['NscZoneConfig'] = None,
132                  ) -> None:
133         if inherit_config is None:
134             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
135         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
136         self.refresh = refresh if refresh is not None else inherit_config.refresh
137         self.retry = retry if retry is not None else inherit_config.retry
138         self.expire = expire if expire is not None else inherit_config.expire
139         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
140         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
141         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
142         if add_daemon_options is not None:
143             self.daemon_options += add_daemon_options
144
145     def finalize(self) -> Self:
146         if not self.origin_server:
147             self.origin_server = socket.getfqdn()
148         if not self.admin_email:
149             self.admin_email = f'hostmaster@{self.origin_server}'
150         return self
151
152
153 NscZoneConfig.default_config = NscZoneConfig(
154     admin_email="",
155     refresh=timedelta(hours=8),
156     retry=timedelta(hours=2),
157     expire=timedelta(days=14),
158     min_ttl=timedelta(days=1),
159     origin_server="",
160     daemon_options=[],
161 )
162
163
164 class NscZoneState:
165     serial: int
166     hash: str
167
168     def __init__(self) -> None:
169         self.serial = 0
170         self.hash = 'none'
171
172     def load(self, file: Path) -> None:
173         try:
174             with open(file) as f:
175                 js = json.load(f)
176                 assert isinstance(js, dict)
177                 if 'serial' in js:
178                     self.serial = js['serial']
179                 if 'hash' in js:
180                     self.hash = js['hash']
181         except FileNotFoundError:
182             pass
183
184     def save(self, file: Path) -> None:
185         new_file = Path(str(file) + '.new')
186         with open(new_file, 'w') as f:
187             js = {
188                 'serial': self.serial,
189                 'hash': self.hash,
190             }
191             json.dump(js, f, indent=4, sort_keys=True)
192         new_file.replace(file)
193
194
195 class ZoneType(Enum):
196     primary = auto()
197     secondary = auto()
198
199
200 class NscZone:
201     nsc: 'Nsc'
202     name: str
203     safe_name: str                          # For use in file names
204     zone_type: ZoneType
205     reverse_for: Optional[IPNetwork]
206
207     def __init__(self,
208                  nsc: 'Nsc',
209                  name: str,
210                  reverse_for: Optional[IPNetwork],
211                  **kwargs) -> None:
212         self.nsc = nsc
213         self.name = name
214         self.safe_name = name.replace('/', '@')
215         self.config = NscZoneConfig(**kwargs).finalize()
216         self.reverse_for = reverse_for
217
218     def process(self) -> None:
219         pass
220
221
222 class NscZonePrimary(NscZone):
223     zone: Zone
224     _min_ttl: int
225     zone_file: Path
226     state_file: Path
227     state: NscZoneState
228     prev_state: NscZoneState
229
230     def __init__(self, *args, **kwargs) -> None:
231         super().__init__(*args, **kwargs)
232
233         self.zone_type = ZoneType.primary
234         self.zone_file = self.nsc.zone_dir / self.safe_name
235         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
236
237         self.state = NscZoneState()
238         self.prev_state = NscZoneState()
239         self.prev_state.load(self.state_file)
240
241         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
242         self._min_ttl = int(self.config.min_ttl.total_seconds())
243         self.update_soa()
244
245     def update_soa(self) -> None:
246         conf = self.config
247         soa = dns.rdtypes.ANY.SOA.SOA(
248             RdataClass.IN, RdataType.SOA,
249             mname=conf.origin_server,
250             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
251             serial=self.state.serial,
252             refresh=int(conf.refresh.total_seconds()),
253             retry=int(conf.retry.total_seconds()),
254             expire=int(conf.expire.total_seconds()),
255             minimum=int(conf.min_ttl.total_seconds()),
256         )
257         self.zone.delete_rdataset("", RdataType.SOA)
258         self[""]._add(soa)
259
260     def n(self, name: str) -> NscNode:
261         return NscNode(self, name)
262
263     def __getitem__(self, name: str) -> NscNode:
264         return NscNode(self, name)
265
266     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
267         n = NscNode(self, name)
268         n.A(*args, reverse=reverse)
269         return n
270
271     def zone_header(self) -> str:
272         return (
273             f'; Zone file for {self.name}\n'
274             + '; Generated by NSC, please do not edit manually.\n'
275             + '\n')
276
277     def dump(self, file: Optional[TextIO] = None) -> None:
278         # Could use self.zone.to_file(sys.stdout), but we want better formatting
279         file = file or sys.stdout
280         file.write(self.zone_header())
281         last_name = None
282         for name, ttl, rec in self.zone.iterate_rdatas():
283             if name == last_name:
284                 print_name = ""
285             else:
286                 print_name = name
287             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
288             last_name = name
289
290     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
291         assert isinstance(self.reverse_for, IPv4Network)
292         parts = str(addr).split('.')
293         parts = parts[self.reverse_for.prefixlen // 8:]
294         name = '.'.join(reversed(parts))
295         self.n(name).PTR(ptr_to)
296
297     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
298         assert isinstance(self.reverse_for, IPv6Network)
299         parts = addr.exploded.replace(':', "")
300         parts = parts[self.reverse_for.prefixlen // 4:]
301         name = '.'.join(reversed(parts))
302         self.n(name).PTR(ptr_to)
303
304     def gen_hash(self) -> None:
305         sha = hashlib.sha1()
306         sha.update(self.zone_header().encode('us-ascii'))
307         for name, ttl, rec in self.zone.iterate_rdatas():
308             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
309             sha.update(text.encode('us-ascii'))
310         self.state.hash = sha.hexdigest()[:16]
311
312     def gen_serial(self) -> None:
313         prev = self.prev_state.serial
314         if self.state.hash == self.prev_state.hash and prev > 0:
315             self.state.serial = self.prev_state.serial
316         else:
317             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
318             if prev <= base:
319                 self.state.serial = base + 1
320             else:
321                 self.state.serial = prev + 1
322                 if prev >= base + 99:
323                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
324
325     def process(self) -> None:
326         if self.zone_type == ZoneType.primary:
327             self.gen_hash()
328             self.gen_serial()
329
330     def write_zone(self) -> None:
331         self.update_soa()
332         new_file = Path(str(self.zone_file) + '.new')
333         with open(new_file, 'w') as f:
334             self.dump(file=f)
335         new_file.replace(self.zone_file)
336
337     def write_state(self) -> None:
338         self.state.save(self.state_file)
339
340     def is_changed(self) -> bool:
341         return self.state.serial != self.prev_state.serial
342
343
344 class NscZoneSecondary(NscZone):
345     primary_server: IPAddress
346     secondary_file: Path
347
348     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
349         super().__init__(*args, **kwargs)
350         self.zone_type = ZoneType.secondary
351         self.primary_server = primary_server
352         self.secondary_file = self.nsc.secondary_dir / self.safe_name
353
354
355 class Nsc:
356     start_time: datetime
357     zones: Dict[str, NscZone]
358     default_zone_config: NscZoneConfig
359     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
360     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
361     root_dir: Path
362     state_dir: Path
363     zone_dir: Path
364     secondary_dir: Path
365     daemon: 'NscDaemon'  # Set by DaemonConfig class
366
367     def __init__(self,
368                  directory: str = '.',
369                  daemon: Optional['NscDaemon'] = None,
370                  **kwargs) -> None:
371         self.start_time = datetime.now()
372         self.zones = {}
373         self.default_zone_config = NscZoneConfig(**kwargs)
374         self.ipv4_reverse = defaultdict(list)
375         self.ipv6_reverse = defaultdict(list)
376
377         self.root_dir = Path(directory)
378         self.state_dir = self.root_dir / 'state'
379         self.state_dir.mkdir(parents=True, exist_ok=True)
380         self.zone_dir = self.root_dir / 'zone'
381         self.zone_dir.mkdir(parents=True, exist_ok=True)
382         self.secondary_dir = self.root_dir / 'secondary'
383         self.secondary_dir.mkdir(parents=True, exist_ok=True)
384
385         if daemon is None:
386             from nsconfig.daemon import NscDaemonNull
387             daemon = NscDaemonNull()
388         self.daemon = daemon
389         daemon.setup(self)
390
391     def add_zone(self,
392                  name: Optional[str] = None,
393                  reverse_for: str | IPNetwork | None = None,
394                  follow_primary: str | IPAddress | None = None,
395                  inherit_config: Optional[NscZoneConfig] = None,
396                  **kwargs) -> Zone:
397         if inherit_config is None:
398             inherit_config = self.default_zone_config
399
400         if reverse_for is not None:
401             if isinstance(reverse_for, str):
402                 reverse_for = ip_network(reverse_for, strict=True)
403             name = name or self._reverse_zone_name(reverse_for)
404         assert name is not None
405         assert name not in self.zones
406
407         z: NscZone
408         if follow_primary is None:
409             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
410         else:
411             if isinstance(follow_primary, str):
412                 follow_primary = ip_address(follow_primary)
413             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
414
415         self.zones[name] = z
416         return z
417
418     def _reverse_zone_name(self, net: IPNetwork) -> str:
419         if isinstance(net, IPv4Network):
420             parts = str(net.network_address).split('.')
421             out = parts[:net.prefixlen // 8]
422             if net.prefixlen % 8 != 0:
423                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
424             return '.'.join(reversed(out)) + '.in-addr.arpa'
425         elif isinstance(net, IPv6Network):
426             assert net.prefixlen % 4 == 0
427             nibbles = net.network_address.exploded.replace(':', "")
428             nibbles = nibbles[:net.prefixlen // 4]
429             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
430         else:
431             raise NotImplementedError()
432
433     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
434         if isinstance(addr, IPv4Address):
435             self.ipv4_reverse[addr].append(ptr_to)
436         else:
437             self.ipv6_reverse[addr].append(ptr_to)
438
439     def dump_reverse(self) -> None:
440         print('### Requests for reverse mappings ###')
441         for ipa4, name in sorted(self.ipv4_reverse.items()):
442             print(f'{ipa4}\t{name}')
443         for ipa6, name in sorted(self.ipv6_reverse.items()):
444             print(f'{ipa6}\t{name}')
445
446     def fill_reverse(self) -> None:
447         for z in self.zones.values():
448             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
449                 if isinstance(z.reverse_for, IPv4Network):
450                     for addr4, ptr_list in self.ipv4_reverse.items():
451                         if addr4 in z.reverse_for:
452                             for ptr_to in ptr_list:
453                                 z._add_ipv4_reverse(addr4, ptr_to)
454                 else:
455                     for addr6, ptr_list in self.ipv6_reverse.items():
456                         if addr6 in z.reverse_for:
457                             for ptr_to in ptr_list:
458                                 z._add_ipv6_reverse(addr6, ptr_to)
459
460     def get_zones(self) -> List[NscZone]:
461         return [self.zones[k] for k in sorted(self.zones.keys())]
462
463     def process(self) -> None:
464         self.fill_reverse()
465         for z in self.get_zones():
466             z.process()