]> mj.ucw.cz Git - pynsc.git/blobdiff - nsc.py
Reverse mappings
[pynsc.git] / nsc.py
diff --git a/nsc.py b/nsc.py
index 41abe3c8dad591ed052aa9a674967ae41d588d1f..2a01f23f7f27715410bcd7fc6a177e7520a19b0e 100755 (executable)
--- a/nsc.py
+++ b/nsc.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 
+from collections import defaultdict
 from dataclasses import dataclass
 from datetime import timedelta
 import dns.name
@@ -10,18 +11,21 @@ from dns.rdataclass import RdataClass
 from dns.rdatatype import RdataType
 import dns.rdtypes.ANY.MX
 import dns.rdtypes.ANY.NS
+import dns.rdtypes.ANY.PTR
 import dns.rdtypes.ANY.SOA
 import dns.rdtypes.ANY.TXT
 import dns.rdtypes.IN.A
 import dns.rdtypes.IN.AAAA
 from dns.zone import Zone
-from ipaddress import ip_address, IPv4Address, IPv6Address
+from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
 import socket
 import sys
-from typing import Optional, Dict, List, Self
+from typing import Optional, Dict, List, Self, Tuple, DefaultDict
 
 
 IPAddress = IPv4Address | IPv6Address
+IPNetwork = IPv4Network | IPv6Network
+IPAddr = str | IPAddress | List[str | IPAddress]
 
 
 class NscNode:
@@ -34,11 +38,11 @@ class NscNode:
         self.nsc_zone = nsc_zone
         self.name = name
         self.node = nsc_zone.zone.find_node(name, create=True)
-        self._ttl = int(nsc_zone.min_ttl.total_seconds())
+        self._ttl = nsc_zone._min_ttl
 
     def ttl(self, *args, **kwargs) -> Self:
         if not args and not kwargs:
-            self._ttl = int(self.nsc_zone.min_ttl.total_seconds())
+            self._ttl = self.nsc_zone._min_ttl
         else:
             self._ttl = int(timedelta(*args, **kwargs).total_seconds())
         return self
@@ -47,15 +51,16 @@ class NscNode:
         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
         rds.add(rec, ttl=self._ttl)
 
-    def _parse_addrs(self, addrs: str | IPAddress | List[str | IPAddress]) -> List[IPAddress]:
-        if not isinstance(addrs, list):
-            addrs = [addrs]
+    def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
         out = []
         for a in addrs:
-            if isinstance(a, IPv4Address) or isinstance(a, IPv6Address):
-                out.append(a)
-            else:
-                out.append(ip_address(a))
+            if not isinstance(a, list):
+                a = [a]
+            for b in a:
+                if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
+                    out.append(b)
+                else:
+                    out.append(ip_address(b))
         return out
 
     def _parse_name(self, name: str) -> Name:
@@ -71,12 +76,14 @@ class NscNode:
         else:
             return [self._parse_name(n) for n in names]
 
-    def A(self, addrs: str | IPAddress | List[str | IPAddress]) -> Self:
+    def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
         for a in self._parse_addrs(addrs):
             if isinstance(a, IPv4Address):
                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
             else:
                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
+            if reverse:
+                self.nsc_zone.nsc._add_reverse_mapping(a, dns.name.from_text(self.name + '.' + self.nsc_zone.name))
         return self
 
     def MX(self, pri: int, name: str) -> Self:
@@ -86,6 +93,7 @@ class NscNode:
         return self
 
     def NS(self, names: str | List[str]) -> Self:
+        # FIXME: Variadic?
         for name in self._parse_names(names):
             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
         return self
@@ -94,55 +102,88 @@ class NscNode:
         self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
         return self
 
+    def PTR(self, target: Name | str) -> Self:
+        self._add(dns.rdtypes.ANY.PTR.PTR(RdataClass.IN, RdataType.PTR, target))
+        return self
+
     def generic(self, typ: str, text: str) -> Self:
         self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
+        return self
 
 
-class NscZone:
-    name: str
-    admin_email: Optional[str] = None
-    refresh: timedelta = timedelta(hours=8)
-    retry: timedelta = timedelta(hours=2)
-    expire: timedelta = timedelta(days=14)
-    min_ttl: timedelta = timedelta(days=1)
-    origin_server: Optional[str] = None
-    zone: Zone
+class NscZoneConfig:
+    admin_email: str
+    refresh: timedelta
+    retry: timedelta
+    expire: timedelta
+    min_ttl: timedelta
+    origin_server: str
+
+    default_config: Optional['NscZoneConfig'] = None
 
     def __init__(self,
-                 name: str,
                  admin_email: Optional[str] = None,
                  refresh: Optional[timedelta] = None,
                  retry: Optional[timedelta] = None,
                  expire: Optional[timedelta] = None,
                  min_ttl: Optional[timedelta] = None,
                  origin_server: Optional[str] = None,
+                 inherit_config: Optional['NscZoneConfig'] = None,
                  ) -> None:
