]> mj.ucw.cz Git - pynsc.git/commitdiff
Support for secondary zones
authorMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 23:21:44 +0000 (01:21 +0200)
committerMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 23:21:44 +0000 (01:21 +0200)
example/__init__.py
nsconfig/cli.py
nsconfig/core.py

index 8309b1b0ac3b68651466ead212fd2db39ab9d4ed..980825075e08d718cc6175f7456196476b942321 100644 (file)
@@ -6,7 +6,9 @@ nsc = Nsc(
 )
 
 for rev in ['10.1.0.0/16', '10.2.0.0/16', 'fd12:3456:789a::/48']:
-    rz = nsc.add_reverse_zone(rev)
+    rz = nsc.add_zone(reverse_for=rev)
     rz[""].NS(['ns1.example.org', 'ns2.example.org'])
 
+nsc.add_zone('example.net', secondary_for='10.42.0.1')
+
 import example.example_org
index 241ede47d051a92559b75cc36b1ca12fe750837c..9d342c071ab232256db4c2f2add91a252af4a992 100644 (file)
@@ -2,7 +2,7 @@ import argparse
 from pathlib import Path
 from texttable import Texttable
 
-from nsconfig.core import Nsc
+from nsconfig.core import Nsc, ZoneType
 
 
 def do_test(nsc: Nsc) -> None:
@@ -10,14 +10,18 @@ def do_test(nsc: Nsc) -> None:
     test_dir.mkdir(exist_ok=True)
     for z in nsc.get_zones():
         print(f'Zone:        {z.name}')
-        print(f'Old serial:  {z.prev_state.serial}')
-        print(f'Old hash:    {z.prev_state.hash}')
-        print(f'New serial:  {z.state.serial}')
-        print(f'New hash:    {z.state.hash}')
-        out_file = test_dir / z.safe_name
-        print(f'Dumping to:  {out_file}')
-        with open(out_file, 'w') as f:
-            z.dump(file=f)
+        print(f'Type:        {z.zone_type.name}')
+        if z.zone_type == ZoneType.primary:
+            print(f'Old serial:  {z.prev_state.serial}')
+            print(f'Old hash:    {z.prev_state.hash}')
+            print(f'New serial:  {z.state.serial}')
+            print(f'New hash:    {z.state.hash}')
+            out_file = test_dir / z.safe_name
+            print(f'Dumping to:  {out_file}')
+            with open(out_file, 'w') as f:
+                z.dump(file=f)
+        else:
+            print(f'Primary:     {z.primary_server}')
         print()
 
 
@@ -28,6 +32,9 @@ def do_status(nsc: Nsc) -> None:
     table.set_deco(Texttable.HEADER)
 
     for z in nsc.get_zones():
+        if z.zone_type != ZoneType.primary:
+            table.add_row([z.name, 'secondary', "", "", "", ""])
+            continue
         if z.state.serial == z.prev_state.serial:
             action = ""
         else:
@@ -46,7 +53,7 @@ def do_status(nsc: Nsc) -> None:
 
 def do_update(nsc: Nsc) -> None:
     for z in nsc.get_zones():
-        if z.state.serial != z.prev_state.serial:
+        if z.zone_type == ZoneType.primary and z.state.serial != z.prev_state.serial:
             print(f'Updating zone {z.name} (serial {z.state.serial})')
             z.write_zone()
             z.write_state()
index ce19bd98d73e759b2d9de2b1c3f30433a4954243..f223e687d61eed510396b775d4e0e242511d5791 100644 (file)
@@ -14,6 +14,7 @@ import dns.rdtypes.ANY.TXT
 import dns.rdtypes.IN.A
 import dns.rdtypes.IN.AAAA
 from dns.zone import Zone
+from enum import Enum, auto
 import hashlib
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
 import json
@@ -188,36 +189,80 @@ class NscZoneState:
         new_file.replace(file)
 
 
+class ZoneType(Enum):
+    primary = auto()
+    secondary = auto()
+
+
 class NscZone:
     nsc: 'Nsc'
     name: str
-    safe_name: str      # For use in file names
+    safe_name: str                          # For use in file names
     zone: Zone
     _min_ttl: int
     reverse_for: Optional[IPNetwork]
+    zone_type: ZoneType
+    primary_server: Optional[IPAddress]     # For secondary zones
     zone_file: Path
     state_file: Path
     state: NscZoneState
     prev_state: NscZoneState
 
-    def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
+    def __init__(self,
+                 nsc: 'Nsc',
+                 name: Optional[str] = None,
+                 reverse_for: str | IPNetwork | None = None,
+                 secondary_for: str | IPAddress | None = None,
+                 **kwargs) -> None:
+        if reverse_for is not None:
+            if isinstance(reverse_for, str):
+                reverse_for = ip_network(reverse_for, strict=True)
+            name = name or self._reverse_zone_name(reverse_for)
+        assert name is not None
+
+        if isinstance(secondary_for, str):
+            secondary_for = ip_address(secondary_for)
+
         self.nsc = nsc
         self.name = name
         self.safe_name = name.replace('/', '@')
         self.config = NscZoneConfig(**kwargs).finalize()
