]> mj.ucw.cz Git - pynsc.git/commitdiff
Refactor primary/secondary to use a class hierarchy
authorMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 23:38:46 +0000 (01:38 +0200)
committerMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 23:38:46 +0000 (01:38 +0200)
nsconfig/core.py

index f223e687d61eed510396b775d4e0e242511d5791..f48a5be4780ac3edf02cf3ea45fb6ca3f65b2049 100644 (file)
@@ -30,12 +30,12 @@ IPAddr = str | IPAddress | List[str | IPAddress]
 
 
 class NscNode:
-    nsc_zone: 'NscZone'
+    nsc_zone: 'NscZonePrimary'
     name: str
     node: Node
     _ttl: int
 
-    def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
+    def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
         self.nsc_zone = nsc_zone
         self.name = name
         self.node = nsc_zone.zone.find_node(name, create=True)
@@ -198,71 +198,48 @@ class NscZone:
     nsc: 'Nsc'
     name: str
     safe_name: str                          # For use in file names
-    zone: Zone
-    _min_ttl: int
-    reverse_for: Optional[IPNetwork]
     zone_type: ZoneType
-    primary_server: Optional[IPAddress]     # For secondary zones
-    zone_file: Path
-    state_file: Path
-    state: NscZoneState
-    prev_state: NscZoneState
+    reverse_for: Optional[IPNetwork]
 
     def __init__(self,
                  nsc: 'Nsc',
-                 name: Optional[str] = None,
-                 reverse_for: str | IPNetwork | None = None,
-                 secondary_for: str | IPAddress | None = None,
+                 name: str,
+                 reverse_for: Optional[IPNetwork],
                  **kwargs) -> None:
-        if reverse_for is not None:
-            if isinstance(reverse_for, str):
-                reverse_for = ip_network(reverse_for, strict=True)
-            name = name or self._reverse_zone_name(reverse_for)
-        assert name is not None
-
-        if isinstance(secondary_for, str):
-            secondary_for = ip_address(secondary_for)
-
         self.nsc = nsc
         self.name = name
         self.safe_name = name.replace('/', '@')
         self.config = NscZoneConfig(**kwargs).finalize()
         self.reverse_for = reverse_for
-        self.primary_server = secondary_for
 
-        if not secondary_for:
-            self.zone_type = ZoneType.primary
-            self.zone_file = nsc.zone_dir / self.safe_name
-            self.state_file = nsc.state_dir / (self.safe_name + '.json')
+    def process(self) -> None:
+        pass
 
-            self.state = NscZoneState()
-            self.prev_state = NscZoneState()
-            self.prev_state.load(self.state_file)
 
-            self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
-            self._min_ttl = int(self.config.min_ttl.total_seconds())
-            self.update_soa()
-        else:
-            self.zone_type = ZoneType.secondary
-            self.zone_file = nsc.secondary_dir / self.safe_name
+class NscZonePrimary(NscZone):
+    zone: Zone
+    _min_ttl: int
+    zone_file: Path
+    state_file: Path
+    state: NscZoneState
+    prev_state: NscZoneState
 