-        self.name = name
-        self.admin_email = admin_email if admin_email is not None else self.admin_email
-        self.refresh = refresh if refresh is not None else self.refresh
-        self.retry = retry if retry is not None else self.retry
-        self.expire = expire if expire is not None else self.expire
-        self.min_ttl = min_ttl if min_ttl is not None else self.min_ttl
-        self.origin_server = origin_server if origin_server is not None else self.origin_server
-        self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
-
-        if self.origin_server is None:
+        if inherit_config is None:
+            inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
+        self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
+        self.refresh = refresh if refresh is not None else inherit_config.refresh
+        self.retry = retry if retry is not None else inherit_config.retry
+        self.expire = expire if expire is not None else inherit_config.expire
+        self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
+        self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
+
+    def finalize(self) -> Self:
+        if not self.origin_server:
             self.origin_server = socket.getfqdn()
+        if not self.admin_email:
+            self.admin_email = f'hostmaster@{self.origin_server}'
+        return self
+
+
+NscZoneConfig.default_config = NscZoneConfig(
+    admin_email="",
+    refresh=timedelta(hours=8),
+    retry=timedelta(hours=2),
+    expire=timedelta(days=14),
+    min_ttl=timedelta(days=1),
+    origin_server="",
+)
+
+
+class NscZone:
+    nsc: 'Nsc'
+    name: str
+    zone: Zone
+    _min_ttl: int
+    reverse_for: Optional[IPNetwork]
 
-        if self.admin_email is None:
-            self.admin_email = f'root@{self.origin_server}'
+    def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
+        self.nsc = nsc
+        self.name = name
+        self.config = NscZoneConfig(**kwargs).finalize()
+        self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
+        self._min_ttl = int(self.config.min_ttl.total_seconds())
+        self.reverse_for = reverse_for
 
+        conf = self.config
         root = self[""]
         root._add(
             dns.rdtypes.ANY.SOA.SOA(
                 RdataClass.IN, RdataType.SOA,
-                mname=self.origin_server,
-                rname=self.admin_email.replace('@', '.'),   # FIXME: names with dots
+                mname=conf.origin_server,
+                rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
                 serial=12345,
-                refresh=int(self.refresh.total_seconds()),
-                retry=int(self.retry.total_seconds()),
-                expire=int(self.expire.total_seconds()),
-                minimum=int(self.min_ttl.total_seconds()),
+                refresh=int(conf.refresh.total_seconds()),
+                retry=int(conf.retry.total_seconds()),
+                expire=int(conf.expire.total_seconds()),
+                minimum=int(conf.min_ttl.total_seconds()),
             )
         )
 
@@ -152,43 +193,121 @@ class NscZone:
     def __getitem__(self, name: str) -> NscNode:
         return NscNode(self, name)
 
+    def host(self, name: str, *args, reverse: bool = True) -> NscNode:
+        n = NscNode(self, name)
+        n.A(*args, reverse=reverse)
+        return n
+
     def dump(self) -> None:
         # Could use self.zone.to_file(sys.stdout), but we want better formatting
+        print(f'; Zone file for {self.name}')
         last_name = None
-        min_ttl = int(self.min_ttl.total_seconds())
         for name, ttl, rec in self.zone.iterate_rdatas():
             if name == last_name:
                 print_name = ""
             else:
                 print_name = name
