]> mj.ucw.cz Git - pynsc.git/blob - nsconfig/core.py
Updating zones
[pynsc.git] / nsconfig / core.py
1 from collections import defaultdict
2 from datetime import datetime, timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 import hashlib
18 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
19 import json
20 from pathlib import Path
21 import socket
22 import sys
23 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
24
25
26 IPAddress = IPv4Address | IPv6Address
27 IPNetwork = IPv4Network | IPv6Network
28 IPAddr = str | IPAddress | List[str | IPAddress]
29
30
31 class NscNode:
32     nsc_zone: 'NscZone'
33     name: str
34     node: Node
35     _ttl: int
36
37     def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
38         self.nsc_zone = nsc_zone
39         self.name = name
40         self.node = nsc_zone.zone.find_node(name, create=True)
41         self._ttl = nsc_zone._min_ttl
42
43     def ttl(self, *args, **kwargs) -> Self:
44         if not args and not kwargs:
45             self._ttl = self.nsc_zone._min_ttl
46         else:
47             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
48         return self
49
50     def _add(self, rec: Rdata) -> None:
51         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
52         rds.add(rec, ttl=self._ttl)
53
54     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
55         out = []
56         for a in addrs:
57             if not isinstance(a, list):
58                 a = [a]
59             for b in a:
60                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
61                     out.append(b)
62                 else:
63                     out.append(ip_address(b))
64         return out
65
66     def _parse_name(self, name: str) -> Name:
67         # FIXME: Names with escaped dots
68         if '.' in name:
69             return dns.name.from_text(name)
70         else:
71             return dns.name.from_text(name, origin=None)
72
73     def _parse_names(self, names: str | List[str]) -> List[Name]:
74         if isinstance(names, str):
75             return [self._parse_name(names)]
76         else:
77             return [self._parse_name(n) for n in names]
78
79     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
80         for a in self._parse_addrs(addrs):
81             if isinstance(a, IPv4Address):
82                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
83             else:
84                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
85             if reverse:
86                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
87         return self
88
89     def MX(self, pri: int, name: str) -> Self:
90         self._add(
91             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
92         )
93         return self
94
95     def NS(self, names: str | List[str]) -> Self:
96         # FIXME: Variadic?
97         for name in self._parse_names(names):
98             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
99         return self
100
101     def TXT(self, text: str) -> Self:
102         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
103         return self
104
105     def PTR(self, target: Name | str) -> Self:
106         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
107         return self
108
109     def generic(self, typ: str, text: str) -> Self:
110         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
111         return self
112
113
114 class NscZoneConfig:
115     admin_email: str
116     refresh: timedelta
117     retry: timedelta
118     expire: timedelta
119     min_ttl: timedelta
120     origin_server: str
121
122     default_config: Optional['NscZoneConfig'] = None
123
124     def __init__(self,
125                  admin_email: Optional[str] = None,
126                  refresh: Optional[timedelta] = None,
127                  retry: Optional[timedelta] = None,
128                  expire: Optional[timedelta] = None,
129                  min_ttl: Optional[timedelta] = None,
130                  origin_server: Optional[str] = None,
131                  inherit_config: Optional['NscZoneConfig'] = None,
132                  ) -> None:
133         if inherit_config is None:
134             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
135         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
136         self.refresh = refresh if refresh is not None else inherit_config.refresh
137         self.retry = retry if retry is not None else inherit_config.retry
138         self.expire = expire if expire is not None else inherit_config.expire
139         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
140         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
141
142     def finalize(self) -> Self:
143         if not self.origin_server:
144             self.origin_server = socket.getfqdn()
145         if not self.admin_email:
146             self.admin_email = f'hostmaster@{self.origin_server}'
147         return self
148
149
150 NscZoneConfig.default_config = NscZoneConfig(
151     admin_email="",
152     refresh=timedelta(hours=8),
153     retry=timedelta(hours=2),
154     expire=timedelta(days=14),
155     min_ttl=timedelta(days=1),
156     origin_server="",
157 )
158
159
160 class NscZoneState:
161     serial: int
162     hash: str
163
164     def __init__(self) -> None:
165         self.serial = 0
166         self.hash = 'none'
167
168     def load(self, file: Path) -> None:
169         try:
170             with open(file) as f:
171                 js = json.load(f)
172                 assert isinstance(js, dict)
173                 if 'serial' in js:
174                     self.serial = js['serial']
175                 if 'hash' in js:
176                     self.hash = js['hash']
177         except FileNotFoundError:
178             pass
179
180     def save(self, file: Path) -> None:
181         new_file = Path(str(file) + '.new')
182         with open(new_file, 'w') as f:
183             js = {
184                 'serial': self.serial,
185                 'hash': self.hash,
186             }
187             json.dump(js, f, indent=4, sort_keys=True)
188         new_file.replace(file)
189
190
191 class NscZone:
192     nsc: 'Nsc'
193     name: str
194     safe_name: str      # For use in file names
195     zone: Zone
196     _min_ttl: int
197     reverse_for: Optional[IPNetwork]
198     zone_file: Path
199     state_file: Path
200     state: NscZoneState
201     prev_state: NscZoneState
202
203     def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
204         self.nsc = nsc
205         self.name = name
206         self.safe_name = name.replace('/', '@')
207         self.config = NscZoneConfig(**kwargs).finalize()
208         self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
209         self._min_ttl = int(self.config.min_ttl.total_seconds())
210         self.reverse_for = reverse_for
211
212         self.zone_file = nsc.zone_dir / self.safe_name
213         self.state_file = nsc.state_dir / (self.safe_name + '.json')
214         self.state = NscZoneState()
215         self.prev_state = NscZoneState()
216         self.prev_state.load(self.state_file)
217
218         self.update_soa()
219
220     def update_soa(self) -> None:
221         conf = self.config
222         soa = dns.rdtypes.ANY.SOA.SOA(
223             RdataClass.IN, RdataType.SOA,
224             mname=conf.origin_server,
225             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
226             serial=self.state.serial,
227             refresh=int(conf.refresh.total_seconds()),
228             retry=int(conf.retry.total_seconds()),
229             expire=int(conf.expire.total_seconds()),
230             minimum=int(conf.min_ttl.total_seconds()),
231         )
232         self.zone.delete_rdataset("", RdataType.SOA)
233         self[""]._add(soa)
234
235     def n(self, name: str) -> NscNode:
236         return NscNode(self, name)
237
238     def __getitem__(self, name: str) -> NscNode:
239         return NscNode(self, name)
240
241     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
242         n = NscNode(self, name)
243         n.A(*args, reverse=reverse)
244         return n
245
246     def dump(self, file: Optional[TextIO] = None) -> None:
247         # Could use self.zone.to_file(sys.stdout), but we want better formatting
248         file = file or sys.stdout
249         file.write(f'; Zone file for {self.name}\n\n')
250         last_name = None
251         for name, ttl, rec in self.zone.iterate_rdatas():
252             if name == last_name:
253                 print_name = ""
254             else:
255                 print_name = name
256             file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
257             last_name = name
258
259     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
260         # Called only for addresses from this reverse network
261         assert self.reverse_for is not None
262         parts = str(addr).split('.')
263         parts = parts[self.reverse_for.prefixlen // 8:]
264         name = '.'.join(reversed(parts))
265         self.n(name).PTR(ptr_to)
266
267     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
268         # Called only for addresses from this reverse network
269         assert self.reverse_for is not None
270         parts = addr.exploded.replace(':', "")
271         parts = parts[self.reverse_for.prefixlen // 4:]
272         name = '.'.join(reversed(parts))
273         self.n(name).PTR(ptr_to)
274
275     def gen_hash(self) -> None:
276         sha = hashlib.sha1()
277         for name, ttl, rec in self.zone.iterate_rdatas():
278             text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
279             sha.update(text.encode('us-ascii'))
280         self.state.hash = sha.hexdigest()[:16]
281
282     def gen_serial(self) -> None:
283         prev = self.prev_state.serial
284         if self.state.hash == self.prev_state.hash and prev > 0:
285             self.state.serial = self.prev_state.serial
286         else:
287             base = int(self.nsc.start_time.strftime('%Y%m%d00'))
288             if prev <= base:
289                 self.state.serial = base + 1
290             else:
291                 self.state.serial = prev + 1
292                 if prev >= base + 99:
293                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
294
295     def process(self) -> None:
296         self.gen_hash()
297         self.gen_serial()
298
299     def write_zone(self) -> None:
300         self.update_soa()
301         new_file = Path(str(self.zone_file) + '.new')
302         with open(new_file, 'w') as f:
303             self.dump(file=f)
304         new_file.replace(self.zone_file)
305
306     def write_state(self) -> None:
307         self.state.save(self.state_file)
308
309
310 class Nsc:
311     start_time: datetime
312     zones: Dict[str, NscZone]
313     default_zone_config: NscZoneConfig
314     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
315     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
316     root_dir: Path
317     state_dir: Path
318     zone_dir: Path
319
320     def __init__(self, directory: str = '.', **kwargs) -> None:
321         self.start_time = datetime.now()
322         self.zones = {}
323         self.default_zone_config = NscZoneConfig(**kwargs)
324         self.ipv4_reverse = defaultdict(list)
325         self.ipv6_reverse = defaultdict(list)
326
327         self.root_dir = Path(directory)
328         self.state_dir = self.root_dir / 'state'
329         self.state_dir.mkdir(parents=True, exist_ok=True)
330         self.zone_dir = self.root_dir / 'zone'
331         self.zone_dir.mkdir(parents=True, exist_ok=True)
332
333     def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
334         if inherit_config is None:
335             inherit_config = self.default_zone_config
336         z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
337         assert z.name not in self.zones
338         self.zones[z.name] = z
339         return z
340
341     def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
342         if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
343             net = ip_network(net, strict=True)
344         name = name or self._reverse_zone_name(net)
345         return self.add_zone(name, reverse_for=net, **kwargs)
346
347     def _reverse_zone_name(self, net: IPNetwork) -> str:
348         if isinstance(net, IPv4Network):
349             parts = str(net.network_address).split('.')
350             out = parts[:net.prefixlen // 8]
351             if net.prefixlen % 8 != 0:
352                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
353             return '.'.join(reversed(out)) + '.in-addr.arpa'
354         elif isinstance(net, IPv6Network):
355             assert net.prefixlen % 4 == 0
356             nibbles = net.network_address.exploded.replace(':', "")
357             nibbles = nibbles[:net.prefixlen // 4]
358             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
359         else:
360             raise NotImplementedError()
361
362     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
363         if isinstance(addr, IPv4Address):
364             self.ipv4_reverse[addr].append(ptr_to)
365         else:
366             self.ipv6_reverse[addr].append(ptr_to)
367
368     def dump_reverse(self) -> None:
369         print('### Requests for reverse mappings ###')
370         for ipa4, name in sorted(self.ipv4_reverse.items()):
371             print(f'{ipa4}\t{name}')
372         for ipa6, name in sorted(self.ipv6_reverse.items()):
373             print(f'{ipa6}\t{name}')
374
375     def fill_reverse(self) -> None:
376         for z in self.zones.values():
377             if z.reverse_for is not None:
378                 if isinstance(z.reverse_for, IPv4Network):
379                     for addr4, ptr_list in self.ipv4_reverse.items():
380                         if addr4 in z.reverse_for:
381                             for ptr_to in ptr_list:
382                                 z._add_ipv4_reverse(addr4, ptr_to)
383                 else:
384                     for addr6, ptr_list in self.ipv6_reverse.items():
385                         if addr6 in z.reverse_for:
386                             for ptr_to in ptr_list:
387                                 z._add_ipv6_reverse(addr6, ptr_to)
388
389     def get_zones(self) -> List[NscZone]:
390         return [self.zones[k] for k in sorted(self.zones.keys())]
391
392     def process(self) -> None:
393         self.fill_reverse()
394         for z in self.get_zones():
395             z.process()