-        self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
-        self._min_ttl = int(self.config.min_ttl.total_seconds())
         self.reverse_for = reverse_for
+        self.primary_server = secondary_for
 
-        self.zone_file = nsc.zone_dir / self.safe_name
-        self.state_file = nsc.state_dir / (self.safe_name + '.json')
-        self.state = NscZoneState()
-        self.prev_state = NscZoneState()
-        self.prev_state.load(self.state_file)
+        if not secondary_for:
+            self.zone_type = ZoneType.primary
+            self.zone_file = nsc.zone_dir / self.safe_name
+            self.state_file = nsc.state_dir / (self.safe_name + '.json')
 
-        self.update_soa()
+            self.state = NscZoneState()
+            self.prev_state = NscZoneState()
+            self.prev_state.load(self.state_file)
+
+            self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
+            self._min_ttl = int(self.config.min_ttl.total_seconds())
+            self.update_soa()
+        else:
+            self.zone_type = ZoneType.secondary
+            self.zone_file = nsc.secondary_dir / self.safe_name
+
+    def _reverse_zone_name(self, net: IPNetwork) -> str:
+        if isinstance(net, IPv4Network):
+            parts = str(net.network_address).split('.')
+            out = parts[:net.prefixlen // 8]
+            if net.prefixlen % 8 != 0:
+                out.append(parts[len(out)] + '/' + str(net.prefixlen))
+            return '.'.join(reversed(out)) + '.in-addr.arpa'
+        elif isinstance(net, IPv6Network):
+            assert net.prefixlen % 4 == 0
+            nibbles = net.network_address.exploded.replace(':', "")
+            nibbles = nibbles[:net.prefixlen // 4]
+            return '.'.join(reversed(nibbles)) + '.ip6.arpa'
+        else:
+            raise NotImplementedError()
 
     def update_soa(self) -> None:
+        assert self.zone_type == ZoneType.primary
         conf = self.config
         soa = dns.rdtypes.ANY.SOA.SOA(
             RdataClass.IN, RdataType.SOA,
@@ -293,10 +338,12 @@ class NscZone:
                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
 
     def process(self) -> None:
-        self.gen_hash()
-        self.gen_serial()
+        if self.zone_type == ZoneType.primary:
+            self.gen_hash()
+            self.gen_serial()
 
     def write_zone(self) -> None:
+        assert self.zone_type == ZoneType.primary
         self.update_soa()
         new_file = Path(str(self.zone_file) + '.new')
         with open(new_file, 'w') as f:
@@ -304,6 +351,7 @@ class NscZone:
         new_file.replace(self.zone_file)
 
     def write_state(self) -> None:
+        assert self.zone_type == ZoneType.primary
         self.state.save(self.state_file)
 
 
@@ -316,6 +364,7 @@ class Nsc:
     root_dir: Path
     state_dir: Path
     zone_dir: Path
+    secondary_dir: Path
 
     def __init__(self, directory: str = '.', **kwargs) -> None:
         self.start_time = datetime.now()
@@ -329,6 +378,8 @@ class Nsc:
         self.state_dir.mkdir(parents=True, exist_ok=True)
         self.zone_dir = self.root_dir / 'zone'
         self.zone_dir.mkdir(parents=True, exist_ok=True)
+        self.secondary_dir = self.root_dir / 'secondary'
+        self.secondary_dir.mkdir(parents=True, exist_ok=True)
 
     def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
         if inherit_config is None:
@@ -338,27 +389,6 @@ class Nsc:
         self.zones[z.name] = z
         return z
 
-    def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
-        if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
-            net = ip_network(net, strict=True)
-        name = name or self._reverse_zone_name(net)
-        return self.add_zone(name, reverse_for=net, **kwargs)
-
-    def _reverse_zone_name(self, net: IPNetwork) -> str:
-        if isinstance(net, IPv4Network):
-            parts = str(net.network_address).split('.')
-            out = parts[:net.prefixlen // 8]
-            if net.prefixlen % 8 != 0:
-                out.append(parts[len(out)] + '/' + str(net.prefixlen))
-            return '.'.join(reversed(out)) + '.in-addr.arpa'
-        elif isinstance(net, IPv6Network):
-            assert net.prefixlen % 4 == 0
-            nibbles = net.network_address.exploded.replace(':', "")
-            nibbles = nibbles[:net.prefixlen // 4]
-            return '.'.join(reversed(nibbles)) + '.ip6.arpa'
-        else:
-            raise NotImplementedError()
-
     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
         if isinstance(addr, IPv4Address):
             self.ipv4_reverse[addr].append(ptr_to)
@@ -374,7 +404,7 @@ class Nsc:
 
     def fill_reverse(self) -> None:
         for z in self.zones.values():
-            if z.reverse_for is not None:
+            if z.zone_type == ZoneType.primary and z.reverse_for is not None:
                 if isinstance(z.reverse_for, IPv4Network):
                     for addr4, ptr_list in self.ipv4_reverse.items():
                         if addr4 in z.reverse_for: