]> mj.ucw.cz Git - pynsc.git/blob - nsc/core.py
Split to modules
[pynsc.git] / nsc / core.py
1 from collections import defaultdict
2 from datetime import timedelta
3 import dns.name
4 from dns.name import Name
5 from dns.node import Node
6 from dns.rdata import Rdata
7 from dns.rdataclass import RdataClass
8 from dns.rdatatype import RdataType
9 import dns.rdtypes.ANY.MX
10 import dns.rdtypes.ANY.NS
11 import dns.rdtypes.ANY.PTR
12 import dns.rdtypes.ANY.SOA
13 import dns.rdtypes.ANY.TXT
14 import dns.rdtypes.IN.A
15 import dns.rdtypes.IN.AAAA
16 from dns.zone import Zone
17 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
18 import socket
19 from typing import Optional, Dict, List, Self, Tuple, DefaultDict
20
21
22 IPAddress = IPv4Address | IPv6Address
23 IPNetwork = IPv4Network | IPv6Network
24 IPAddr = str | IPAddress | List[str | IPAddress]
25
26
27 class NscNode:
28     nsc_zone: 'NscZone'
29     name: str
30     node: Node
31     _ttl: int
32
33     def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
34         self.nsc_zone = nsc_zone
35         self.name = name
36         self.node = nsc_zone.zone.find_node(name, create=True)
37         self._ttl = nsc_zone._min_ttl
38
39     def ttl(self, *args, **kwargs) -> Self:
40         if not args and not kwargs:
41             self._ttl = self.nsc_zone._min_ttl
42         else:
43             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
44         return self
45
46     def _add(self, rec: Rdata) -> None:
47         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
48         rds.add(rec, ttl=self._ttl)
49
50     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
51         out = []
52         for a in addrs:
53             if not isinstance(a, list):
54                 a = [a]
55             for b in a:
56                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
57                     out.append(b)
58                 else:
59                     out.append(ip_address(b))
60         return out
61
62     def _parse_name(self, name: str) -> Name:
63         # FIXME: Names with escaped dots
64         if '.' in name:
65             return dns.name.from_text(name)
66         else:
67             return dns.name.from_text(name, origin=None)
68
69     def _parse_names(self, names: str | List[str]) -> List[Name]:
70         if isinstance(names, str):
71             return [self._parse_name(names)]
72         else:
73             return [self._parse_name(n) for n in names]
74
75     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
76         for a in self._parse_addrs(addrs):
77             if isinstance(a, IPv4Address):
78                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
79             else:
80                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
81             if reverse:
82                 self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
83         return self
84
85     def MX(self, pri: int, name: str) -> Self:
86         self._add(
87             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
88         )
89         return self
90
91     def NS(self, names: str | List[str]) -> Self:
92         # FIXME: Variadic?
93         for name in self._parse_names(names):
94             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
95         return self
96
97     def TXT(self, text: str) -> Self:
98         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
99         return self
100
101     def PTR(self, target: Name | str) -> Self:
102         self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
103         return self
104
105     def generic(self, typ: str, text: str) -> Self:
106         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
107         return self
108
109
110 class NscZoneConfig:
111     admin_email: str
112     refresh: timedelta
113     retry: timedelta
114     expire: timedelta
115     min_ttl: timedelta
116     origin_server: str
117
118     default_config: Optional['NscZoneConfig'] = None
119
120     def __init__(self,
121                  admin_email: Optional[str] = None,
122                  refresh: Optional[timedelta] = None,
123                  retry: Optional[timedelta] = None,
124                  expire: Optional[timedelta] = None,
125                  min_ttl: Optional[timedelta] = None,
126                  origin_server: Optional[str] = None,
127                  inherit_config: Optional['NscZoneConfig'] = None,
128                  ) -> None:
129         if inherit_config is None:
130             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
131         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
132         self.refresh = refresh if refresh is not None else inherit_config.refresh
133         self.retry = retry if retry is not None else inherit_config.retry
134         self.expire = expire if expire is not None else inherit_config.expire
135         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
136         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
137
138     def finalize(self) -> Self:
139         if not self.origin_server:
140             self.origin_server = socket.getfqdn()
141         if not self.admin_email:
142             self.admin_email = f'hostmaster@{self.origin_server}'
143         return self
144
145
146 NscZoneConfig.default_config = NscZoneConfig(
147     admin_email="",
148     refresh=timedelta(hours=8),
149     retry=timedelta(hours=2),
150     expire=timedelta(days=14),
151     min_ttl=timedelta(days=1),
152     origin_server="",
153 )
154
155
156 class NscZone:
157     nsc: 'Nsc'
158     name: str
159     zone: Zone
160     _min_ttl: int
161     reverse_for: Optional[IPNetwork]
162
163     def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
164         self.nsc = nsc
165         self.name = name
166         self.config = NscZoneConfig(**kwargs).finalize()
167         self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
168         self._min_ttl = int(self.config.min_ttl.total_seconds())
169         self.reverse_for = reverse_for
170
171         conf = self.config
172         root = self[""]
173         root._add(
174             dns.rdtypes.ANY.SOA.SOA(
175                 RdataClass.IN, RdataType.SOA,
176                 mname=conf.origin_server,
177                 rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
178                 serial=12345,
179                 refresh=int(conf.refresh.total_seconds()),
180                 retry=int(conf.retry.total_seconds()),
181                 expire=int(conf.expire.total_seconds()),
182                 minimum=int(conf.min_ttl.total_seconds()),
183             )
184         )
185
186     def n(self, name: str) -> NscNode:
187         return NscNode(self, name)
188
189     def __getitem__(self, name: str) -> NscNode:
190         return NscNode(self, name)
191
192     def host(self, name: str, *args, reverse: bool = True) -> NscNode:
193         n = NscNode(self, name)
194         n.A(*args, reverse=reverse)
195         return n
196
197     def dump(self) -> None:
198         # Could use self.zone.to_file(sys.stdout), but we want better formatting
199         print(f'; Zone file for {self.name}')
200         last_name = None
201         for name, ttl, rec in self.zone.iterate_rdatas():
202             if name == last_name:
203                 print_name = ""
204             else:
205                 print_name = name
206             print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
207             last_name = name
208
209     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
210         # Called only for addresses from this reverse network
211         assert self.reverse_for is not None
212         parts = str(addr).split('.')
213         parts = parts[self.reverse_for.prefixlen // 8:]
214         name = '.'.join(reversed(parts))
215         self.n(name).PTR(ptr_to)
216
217     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
218         # Called only for addresses from this reverse network
219         assert self.reverse_for is not None
220         parts = addr.exploded.replace(':', "")
221         parts = parts[self.reverse_for.prefixlen // 4:]
222         name = '.'.join(reversed(parts))
223         self.n(name).PTR(ptr_to)
224
225
226 class Nsc:
227     zones: Dict[str, NscZone]
228     default_zone_config: NscZoneConfig
229     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
230     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
231
232     def __init__(self, **kwargs) -> None:
233         self.zones = {}
234         self.default_zone_config = NscZoneConfig(**kwargs)
235         self.ipv4_reverse = defaultdict(list)
236         self.ipv6_reverse = defaultdict(list)
237
238     def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
239         if inherit_config is None:
240             inherit_config = self.default_zone_config
241         z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
242         assert z.name not in self.zones
243         self.zones[z.name] = z
244         return z
245
246     def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
247         if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
248             net = ip_network(net, strict=True)
249         name = name or self._reverse_zone_name(net)
250         return self.add_zone(name, reverse_for=net, **kwargs)
251
252     def _reverse_zone_name(self, net: IPNetwork) -> str:
253         if isinstance(net, IPv4Network):
254             parts = str(net.network_address).split('.')
255             out = parts[:net.prefixlen // 8]
256             if net.prefixlen % 8 != 0:
257                 out.append(parts[len(out)] + '/' + str(net.prefixlen))
258             return '.'.join(reversed(out)) + '.in-addr.arpa'
259         elif isinstance(net, IPv6Network):
260             assert net.prefixlen % 4 == 0
261             nibbles = net.network_address.exploded.replace(':', "")
262             nibbles = nibbles[:net.prefixlen // 4]
263             return '.'.join(reversed(nibbles)) + '.ip6.arpa'
264         else:
265             raise NotImplementedError()
266
267     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
268         if isinstance(addr, IPv4Address):
269             self.ipv4_reverse[addr].append(ptr_to)
270         else:
271             self.ipv6_reverse[addr].append(ptr_to)
272
273     def dump_reverse(self) -> None:
274         print('### Requests for reverse mappings ###')
275         for ipa4, name in sorted(self.ipv4_reverse.items()):
276             print(f'{ipa4}\t{name}')
277         for ipa6, name in sorted(self.ipv6_reverse.items()):
278             print(f'{ipa6}\t{name}')
279
280     def fill_reverse(self) -> None:
281         for z in self.zones.values():
282             if z.reverse_for is not None:
283                 if isinstance(z.reverse_for, IPv4Network):
284                     for addr4, ptr_list in self.ipv4_reverse.items():
285                         if addr4 in z.reverse_for:
286                             for ptr_to in ptr_list:
287                                 z._add_ipv4_reverse(addr4, ptr_to)
288                 else:
289                     for addr6, ptr_list in self.ipv6_reverse.items():
290                         if addr6 in z.reverse_for:
291                             for ptr_to in ptr_list:
292                                 z._add_ipv6_reverse(addr6, ptr_to)
293
294     def dump(self) -> None:
295         for z in self.zones.values():
296             z.dump()