]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Unify processing of arguments to record-generating methods
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from enum import Enum, auto
18 import hashlib
19 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
20 import json
21 from pathlib import Path
22 import socket
23 import sys
24 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
25
26 from nsconfig.util import flatten_list
27
28
29 if TYPE_CHECKING:
30     from nsconfig.daemon import NscDaemon
31
32
33 IPAddress = IPv4Address | IPv6Address
34 IPNetwork = IPv4Network | IPv6Network
35 IPAddr = str | IPAddress | List[str | IPAddress]
36
37
38 class NscNode:
39     nsc_zone: 'NscZonePrimary'
40     name: str
41     node: Node
42     _ttl: int
43
44     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
45         self.nsc_zone = nsc_zone
46         self.name = name
47         self.node = nsc_zone.zone.find_node(name, create=True)
48         self._ttl = nsc_zone._min_ttl
49
50     def ttl(self, *args, **kwargs) -> Self:
51         if not args and not kwargs:
52             self._ttl = self.nsc_zone._min_ttl
53         else:
54             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
55         return self
56
57     def _add(self, rec: Rdata) -> None:
58         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
59         rds.add(rec, ttl=self._ttl)
60
61     def _parse_addr(self, addr: IPAddr | str) -> IPAddress:
62         if isinstance(addr, IPv4Address) or isinstance(addr, IPv6Address):
63             return addr
64         elif isinstance(addr, str):
65             return ip_address(addr)
66         else:
67             raise ValueError('Cannot parse IP address')
68
69     def _parse_name(self, name: str) -> Name:
70         # FIXME: Names with escaped dots
71         if '.' in name:
72             return dns.name.from_text(name)
73         else:
74             return dns.name.from_text(name, origin=None)
75
76     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
77         for a in map(self._parse_addr, flatten_list(addrs)):
78             if isinstance(a, IPv4Address):
79                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
80             else:
81                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
82             if reverse:
83                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
84         return self
85
86     def MX(self, pri: int, name: str) -> Self:
87         self._add(
88             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
89         )
90         return self
91
92     def NS(self, *names: str | List[str]) -> Self:
93         for name in map(self._parse_name, flatten_list(names)):
94             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
95         return self
96
97     def TXT(self, *text: str | List[str]) -> Self:
98         for txt in flatten_list(text):
99             self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
100         return self
101
102     def PTR(self, target: Name | str) -> Self:
103         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
104         return self
105
106     def generic(self, typ: str, text: str) -> Self:
107         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
108         return self
109
110
111 class NscZoneConfig:
112     admin_email: str
113     refresh: timedelta
114     retry: timedelta
115     expire: timedelta
116     min_ttl: timedelta
117     origin_server: str
118
119     default_config: Optional['NscZoneConfig'] = None
120
121     def __init__(self,
122                  admin_email: Optional[str] = None,
123                  refresh: Optional[timedelta] = None,
124                  retry: Optional[timedelta] = None,
125                  expire: Optional[timedelta] = None,
126                  min_ttl: Optional[timedelta] = None,
127                  origin_server: Optional[str] = None,
128                  inherit_config: Optional['NscZoneConfig'] = None,
129                  ) -> None:
130         if inherit_config is None:
131             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
132         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
133         self.refresh = refresh if refresh is not None else inherit_config.refresh
134         self.retry = retry if retry is not None else inherit_config.retry
135         self.expire = expire if expire is not None else inherit_config.expire
136         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
137         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
138
139     def finalize(self) -> Self:
140         if not self.origin_server:
141             self.origin_server = socket.getfqdn()
142         if not self.admin_email:
143             self.admin_email = f'hostmaster@{self.origin_server}'
144         return self
145
146
147 NscZoneConfig.default_config = NscZoneConfig(
148     admin_email="",
149     refresh=timedelta(hours=8),
150     retry=timedelta(hours=2),
151     expire=timedelta(days=14),
152     min_ttl=timedelta(days=1),
153     origin_server="",
154 )
155
156
157 class NscZoneState:
158     serial: int
159     hash: str
160
161     def __init__(self) -> None:
162         self.serial = 0
163         self.hash = 'none'
164
165     def load(self, file: Path) -> None:
166         try:
167             with open(file) as f:
168                 js = json.load(f)
169                 assert isinstance(js, dict)
170                 if 'serial' in js:
171                     self.serial = js['serial']
172                 if 'hash' in js:
173                     self.hash = js['hash']
174         except FileNotFoundError:
175             pass
176
177     def save(self, file: Path) -> None:
178         new_file = Path(str(file) + '.new')
179         with open(new_file, 'w') as f:
180             js = {
181                 'serial': self.serial,
182                 'hash': self.hash,
183             }
184             json.dump(js, f, indent=4, sort_keys=True)
185         new_file.replace(file)
186
187
188 class ZoneType(Enum):
189     primary = auto()
190     secondary = auto()
191
192
193 class NscZone:
194     nsc: 'Nsc'
195     name: str
196     safe_name: str                          # For use in file names
197     zone_type: ZoneType
198     reverse_for: Optional[IPNetwork]
199
200     def __init__(self,
201                  nsc: 'Nsc',
202                  name: str,
203                  reverse_for: Optional[IPNetwork],
204                  **kwargs) -> None:
205         self.nsc = nsc
206         self.name = name
207         self.safe_name = name.replace('/', '@')
208         self.config = NscZoneConfig(**kwargs).finalize()
209         self.reverse_for = reverse_for
210
211     def process(self) -> None:
212         pass
213
214
215 class NscZonePrimary(NscZone):
216     zone: Zone
217     _min_ttl: int
218     zone_file: Path
219     state_file: Path
220     state: NscZoneState
221     prev_state: NscZoneState
222
223     def __init__(self, *args, **kwargs) -> None:
224         super().__init__(*args, **kwargs)
225
226         self.zone_type = ZoneType.primary
227         self.zone_file = self.nsc.zone_dir / self.safe_name
228         self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
229
230         self.state = NscZoneState()
231         self.prev_state = NscZoneState()
232         self.prev_state.load(self.state_file)
233
234         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
235         self._min_ttl = int(self.config.min_ttl.total_seconds())
236         self.update_soa()
237
238     def update_soa(self) -> None:
239         conf = self.config
240         soa = dns.rdtypes.ANY.SOA.SOA(
241             RdataClass.IN, RdataType.SOA,
242             mname=conf.origin_server,
243             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
244             serial=self.state.serial,
245             refresh=int(conf.refresh.total_seconds()),
246             retry=int(conf.retry.total_seconds()),
247             expire=int(conf.expire.total_seconds()),
248             minimum=int(conf.min_ttl.total_seconds()),
249         )
250         self.zone.delete_rdataset("", RdataType.SOA)
251         self[""]._add(soa)
252
253     def n(self, name: str) -> NscNode:
254         return NscNode(self, name)
255
256     def __getitem__(self, name: str) -> NscNode:
257         return NscNode(self, name)
258
259     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
260         n = NscNode(self, name)
261         n.A(*args, reverse=reverse)
262         return n
263
264     def dump(self, file: Optional[TextIO] = None) -> None:
265         # Could use self.zone.to_file(sys.stdout), but we want better formatting
266         file = file or sys.stdout
267         file.write(f'; Zone file for {self.name}\n\n')
268         last_name = None
269         for name, ttl, rec in self.zone.iterate_rdatas():
270             if name == last_name:
271                 print_name = ""
272             else:
273                 print_name = name
274             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
275             last_name = name
276
277     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
278         assert isinstance(self.reverse_for, IPv4Network)
279         parts = str(addr).split('.')
280         parts = parts[self.reverse_for.prefixlen // 8:]
281         name = '.'.join(reversed(parts))
282         self.n(name).PTR(ptr_to)
283
284     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
285         assert isinstance(self.reverse_for, IPv6Network)
286         parts = addr.exploded.replace(':', "")
287         parts = parts[self.reverse_for.prefixlen // 4:]
288         name = '.'.join(reversed(parts))
289         self.n(name).PTR(ptr_to)
290
291     def gen_hash(self) -> None:
292         sha = hashlib.sha1()
293         for name, ttl, rec in self.zone.iterate_rdatas():
294             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
295             sha.update(text.encode('us-ascii'))
296         self.state.hash = sha.hexdigest()[:16]
297
298     def gen_serial(self) -> None:
299         prev = self.prev_state.serial
300         if self.state.hash == self.prev_state.hash and prev > 0:
301             self.state.serial = self.prev_state.serial
302         else:
303             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
304             if prev <= base:
305                 self.state.serial = base + 1
306             else:
307                 self.state.serial = prev + 1
308                 if prev >= base + 99:
309                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
310
311     def process(self) -> None:
312         if self.zone_type == ZoneType.primary:
313             self.gen_hash()
314             self.gen_serial()
315
316     def write_zone(self) -> None:
317         self.update_soa()
318         new_file = Path(str(self.zone_file) + '.new')
319         with open(new_file, 'w') as f:
320             self.dump(file=f)
321         new_file.replace(self.zone_file)
322
323     def write_state(self) -> None:
324         self.state.save(self.state_file)
325
326     def is_changed(self) -> bool:
327         return self.state.serial != self.prev_state.serial
328
329
330 class NscZoneSecondary(NscZone):
331     primary_server: IPAddress
332     secondary_file: Path
333
334     def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
335         super().__init__(*args, **kwargs)
336         self.zone_type = ZoneType.secondary
337         self.primary_server = primary_server
338         self.secondary_file = self.nsc.secondary_dir / self.safe_name
339
340
341 class Nsc:
342     start_time: datetime
343     zones: Dict[str, NscZone]
344     default_zone_config: NscZoneConfig
345     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
346     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
347     root_dir: Path
348     state_dir: Path
349     zone_dir: Path
350     secondary_dir: Path
351     daemon: 'NscDaemon'  # Set by DaemonConfig class
352
353     def __init__(self,
354                  directory: str = '.',
355                  daemon: Optional['NscDaemon'] = None,
356                  **kwargs) -> None:
357         self.start_time = datetime.now()
358         self.zones = {}
359         self.default_zone_config = NscZoneConfig(**kwargs)
360         self.ipv4_reverse = defaultdict(list)
361         self.ipv6_reverse = defaultdict(list)
362
363         self.root_dir = Path(directory)
364         self.state_dir = self.root_dir / 'state'
365         self.state_dir.mkdir(parents=True, exist_ok=True)
366         self.zone_dir = self.root_dir / 'zone'
367         self.zone_dir.mkdir(parents=True, exist_ok=True)
368         self.secondary_dir = self.root_dir / 'secondary'
369         self.secondary_dir.mkdir(parents=True, exist_ok=True)
370
371         if daemon is None:
372             from nsconfig.daemon import NscDaemonNull
373             daemon = NscDaemonNull()
374         self.daemon = daemon
375         daemon.setup(self)
376
377     def add_zone(self,
378                  name: Optional[str] = None,
379                  reverse_for: str | IPNetwork | None = None,
380                  follow_primary: str | IPAddress | None = None,
381                  inherit_config: Optional[NscZoneConfig] = None,
382                  **kwargs) -> Zone:
383         if inherit_config is None:
384             inherit_config = self.default_zone_config
385
386         if reverse_for is not None:
387             if isinstance(reverse_for, str):
388                 reverse_for = ip_network(reverse_for, strict=True)
389             name = name or self._reverse_zone_name(reverse_for)
390         assert name is not None
391         assert name not in self.zones
392
393         z: NscZone
394         if follow_primary is None:
395             z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
396         else:
397             if isinstance(follow_primary, str):
398                 follow_primary = ip_address(follow_primary)
399             z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
400
401         self.zones[name] = z
402         return z
403
404     def _reverse_zone_name(self, net: IPNetwork) -> str:
405         if isinstance(net, IPv4Network):
406             parts = str(net.network_address).split('.')
407             out = parts[:net.prefixlen // 8]
408             if net.prefixlen % 8 != 0:
409                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
410             return '.'.join(reversed(out)) + '.in-addr.arpa'
411         elif isinstance(net, IPv6Network):
412             assert net.prefixlen % 4 == 0
413             nibbles = net.network_address.exploded.replace(':', "")
414             nibbles = nibbles[:net.prefixlen // 4]
415             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
416         else:
417             raise NotImplementedError()
418
419     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
420         if isinstance(addr, IPv4Address):
421             self.ipv4_reverse[addr].append(ptr_to)
422         else:
423             self.ipv6_reverse[addr].append(ptr_to)
424
425     def dump_reverse(self) -> None:
426         print('### Requests for reverse mappings ###')
427         for ipa4, name in sorted(self.ipv4_reverse.items()):
428             print(f'{ipa4}\t{name}')
429         for ipa6, name in sorted(self.ipv6_reverse.items()):
430             print(f'{ipa6}\t{name}')
431
432     def fill_reverse(self) -> None:
433         for z in self.zones.values():
434             if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
435                 if isinstance(z.reverse_for, IPv4Network):
436                     for addr4, ptr_list in self.ipv4_reverse.items():
437                         if addr4 in z.reverse_for:
438                             for ptr_to in ptr_list:
439                                 z._add_ipv4_reverse(addr4, ptr_to)
440                 else:
441                     for addr6, ptr_list in self.ipv6_reverse.items():
442                         if addr6 in z.reverse_for:
443                             for ptr_to in ptr_list:
444                                 z._add_ipv6_reverse(addr6, ptr_to)
445
446     def get_zones(self) -> List[NscZone]:
447         return [self.zones[k] for k in sorted(self.zones.keys())]
448
449     def process(self) -> None:
450         self.fill_reverse()
451         for z in self.get_zones():
452             z.process()