]> mj.ucw.cz Git - pynsc.git/blob - nsc.py
Reverse mappings
[pynsc.git] / nsc.py
1 #!/usr/bin/env python3
2
3 from collections import defaultdict
4 from dataclasses import dataclass
5 from datetime import timedelta
6 import dns.name
7 from dns.name import Name
8 from dns.node import Node
9 from dns.rdata import Rdata
10 from dns.rdataclass import RdataClass
11 from dns.rdatatype import RdataType
12 import dns.rdtypes.ANY.MX
13 import dns.rdtypes.ANY.NS
14 import dns.rdtypes.ANY.PTR
15 import dns.rdtypes.ANY.SOA
16 import dns.rdtypes.ANY.TXT
17 import dns.rdtypes.IN.A
18 import dns.rdtypes.IN.AAAA
19 from dns.zone import Zone
20 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
21 import socket
22 import sys
23 from typing import Optional, Dict, List, Self, Tuple, DefaultDict
24
25
26 IPAddress = IPv4Address | IPv6Address
27 IPNetwork = IPv4Network | IPv6Network
28 IPAddr = str | IPAddress | List[str | IPAddress]
29
30
31 class NscNode:
32     nsc_zone: 'NscZone'
33     name: str
34     node: Node
35     _ttl: int
36
37     def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
38         self.nsc_zone = nsc_zone
39         self.name = name
40         self.node = nsc_zone.zone.find_node(name, create=True)
41         self._ttl = nsc_zone._min_ttl
42
43     def ttl(self, *args, **kwargs) -> Self:
44         if not args and not kwargs:
45             self._ttl = self.nsc_zone._min_ttl
46         else:
47             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
48         return self
49
50     def _add(self, rec: Rdata) -> None:
51         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
52         rds.add(rec, ttl=self._ttl)
53
54     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
55         out = []
56         for a in addrs:
57             if not isinstance(a, list):
58                 a = [a]
59             for b in a:
60                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
61                     out.append(b)
62                 else:
63                     out.append(ip_address(b))
64         return out
65
66     def _parse_name(self, name: str) -> Name:
67         # FIXME: Names with escaped dots
68         if '.' in name:
69             return dns.name.from_text(name)
70         else:
71             return dns.name.from_text(name, origin=None)
72
73     def _parse_names(self, names: str | List[str]) -> List[Name]:
74         if isinstance(names, str):
75             return [self._parse_name(names)]
76         else:
77             return [self._parse_name(n) for n in names]
78
79     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
80         for a in self._parse_addrs(addrs):
81             if isinstance(a, IPv4Address):
82                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
83             else:
84                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
85             if reverse:
86                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
87         return self
88
89     def MX(self, pri: int, name: str) -> Self:
90         self._add(
91             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
92         )
93         return self
94
95     def NS(self, names: str | List[str]) -> Self:
96         # FIXME: Variadic?
97         for name in self._parse_names(names):
98             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
99         return self
100
101     def TXT(self, text: str) -> Self:
102         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
103         return self
104
105     def PTR(self, target: Name | str) -> Self:
106         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
107         return self
108
109     def generic(self, typ: str, text: str) -> Self:
110         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
111         return self
112
113
114 class NscZoneConfig:
115     admin_email: str
116     refresh: timedelta
117     retry: timedelta
118     expire: timedelta
119     min_ttl: timedelta
120     origin_server: str
121
122     default_config: Optional['NscZoneConfig'] = None
123
124     def __init__(self,
125                  admin_email: Optional[str] = None,
126                  refresh: Optional[timedelta] = None,
127                  retry: Optional[timedelta] = None,
128                  expire: Optional[timedelta] = None,
129                  min_ttl: Optional[timedelta] = None,
130                  origin_server: Optional[str] = None,
131                  inherit_config: Optional['NscZoneConfig'] = None,
132                  ) -> None:
133         if inherit_config is None:
134             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
135         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
136         self.refresh = refresh if refresh is not None else inherit_config.refresh
137         self.retry = retry if retry is not None else inherit_config.retry
138         self.expire = expire if expire is not None else inherit_config.expire
139         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
140         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
141
142     def finalize(self) -> Self:
143         if not self.origin_server:
144             self.origin_server = socket.getfqdn()
145         if not self.admin_email:
146             self.admin_email = f'hostmaster@{self.origin_server}'
147         return self
148
149
150 NscZoneConfig.default_config = NscZoneConfig(
151     admin_email="",
152     refresh=timedelta(hours=8),
153     retry=timedelta(hours=2),
154     expire=timedelta(days=14),
155     min_ttl=timedelta(days=1),
156     origin_server="",
157 )
158
159
160 class NscZone:
161     nsc: 'Nsc'
162     name: str
163     zone: Zone
164     _min_ttl: int
165     reverse_for: Optional[IPNetwork]
166
167     def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
168         self.nsc = nsc
169         self.name = name
170         self.config = NscZoneConfig(**kwargs).finalize()
171         self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
172         self._min_ttl = int(self.config.min_ttl.total_seconds())
173         self.reverse_for = reverse_for
174
175         conf = self.config
176         root = self[""]
177         root._add(
178             dns.rdtypes.ANY.SOA.SOA(
179                 RdataClass.IN, RdataType.SOA,
180                 mname=conf.origin_server,
181                 rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
182                 serial=12345,
183                 refresh=int(conf.refresh.total_seconds()),
184                 retry=int(conf.retry.total_seconds()),
185                 expire=int(conf.expire.total_seconds()),
186                 minimum=int(conf.min_ttl.total_seconds()),
187             )
188         )
189
190     def n(self, name: str) -> NscNode:
191         return NscNode(self, name)
192
193     def __getitem__(self, name: str) -> NscNode:
194         return NscNode(self, name)
195
196     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
197         n = NscNode(self, name)
198         n.A(*args, reverse=reverse)
199         return n
200
201     def dump(self) -> None:
202         # Could use self.zone.to_file(sys.stdout), but we want better formatting
203         print(f'; Zone file for {self.name}')
204         last_name = None
205         for name, ttl, rec in self.zone.iterate_rdatas():
206             if name == last_name:
207                 print_name = ""
208             else:
209                 print_name = name
210             print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
211             last_name = name
212
213     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
214         # Called only for addresses from this reverse network
215         assert self.reverse_for is not None
216         parts = str(addr).split('.')
217         parts = parts[self.reverse_for.prefixlen // 8:]
218         name = '.'.join(reversed(parts))
219         self.n(name).PTR(ptr_to)
220
221     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
222         # Called only for addresses from this reverse network
223         assert self.reverse_for is not None
224         parts = addr.exploded.replace(':', "")
225         parts = parts[self.reverse_for.prefixlen // 4:]
226         name = '.'.join(reversed(parts))
227         self.n(name).PTR(ptr_to)
228
229
230 class Nsc:
231     zones: Dict[str, NscZone]
232     default_zone_config: NscZoneConfig
233     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
234     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
235
236     def __init__(self, **kwargs) -> None:
237         self.zones = {}
238         self.default_zone_config = NscZoneConfig(**kwargs)
239         self.ipv4_reverse = defaultdict(list)
240         self.ipv6_reverse = defaultdict(list)
241
242     def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
243         if inherit_config is None:
244             inherit_config = self.default_zone_config
245         z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
246         assert z.name not in self.zones
247         self.zones[z.name] = z
248         return z
249
250     def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
251         if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
252             net = ip_network(net, strict=True)
253         name = name or self._reverse_zone_name(net)
254         return self.add_zone(name, reverse_for=net, **kwargs)
255
256     def _reverse_zone_name(self, net: IPNetwork) -> str:
257         if isinstance(net, IPv4Network):
258             parts = str(net.network_address).split('.')
259             out = parts[:net.prefixlen // 8]
260             if net.prefixlen % 8 != 0:
261                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
262             return '.'.join(reversed(out)) + '.in-addr.arpa'
263         elif isinstance(net, IPv6Network):
264             assert net.prefixlen % 4 == 0
265             nibbles = net.network_address.exploded.replace(':', "")
266             nibbles = nibbles[:net.prefixlen // 4]
267             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
268         else:
269             raise NotImplementedError()
270
271     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
272         if isinstance(addr, IPv4Address):
273             self.ipv4_reverse[addr].append(ptr_to)
274         else:
275             self.ipv6_reverse[addr].append(ptr_to)
276
277     def dump_reverse(self) -> None:
278         print('### Requests for reverse mappings ###')
279         for ipa4, name in sorted(self.ipv4_reverse.items()):
280             print(f'{ipa4}\t{name}')
281         for ipa6, name in sorted(self.ipv6_reverse.items()):
282             print(f'{ipa6}\t{name}')
283
284     def fill_reverse(self) -> None:
285         for z in self.zones.values():
286             if z.reverse_for is not None:
287                 if isinstance(z.reverse_for, IPv4Network):
288                     for addr4, ptr_list in self.ipv4_reverse.items():
289                         if addr4 in z.reverse_for:
290                             for ptr_to in ptr_list:
291                                 z._add_ipv4_reverse(addr4, ptr_to)
292                 else:
293                     for addr6, ptr_list in self.ipv6_reverse.items():
294                         if addr6 in z.reverse_for:
295                             for ptr_to in ptr_list:
296                                 z._add_ipv6_reverse(addr6, ptr_to)
297
298
299 c = Nsc(
300     admin_email='admin@ucw.cz',
301     origin_server='ns.ucw.cz',
302 )
303
304 z = c.add_zone('ucw.cz')
305
306 z[""].NS(['jabberwock', 'chirigo.gebbeth.cz', 'drak.ucw.cz'])
307
308 z['jabberwock'].A('1.2.3.4', '2a00:da80:fff0:2::2', '195.113.31.123')
309
310 z.host('test', '1.2.3.4', ['5.6.7.8', '8.7.6.5'])
311
312 (z['mnau']
313     .A('195.113.31.123')
314     .MX(0, 'jabberwock')
315     .ttl(minutes=15)
316     .TXT('hey?')
317     .generic('HINFO', 'Something fishy'))
318
319 z.dump()
320
321 r = c.add_reverse_zone('195.113.0.0/16')
322 r2 = c.add_reverse_zone('2a00:da80:fff0:2::/64')
323
324 c.dump_reverse()
325 c.fill_reverse()
326
327 r.dump()
328 r2.dump()