]> mj.ucw.cz Git - pynsc.git/blob - nsc.py
Enter NscZoneConfig
[pynsc.git] / nsc.py
1 #!/usr/bin/env python3
2
3 from dataclasses import dataclass
4 from datetime import timedelta
5 import dns.name
6 from dns.name import Name
7 from dns.node import Node
8 from dns.rdata import Rdata
9 from dns.rdataclass import RdataClass
10 from dns.rdatatype import RdataType
11 import dns.rdtypes.ANY.MX
12 import dns.rdtypes.ANY.NS
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from ipaddress import ip_address, IPv4Address, IPv6Address
19 import socket
20 import sys
21 from typing import Optional, Dict, List, Self, Tuple
22
23
24 IPAddress = IPv4Address | IPv6Address
25 IPAddr = str | IPAddress | List[str | IPAddress]
26
27
28 class NscNode:
29     nsc_zone: 'NscZone'
30     name: str
31     node: Node
32     _ttl: int
33
34     def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
35         self.nsc_zone = nsc_zone
36         self.name = name
37         self.node = nsc_zone.zone.find_node(name, create=True)
38         self._ttl = nsc_zone._min_ttl
39
40     def ttl(self, *args, **kwargs) -> Self:
41         if not args and not kwargs:
42             self._ttl = self.nsc_zone._min_ttl
43         else:
44             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
45         return self
46
47     def _add(self, rec: Rdata) -> None:
48         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
49         rds.add(rec, ttl=self._ttl)
50
51     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
52         out = []
53         for a in addrs:
54             if not isinstance(a, list):
55                 a = [a]
56             for b in a:
57                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
58                     out.append(b)
59                 else:
60                     out.append(ip_address(b))
61         return out
62
63     def _parse_name(self, name: str) -> Name:
64         # FIXME: Names with escaped dots
65         if '.' in name:
66             return dns.name.from_text(name)
67         else:
68             return dns.name.from_text(name, origin=None)
69
70     def _parse_names(self, names: str | List[str]) -> List[Name]:
71         if isinstance(names, str):
72             return [self._parse_name(names)]
73         else:
74             return [self._parse_name(n) for n in names]
75
76     def A(self, *addrs: IPAddr) -> Self:
77         for a in self._parse_addrs(addrs):
78             if isinstance(a, IPv4Address):
79                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
80             else:
81                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
82         return self
83
84     def MX(self, pri: int, name: str) -> Self:
85         self._add(
86             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
87         )
88         return self
89
90     def NS(self, names: str | List[str]) -> Self:
91         for name in self._parse_names(names):
92             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
93         return self
94
95     def TXT(self, text: str) -> Self:
96         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
97         return self
98
99     def generic(self, typ: str, text: str) -> Self:
100         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
101         return self
102
103
104 class NscZoneConfig:
105     admin_email: str
106     refresh: timedelta
107     retry: timedelta
108     expire: timedelta
109     min_ttl: timedelta
110     origin_server: str
111
112     default_config: Optional['NscZoneConfig'] = None
113
114     def __init__(self,
115                  admin_email: Optional[str] = None,
116                  refresh: Optional[timedelta] = None,
117                  retry: Optional[timedelta] = None,
118                  expire: Optional[timedelta] = None,
119                  min_ttl: Optional[timedelta] = None,
120                  origin_server: Optional[str] = None,
121                  inherit_config: Optional['NscZoneConfig'] = None,
122                  ) -> None:
123         if inherit_config is None:
124             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
125         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
126         self.refresh = refresh if refresh is not None else inherit_config.refresh
127         self.retry = retry if retry is not None else inherit_config.retry
128         self.expire = expire if expire is not None else inherit_config.expire
129         self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
130         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
131
132     def finalize(self) -> Self:
133         if not self.origin_server:
134             self.origin_server = socket.getfqdn()
135         if not self.admin_email:
136             self.admin_email = f'hostmaster@{self.origin_server}'
137         return self
138
139
140 NscZoneConfig.default_config = NscZoneConfig(
141     admin_email="",
142     refresh=timedelta(hours=8),
143     retry=timedelta(hours=2),
144     expire=timedelta(days=14),
145     min_ttl=timedelta(days=1),
146     origin_server="",
147 )
148
149
150 class NscZone:
151     name: str
152     zone: Zone
153     _min_ttl: int
154
155     def __init__(self, name: str, **kwargs) -> None:
156         self.name = name
157         self.config = NscZoneConfig(**kwargs).finalize()
158         self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
159         self._min_ttl = int(self.config.min_ttl.total_seconds())
160
161         conf = self.config
162         root = self[""]
163         root._add(
164             dns.rdtypes.ANY.SOA.SOA(
165                 RdataClass.IN, RdataType.SOA,
166                 mname=conf.origin_server,
167                 rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
168                 serial=12345,
169                 refresh=int(conf.refresh.total_seconds()),
170                 retry=int(conf.retry.total_seconds()),
171                 expire=int(conf.expire.total_seconds()),
172                 minimum=int(conf.min_ttl.total_seconds()),
173             )
174         )
175
176     def n(self, name: str) -> NscNode:
177         return NscNode(self, name)
178
179     def __getitem__(self, name: str) -> NscNode:
180         return NscNode(self, name)
181
182     def host(self, name: str, *args) -> NscNode:
183         n = NscNode(self, name)
184         n.A(*args)
185         return n
186
187     def dump(self) -> None:
188         # Could use self.zone.to_file(sys.stdout), but we want better formatting
189         last_name = None
190         for name, ttl, rec in self.zone.iterate_rdatas():
191             if name == last_name:
192                 print_name = ""
193             else:
194                 print_name = name
195             print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
196             last_name = name
197
198
199 class Nsc:
200     zones: Dict[str, Zone]
201     default_zone_config: NscZoneConfig
202
203     def __init__(self, **kwargs) -> None:
204         self.zones = {}
205         self.default_zone_config = NscZoneConfig(**kwargs)
206
207     def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
208         if inherit_config is None:
209             inherit_config = self.default_zone_config
210         dom = NscZone(*args, inherit_config=inherit_config, **kwargs)
211         assert dom.name not in self.zones
212         self.zones[dom.name] = dom
213         return dom
214
215
216 class MyZone(Zone):
217     admin_email = 'admin@ucw.cz'
218     origin_server = 'ns.ucw.cz'
219
220
221 c = Nsc()
222 z = c.add_zone('ucw.cz')  # origin_server='jabberwock.ucw.cz')
223
224 z[""].NS(['jabberwock', 'chirigo.gebbeth.cz', 'drak.ucw.cz'])
225
226 z['jabberwock'].A('1.2.3.4', '2a00:da80:fff0:2::2')
227
228 z.host('test', '1.2.3.4', ['5.6.7.8', '8.7.6.5'])
229
230 (z['mnau']
231     .A('195.113.31.123')
232     .MX(0, 'jabberwock')
233     .ttl(minutes=15)
234     .TXT('hey?')
235     .generic('HINFO', 'Something fishy'))
236
237 z.dump()