]> mj.ucw.cz Git - pynsc.git/blobdiff - nsconfig/core.py
More daemon configuration
[pynsc.git] / nsconfig / core.py
index ce19bd98d73e759b2d9de2b1c3f30433a4954243..8a3bb74f694453655887758f39063fa20182779e 100644 (file)
@@ -14,13 +14,18 @@ import dns.rdtypes.ANY.TXT
 import dns.rdtypes.IN.A
 import dns.rdtypes.IN.AAAA
 from dns.zone import Zone
 import dns.rdtypes.IN.A
 import dns.rdtypes.IN.AAAA
 from dns.zone import Zone
+from enum import Enum, auto
 import hashlib
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
 import json
 from pathlib import Path
 import socket
 import sys
 import hashlib
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
 import json
 from pathlib import Path
 import socket
 import sys
-from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
+from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    from nsconfig.daemon import NscDaemon
 
 
 IPAddress = IPv4Address | IPv6Address
 
 
 IPAddress = IPv4Address | IPv6Address
@@ -29,12 +34,12 @@ IPAddr = str | IPAddress | List[str | IPAddress]
 
 
 class NscNode:
 
 
 class NscNode:
-    nsc_zone: 'NscZone'
+    nsc_zone: 'NscZonePrimary'
     name: str
     node: Node
     _ttl: int
 
     name: str
     node: Node
     _ttl: int
 
-    def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
+    def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
         self.nsc_zone = nsc_zone
         self.name = name
         self.node = nsc_zone.zone.find_node(name, create=True)
         self.nsc_zone = nsc_zone
         self.name = name
         self.node = nsc_zone.zone.find_node(name, create=True)
@@ -188,33 +193,54 @@ class NscZoneState:
         new_file.replace(file)
 
 
         new_file.replace(file)
 
 
+class ZoneType(Enum):
+    primary = auto()
+    secondary = auto()
+
+
 class NscZone:
     nsc: 'Nsc'
     name: str
 class NscZone:
     nsc: 'Nsc'
     name: str
-    safe_name: str      # For use in file names
-    zone: Zone
-    _min_ttl: int
+    safe_name: str                          # For use in file names
+    zone_type: ZoneType
     reverse_for: Optional[IPNetwork]
     reverse_for: Optional[IPNetwork]
-    zone_file: Path
-    state_file: Path
-    state: NscZoneState
-    prev_state: NscZoneState
 
 
-    def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
+    def __init__(self,
+                 nsc: 'Nsc',
+                 name: str,
+                 reverse_for: Optional[IPNetwork],
+                 **kwargs) -> None:
         self.nsc = nsc
         self.name = name
         self.safe_name = name.replace('/', '@')
         self.config = NscZoneConfig(**kwargs).finalize()
         self.nsc = nsc
         self.name = name
         self.safe_name = name.replace('/', '@')
         self.config = NscZoneConfig(**kwargs).finalize()
-        self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
-        self._min_ttl = int(self.config.min_ttl.total_seconds())
         self.reverse_for = reverse_for
 
         self.reverse_for = reverse_for
 
-        self.zone_file = nsc.zone_dir / self.safe_name
-        self.state_file = nsc.state_dir / (self.safe_name + '.json')
+    def process(self) -> None:
+        pass
+
+
+class NscZonePrimary(NscZone):
+    zone: Zone
+    _min_ttl: int
+    zone_file: Path
+    state_file: Path
+    state: NscZoneState
+    prev_state: NscZoneState
+
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+        self.zone_type = ZoneType.primary
+        self.zone_file = self.nsc.zone_dir / self.safe_name
+        self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
+
         self.state = NscZoneState()
         self.prev_state = NscZoneState()
         self.prev_state.load(self.state_file)
 
         self.state = NscZoneState()
         self.prev_state = NscZoneState()
         self.prev_state.load(self.state_file)
 
+        self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
+        self._min_ttl = int(self.config.min_ttl.total_seconds())
         self.update_soa()
 
     def update_soa(self) -> None:
         self.update_soa()
 
     def update_soa(self) -> None:
@@ -257,16 +283,14 @@ class NscZone:
             last_name = name
 
     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
             last_name = name
 
     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
-        # Called only for addresses from this reverse network
-        assert self.reverse_for is not None
+        assert isinstance(self.reverse_for, IPv4Network)
         parts = str(addr).split('.')
         parts = parts[self.reverse_for.prefixlen // 8:]
         name = '.'.join(reversed(parts))
         self.n(name).PTR(ptr_to)
 
     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
         parts = str(addr).split('.')
         parts = parts[self.reverse_for.prefixlen // 8:]
         name = '.'.join(reversed(parts))
         self.n(name).PTR(ptr_to)
 
     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
-        # Called only for addresses from this reverse network
-        assert self.reverse_for is not None
+        assert isinstance(self.reverse_for, IPv6Network)
         parts = addr.exploded.replace(':', "")
         parts = parts[self.reverse_for.prefixlen // 4:]
         name = '.'.join(reversed(parts))
         parts = addr.exploded.replace(':', "")
         parts = parts[self.reverse_for.prefixlen // 4:]
         name = '.'.join(reversed(parts))
@@ -293,8 +317,9 @@ class NscZone:
                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
 
     def process(self) -> None:
                     print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
 
     def process(self) -> None:
-        self.gen_hash()
-        self.gen_serial()
+        if self.zone_type == ZoneType.primary:
+            self.gen_hash()
+            self.gen_serial()
 
     def write_zone(self) -> None:
         self.update_soa()
 
     def write_zone(self) -> None:
         self.update_soa()
@@ -306,6 +331,20 @@ class NscZone:
     def write_state(self) -> None:
         self.state.save(self.state_file)
 
     def write_state(self) -> None:
         self.state.save(self.state_file)
 
+    def is_changed(self) -> bool:
+        return self.state.serial != self.prev_state.serial
+
+
+class NscZoneSecondary(NscZone):
+    primary_server: IPAddress
+    secondary_file: Path
+
+    def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.zone_type = ZoneType.secondary
+        self.primary_server = primary_server
+        self.secondary_file = self.nsc.secondary_dir / self.safe_name
+
 
 class Nsc:
     start_time: datetime
 
 class Nsc:
     start_time: datetime
@@ -316,8 +355,13 @@ class Nsc:
     root_dir: Path
     state_dir: Path
     zone_dir: Path
     root_dir: Path
     state_dir: Path
     zone_dir: Path
+    secondary_dir: Path
+    daemon: 'NscDaemon'  # Set by DaemonConfig class
 
 
-    def __init__(self, directory: str = '.', **kwargs) -> None:
+    def __init__(self,
+                 directory: str = '.',
+                 daemon: Optional['NscDaemon'] = None,
+                 **kwargs) -> None:
         self.start_time = datetime.now()
         self.zones = {}
         self.default_zone_config = NscZoneConfig(**kwargs)
         self.start_time = datetime.now()
         self.zones = {}
         self.default_zone_config = NscZoneConfig(**kwargs)
@@ -329,20 +373,41 @@ class Nsc:
         self.state_dir.mkdir(parents=True, exist_ok=True)
         self.zone_dir = self.root_dir / 'zone'
         self.zone_dir.mkdir(parents=True, exist_ok=True)
         self.state_dir.mkdir(parents=True, exist_ok=True)
         self.zone_dir = self.root_dir / 'zone'
         self.zone_dir.mkdir(parents=True, exist_ok=True)
-
-    def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
+        self.secondary_dir = self.root_dir / 'secondary'
+        self.secondary_dir.mkdir(parents=True, exist_ok=True)
+
+        if daemon is None:
+            from nsconfig.daemon import NscDaemonNull
+            daemon = NscDaemonNull()
+        self.daemon = daemon
+        daemon.setup(self)
+
+    def add_zone(self,
+                 name: Optional[str] = None,
+                 reverse_for: str | IPNetwork | None = None,
+                 follow_primary: str | IPAddress | None = None,
+                 inherit_config: Optional[NscZoneConfig] = None,
+                 **kwargs) -> Zone:
         if inherit_config is None:
             inherit_config = self.default_zone_config
         if inherit_config is None:
             inherit_config = self.default_zone_config
-        z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
-        assert z.name not in self.zones
-        self.zones[z.name] = z
-        return z
 
 
-    def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
-        if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
-            net = ip_network(net, strict=True)
-        name = name or self._reverse_zone_name(net)
-        return self.add_zone(name, reverse_for=net, **kwargs)
+        if reverse_for is not None:
+            if isinstance(reverse_for, str):
+                reverse_for = ip_network(reverse_for, strict=True)
+            name = name or self._reverse_zone_name(reverse_for)
+        assert name is not None
+        assert name not in self.zones
+
+        z: NscZone
+        if follow_primary is None:
+            z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
+        else:
+            if isinstance(follow_primary, str):
+                follow_primary = ip_address(follow_primary)
+            z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
+
+        self.zones[name] = z
+        return z
 
     def _reverse_zone_name(self, net: IPNetwork) -> str:
         if isinstance(net, IPv4Network):
 
     def _reverse_zone_name(self, net: IPNetwork) -> str:
         if isinstance(net, IPv4Network):
@@ -374,7 +439,7 @@ class Nsc:
 
     def fill_reverse(self) -> None:
         for z in self.zones.values():
 
     def fill_reverse(self) -> None:
         for z in self.zones.values():
-            if z.reverse_for is not None:
+            if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
                 if isinstance(z.reverse_for, IPv4Network):
                     for addr4, ptr_list in self.ipv4_reverse.items():
                         if addr4 in z.reverse_for:
                 if isinstance(z.reverse_for, IPv4Network):
                     for addr4, ptr_list in self.ipv4_reverse.items():
                         if addr4 in z.reverse_for: