]> mj.ucw.cz Git - pynsc.git/blob - nsc.py
41abe3c8dad591ed052aa9a674967ae41d588d1f
[pynsc.git] / nsc.py
1 #!/usr/bin/env python3
2
3 from dataclasses import dataclass
4 from datetime import timedelta
5 import dns.name
6 from dns.name import Name
7 from dns.node import Node
8 from dns.rdata import Rdata
9 from dns.rdataclass import RdataClass
10 from dns.rdatatype import RdataType
11 import dns.rdtypes.ANY.MX
12 import dns.rdtypes.ANY.NS
13 import dns.rdtypes.ANY.SOA
14 import dns.rdtypes.ANY.TXT
15 import dns.rdtypes.IN.A
16 import dns.rdtypes.IN.AAAA
17 from dns.zone import Zone
18 from ipaddress import ip_address, IPv4Address, IPv6Address
19 import socket
20 import sys
21 from typing import Optional, Dict, List, Self
22
23
24 IPAddress = IPv4Address | IPv6Address
25
26
27 class NscNode:
28     nsc_zone: 'NscZone'
29     name: str
30     node: Node
31     _ttl: int
32
33     def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
34         self.nsc_zone = nsc_zone
35         self.name = name
36         self.node = nsc_zone.zone.find_node(name, create=True)
37         self._ttl = int(nsc_zone.min_ttl.total_seconds())
38
39     def ttl(self, *args, **kwargs) -> Self:
40         if not args and not kwargs:
41             self._ttl = int(self.nsc_zone.min_ttl.total_seconds())
42         else:
43             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
44         return self
45
46     def _add(self, rec: Rdata) -> None:
47         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
48         rds.add(rec, ttl=self._ttl)
49
50     def _parse_addrs(self, addrs: str | IPAddress | List[str | IPAddress]) -> List[IPAddress]:
51         if not isinstance(addrs, list):
52             addrs = [addrs]
53         out = []
54         for a in addrs:
55             if isinstance(a, IPv4Address) or isinstance(a, IPv6Address):
56                 out.append(a)
57             else:
58                 out.append(ip_address(a))
59         return out
60
61     def _parse_name(self, name: str) -> Name:
62         # FIXME: Names with escaped dots
63         if '.' in name:
64             return dns.name.from_text(name)
65         else:
66             return dns.name.from_text(name, origin=None)
67
68     def _parse_names(self, names: str | List[str]) -> List[Name]:
69         if isinstance(names, str):
70             return [self._parse_name(names)]
71         else:
72             return [self._parse_name(n) for n in names]
73
74     def A(self, addrs: str | IPAddress | List[str | IPAddress]) -> Self:
75         for a in self._parse_addrs(addrs):
76             if isinstance(a, IPv4Address):
77                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
78             else:
79                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
80         return self
81
82     def MX(self, pri: int, name: str) -> Self:
83         self._add(
84             dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
85         )
86         return self
87
88     def NS(self, names: str | List[str]) -> Self:
89         for name in self._parse_names(names):
90             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
91         return self
92
93     def TXT(self, text: str) -> Self:
94         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
95         return self
96
97     def generic(self, typ: str, text: str) -> Self:
98         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
99
100
101 class NscZone:
102     name: str
103     admin_email: Optional[str] = None
104     refresh: timedelta = timedelta(hours=8)
105     retry: timedelta = timedelta(hours=2)
106     expire: timedelta = timedelta(days=14)
107     min_ttl: timedelta = timedelta(days=1)
108     origin_server: Optional[str] = None
109     zone: Zone
110
111     def __init__(self,
112                  name: str,
113                  admin_email: Optional[str] = None,
114                  refresh: Optional[timedelta] = None,
115                  retry: Optional[timedelta] = None,
116                  expire: Optional[timedelta] = None,
117                  min_ttl: Optional[timedelta] = None,
118                  origin_server: Optional[str] = None,
119                  ) -> None:
120         self.name = name
121         self.admin_email = admin_email if admin_email is not None else self.admin_email
122         self.refresh = refresh if refresh is not None else self.refresh
123         self.retry = retry if retry is not None else self.retry
124         self.expire = expire if expire is not None else self.expire
125         self.min_ttl = min_ttl if min_ttl is not None else self.min_ttl
126         self.origin_server = origin_server if origin_server is not None else self.origin_server
127         self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
128
129         if self.origin_server is None:
130             self.origin_server = socket.getfqdn()
131
132         if self.admin_email is None:
133             self.admin_email = f'root@{self.origin_server}'
134
135         root = self[""]
136         root._add(
137             dns.rdtypes.ANY.SOA.SOA(
138                 RdataClass.IN, RdataType.SOA,
139                 mname=self.origin_server,
140                 rname=self.admin_email.replace('@', '.'),   # FIXME: names with dots
141                 serial=12345,
142                 refresh=int(self.refresh.total_seconds()),
143                 retry=int(self.retry.total_seconds()),
144                 expire=int(self.expire.total_seconds()),
145                 minimum=int(self.min_ttl.total_seconds()),
146             )
147         )
148
149     def n(self, name: str) -> NscNode:
150         return NscNode(self, name)
151
152     def __getitem__(self, name: str) -> NscNode:
153         return NscNode(self, name)
154
155     def dump(self) -> None:
156         # Could use self.zone.to_file(sys.stdout), but we want better formatting
157         last_name = None
158         min_ttl = int(self.min_ttl.total_seconds())
159         for name, ttl, rec in self.zone.iterate_rdatas():
160             if name == last_name:
161                 print_name = ""
162             else:
163                 print_name = name
164             print(f'{print_name}\t{ttl if ttl != min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
165             last_name = name
166
167
168 class Config:
169     zones: Dict[str, Zone]
170
171     def __init__(self) -> None:
172         self.zones = {}
173
174     def add_zone(self, *args, **kwargs) -> Zone:
175         dom = NscZone(*args, **kwargs)
176         assert dom.name not in self.zones
177         self.zones[dom.name] = dom
178         return dom
179
180
181 class MyZone(Zone):
182     admin_email = 'admin@ucw.cz'
183     origin_server = 'ns.ucw.cz'
184
185
186 c = Config()
187 z = c.add_zone('ucw.cz')  # origin_server='jabberwock.ucw.cz')
188
189 z[""].NS(['jabberwock', 'chirigo.gebbeth.cz', 'drak.ucw.cz'])
190
191 z['jabberwock'].A(['1.2.3.4', '2a00:da80:fff0:2::2'])
192
193 (z['mnau']
194     .A('195.113.31.123')
195     .MX(0, 'jabberwock')
196     .ttl(minutes=15)
197     .TXT('hey?')
198     .generic('HINFO', 'Something fishy'))
199
200 z.dump()