]> mj.ucw.cz Git - pynsc.git/blob - nsc.py
Variadic A and host()
[pynsc.git] / nsc.py
1 #!/usr/bin/env python3
2
3 from dataclasses import dataclass
4 from datetime import timedelta
5 import dns.name
6 from dns.name import Name
7 from dns.node import Node
8 from dns.rdata import Rdata
9 from dns.rdataclass import RdataClass
10 from dns.rdatatype import RdataType
11 import dns.rdtypes.ANY.MX
12 import dns.rdtypes.ANY.NS
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from ipaddress import ip_address, IPv4Address, IPv6Address
19 import socket
20 import sys
21 from typing import Optional, Dict, List, Self, Tuple
22
23
24 IPAddress = IPv4Address | IPv6Address
25 IPAddr = str | IPAddress | List[str | IPAddress]
26
27
28 class NscNode:
29     nsc_zone: 'NscZone'
30     name: str
31     node: Node
32     _ttl: int
33
34     def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
35         self.nsc_zone = nsc_zone
36         self.name = name
37         self.node = nsc_zone.zone.find_node(name, create=True)
38         self._ttl = int(nsc_zone.min_ttl.total_seconds())
39
40     def ttl(self, *args, **kwargs) -> Self:
41         if not args and not kwargs:
42             self._ttl = int(self.nsc_zone.min_ttl.total_seconds())
43         else:
44             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
45         return self
46
47     def _add(self, rec: Rdata) -> None:
48         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
49         rds.add(rec, ttl=self._ttl)
50
51     def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
52         out = []
53         for a in addrs:
54             if not isinstance(a, list):
55                 a = [a]
56             for b in a:
57                 if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
58                     out.append(b)
59                 else:
60                     out.append(ip_address(b))
61         return out
62
63     def _parse_name(self, name: str) -> Name:
64         # FIXME: Names with escaped dots
65         if '.' in name:
66             return dns.name.from_text(name)
67         else:
68             return dns.name.from_text(name, origin=None)
69
70     def _parse_names(self, names: str | List[str]) -> List[Name]:
71         if isinstance(names, str):
72             return [self._parse_name(names)]
73         else:
74             return [self._parse_name(n) for n in names]
75
76     def A(self, *addrs: IPAddr) -> Self:
77         for a in self._parse_addrs(addrs):
78             if isinstance(a, IPv4Address):
79                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
80             else:
81                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
82         return self
83
84     def MX(self, pri: int, name: str) -> Self:
85         self._add(
86             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
87         )
88         return self
89
90     def NS(self, names: str | List[str]) -> Self:
91         for name in self._parse_names(names):
92             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
93         return self
94
95     def TXT(self, text: str) -> Self:
96         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
97         return self
98
99     def generic(self, typ: str, text: str) -> Self:
100         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
101         return self
102
103
104 class NscZone:
105     name: str
106     admin_email: Optional[str] = None
107     refresh: timedelta = timedelta(hours=8)
108     retry: timedelta = timedelta(hours=2)
109     expire: timedelta = timedelta(days=14)
110     min_ttl: timedelta = timedelta(days=1)
111     origin_server: Optional[str] = None
112     zone: Zone
113
114     def __init__(self,
115                  name: str,
116                  admin_email: Optional[str] = None,
117                  refresh: Optional[timedelta] = None,
118                  retry: Optional[timedelta] = None,
119                  expire: Optional[timedelta] = None,
120                  min_ttl: Optional[timedelta] = None,
121                  origin_server: Optional[str] = None,
122                  ) -> None:
123         self.name = name
124         self.admin_email = admin_email if admin_email is not None else self.admin_email
125         self.refresh = refresh if refresh is not None else self.refresh
126         self.retry = retry if retry is not None else self.retry
127         self.expire = expire if expire is not None else self.expire
128         self.min_ttl = min_ttl if min_ttl is not None else self.min_ttl
129         self.origin_server = origin_server if origin_server is not None else self.origin_server
130         self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
131
132         if self.origin_server is None:
133             self.origin_server = socket.getfqdn()
134
135         if self.admin_email is None:
136             self.admin_email = f'root@{self.origin_server}'
137
138         root = self[""]
139         root._add(
140             dns.rdtypes.ANY.SOA.SOA(
141                 RdataClass.IN, RdataType.SOA,
142                 mname=self.origin_server,
143                 rname=self.admin_email.replace('@', '.'),   # FIXME: names with dots
144                 serial=12345,
145                 refresh=int(self.refresh.total_seconds()),
146                 retry=int(self.retry.total_seconds()),
147                 expire=int(self.expire.total_seconds()),
148                 minimum=int(self.min_ttl.total_seconds()),
149             )
150         )
151
152     def n(self, name: str) -> NscNode:
153         return NscNode(self, name)
154
155     def __getitem__(self, name: str) -> NscNode:
156         return NscNode(self, name)
157
158     def host(self, name: str, *args) -> NscNode:
159         n = NscNode(self, name)
160         n.A(*args)
161         return n
162
163     def dump(self) -> None:
164         # Could use self.zone.to_file(sys.stdout), but we want better formatting
165         last_name = None
166         min_ttl = int(self.min_ttl.total_seconds())
167         for name, ttl, rec in self.zone.iterate_rdatas():
168             if name == last_name:
169                 print_name = ""
170             else:
171                 print_name = name
172             print(f'{print_name}\t{ttl if ttl != min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
173             last_name = name
174
175
176 class Config:
177     zones: Dict[str, Zone]
178
179     def __init__(self) -> None:
180         self.zones = {}
181
182     def add_zone(self, *args, **kwargs) -> Zone:
183         dom = NscZone(*args, **kwargs)
184         assert dom.name not in self.zones
185         self.zones[dom.name] = dom
186         return dom
187
188
189 class MyZone(Zone):
190     admin_email = 'admin@ucw.cz'
191     origin_server = 'ns.ucw.cz'
192
193
194 c = Config()
195 z = c.add_zone('ucw.cz')  # origin_server='jabberwock.ucw.cz')
196
197 z[""].NS(['jabberwock', 'chirigo.gebbeth.cz', 'drak.ucw.cz'])
198
199 z['jabberwock'].A('1.2.3.4', '2a00:da80:fff0:2::2')
200
201 z.host('test', '1.2.3.4', ['5.6.7.8', '8.7.6.5'])
202
203 (z['mnau']
204     .A('195.113.31.123')
205     .MX(0, 'jabberwock')
206     .ttl(minutes=15)
207     .TXT('hey?')
208     .generic('HINFO', 'Something fishy'))
209
210 z.dump()