-            print(f'{print_name}\t{ttl if ttl != min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
+            print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
             last_name = name
 
-
-class Config:
-    zones: Dict[str, Zone]
-
-    def __init__(self) -> None:
+    def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
+        # Called only for addresses from this reverse network
+        assert self.reverse_for is not None
+        parts = str(addr).split('.')
+        parts = parts[self.reverse_for.prefixlen // 8:]
+        name = '.'.join(reversed(parts))
+        self.n(name).PTR(ptr_to)
+
+    def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
+        # Called only for addresses from this reverse network
+        assert self.reverse_for is not None
+        parts = addr.exploded.replace(':', "")
+        parts = parts[self.reverse_for.prefixlen // 4:]
+        name = '.'.join(reversed(parts))
+        self.n(name).PTR(ptr_to)
+
+
+class Nsc:
+    zones: Dict[str, NscZone]
+    default_zone_config: NscZoneConfig
+    ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
+    ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
+
+    def __init__(self, **kwargs) -> None:
         self.zones = {}
+        self.default_zone_config = NscZoneConfig(**kwargs)
+        self.ipv4_reverse = defaultdict(list)
+        self.ipv6_reverse = defaultdict(list)
+
+    def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
+        if inherit_config is None:
+            inherit_config = self.default_zone_config
+        z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
+        assert z.name not in self.zones
+        self.zones[z.name] = z
+        return z
+
+    def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
+        if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
+            net = ip_network(net, strict=True)
+        name = name or self._reverse_zone_name(net)
+        return self.add_zone(name, reverse_for=net, **kwargs)
+
+    def _reverse_zone_name(self, net: IPNetwork) -> str:
+        if isinstance(net, IPv4Network):
+            parts = str(net.network_address).split('.')
+            out = parts[:net.prefixlen // 8]
+            if net.prefixlen % 8 != 0:
+                out.append(parts[len(out)] + '/' + str(net.prefixlen))
+            return '.'.join(reversed(out)) + '.in-addr.arpa'
+        elif isinstance(net, IPv6Network):
+            assert net.prefixlen % 4 == 0
+            nibbles = net.network_address.exploded.replace(':', "")
+            nibbles = nibbles[:net.prefixlen // 4]
+            return '.'.join(reversed(nibbles)) + '.ip6.arpa'
+        else:
+            raise NotImplementedError()
 
-    def add_zone(self, *args, **kwargs) -> Zone:
-        dom = NscZone(*args, **kwargs)
-        assert dom.name not in self.zones
-        self.zones[dom.name] = dom
-        return dom
-
-
-class MyZone(Zone):
-    admin_email = 'admin@ucw.cz'
-    origin_server = 'ns.ucw.cz'
-
-
-c = Config()
-z = c.add_zone('ucw.cz')  # origin_server='jabberwock.ucw.cz')
+    def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
+        if isinstance(addr, IPv4Address):
+            self.ipv4_reverse[addr].append(ptr_to)
+        else:
+            self.ipv6_reverse[addr].append(ptr_to)
+
+    def dump_reverse(self) -> None:
+        print('### Requests for reverse mappings ###')
+        for ipa4, name in sorted(self.ipv4_reverse.items()):
+            print(f'{ipa4}\t{name}')
+        for ipa6, name in sorted(self.ipv6_reverse.items()):
+            print(f'{ipa6}\t{name}')
+
+    def fill_reverse(self) -> None:
+        for z in self.zones.values():
+            if z.reverse_for is not None:
+                if isinstance(z.reverse_for, IPv4Network):
+                    for addr4, ptr_list in self.ipv4_reverse.items():
+                        if addr4 in z.reverse_for:
+                            for ptr_to in ptr_list:
+                                z._add_ipv4_reverse(addr4, ptr_to)
+                else:
+                    for addr6, ptr_list in self.ipv6_reverse.items():
+                        if addr6 in z.reverse_for:
+                            for ptr_to in ptr_list:
+                                z._add_ipv6_reverse(addr6, ptr_to)
+
+
+c = Nsc(
+    admin_email='admin@ucw.cz',
+    origin_server='ns.ucw.cz',
+)
+
+z = c.add_zone('ucw.cz')
 
 z[""].NS(['jabberwock', 'chirigo.gebbeth.cz', 'drak.ucw.cz'])
 
-z['jabberwock'].A(['1.2.3.4', '2a00:da80:fff0:2::2'])
+z['jabberwock'].A('1.2.3.4', '2a00:da80:fff0:2::2', '195.113.31.123')
+
+z.host('test', '1.2.3.4', ['5.6.7.8', '8.7.6.5'])
 
 (z['mnau']
     .A('195.113.31.123')
@@ -198,3 +317,12 @@ z['jabberwock'].A(['1.2.3.4', '2a00:da80:fff0:2::2'])
     .generic('HINFO', 'Something fishy'))
 
 z.dump()
+
+r = c.add_reverse_zone('195.113.0.0/16')
+r2 = c.add_reverse_zone('2a00:da80:fff0:2::/64')
+
+c.dump_reverse()
+c.fill_reverse()
+
+r.dump()
+r2.dump()