-    def _reverse_zone_name(self, net: IPNetwork) -> str:
-        if isinstance(net, IPv4Network):
-            parts = str(net.network_address).split('.')
-            out = parts[:net.prefixlen // 8]
-            if net.prefixlen % 8 != 0:
-                out.append(parts[len(out)] + '/' + str(net.prefixlen))
-            return '.'.join(reversed(out)) + '.in-addr.arpa'
-        elif isinstance(net, IPv6Network):
-            assert net.prefixlen % 4 == 0
-            nibbles = net.network_address.exploded.replace(':', "")
-            nibbles = nibbles[:net.prefixlen // 4]
-            return '.'.join(reversed(nibbles)) + '.ip6.arpa'
-        else:
-            raise NotImplementedError()
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+        self.zone_type = ZoneType.primary
+        self.zone_file = self.nsc.zone_dir / self.safe_name
+        self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
+
+        self.state = NscZoneState()
+        self.prev_state = NscZoneState()
+        self.prev_state.load(self.state_file)
+
+        self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
+        self._min_ttl = int(self.config.min_ttl.total_seconds())
+        self.update_soa()
 
     def update_soa(self) -> None:
-        assert self.zone_type == ZoneType.primary
         conf = self.config
         soa = dns.rdtypes.ANY.SOA.SOA(
             RdataClass.IN, RdataType.SOA,
@@ -302,16 +279,14 @@ class NscZone:
             last_name = name
 
     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
-        # Called only for addresses from this reverse network
-        assert self.reverse_for is not None
+        assert isinstance(self.reverse_for, IPv4Network)
         parts = str(addr).split('.')
         parts = parts[self.reverse_for.prefixlen // 8:]
         name = '.'.join(reversed(parts))
         self.n(name).PTR(ptr_to)
 
     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
-        # Called only for addresses from this reverse network
-        assert self.reverse_for is not None
+        assert isinstance(self.reverse_for, IPv6Network)
         parts = addr.exploded.replace(':', "")
         parts = parts[self.reverse_for.prefixlen // 4:]
         name = '.'.join(reversed(parts))
@@ -343,7 +318,6 @@ class NscZone:
             self.gen_serial()
 
     def write_zone(self) -> None:
-        assert self.zone_type == ZoneType.primary
         self.update_soa()
         new_file = Path(str(self.zone_file) + '.new')
         with open(new_file, 'w') as f:
@@ -351,10 +325,20 @@ class NscZone:
         new_file.replace(self.zone_file)
 
     def write_state(self) -> None:
-        assert self.zone_type == ZoneType.primary
         self.state.save(self.state_file)
 
 
+class NscZoneSecondary(NscZone):
+    primary_server: IPAddress
+    secondary_file: Path
+
+    def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.zone_type = ZoneType.secondary
+        self.primary_server = primary_server
+        self.secondary_file = self.nsc.secondary_dir / self.safe_name
+
+
 class Nsc:
     start_time: datetime
     zones: Dict[str, NscZone]
@@ -381,14 +365,48 @@ class Nsc:
         self.secondary_dir = self.root_dir / 'secondary'
         self.secondary_dir.mkdir(parents=True, exist_ok=True)
 
-    def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
+    def add_zone(self,
+                 name: Optional[str] = None,
+                 reverse_for: str | IPNetwork | None = None,
+                 secondary_for: str | IPAddress | None = None,
+                 inherit_config: Optional[NscZoneConfig] = None,
+                 **kwargs) -> Zone:
         if inherit_config is None:
             inherit_config = self.default_zone_config
-        z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
-        assert z.name not in self.zones
-        self.zones[z.name] = z
+
+        if reverse_for is not None:
+            if isinstance(reverse_for, str):
+                reverse_for = ip_network(reverse_for, strict=True)
+            name = name or self._reverse_zone_name(reverse_for)
+        assert name is not None
+        assert name not in self.zones
+
+        z: NscZone
+        if secondary_for is None:
+            z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
+        else:
+            if isinstance(secondary_for, str):
+                secondary_for = ip_address(secondary_for)
+            z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=secondary_for, inherit_config=inherit_config, **kwargs)
+
+        self.zones[name] = z
         return z
 
+    def _reverse_zone_name(self, net: IPNetwork) -> str:
+        if isinstance(net, IPv4Network):
+            parts = str(net.network_address).split('.')
+            out = parts[:net.prefixlen // 8]
+            if net.prefixlen % 8 != 0:
+                out.append(parts[len(out)] + '/' + str(net.prefixlen))
+            return '.'.join(reversed(out)) + '.in-addr.arpa'
+        elif isinstance(net, IPv6Network):
+            assert net.prefixlen % 4 == 0
+            nibbles = net.network_address.exploded.replace(':', "")
+            nibbles = nibbles[:net.prefixlen // 4]
+            return '.'.join(reversed(nibbles)) + '.ip6.arpa'
+        else:
+            raise NotImplementedError()
+
     def _add_reverse_mapping(self, addr: IPAddress, ptr_to: Name) -> None:
         if isinstance(addr, IPv4Address):
             self.ipv4_reverse[addr].append(ptr_to)
@@ -404,7 +422,7 @@ class Nsc:
 
     def fill_reverse(self) -> None:
         for z in self.zones.values():
-            if z.zone_type == ZoneType.primary and z.reverse_for is not None:
+            if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
                 if isinstance(z.reverse_for, IPv4Network):
                     for addr4, ptr_list in self.ipv4_reverse.items():
                         if addr4 in z.reverse_for: