]> mj.ucw.cz Git - pynsc.git/blobdiff - nsconfig/core.py
More daemon configuration
[pynsc.git] / nsconfig / core.py
index 86f56ae90f7cd9b8e403cb773f095819fa35e7ab..8a3bb74f694453655887758f39063fa20182779e 100644 (file)
@@ -1,5 +1,5 @@
 from collections import defaultdict
-from datetime import timedelta
+from datetime import datetime, timedelta
 import dns.name
 from dns.name import Name
 from dns.node import Node
@@ -14,9 +14,18 @@ import dns.rdtypes.ANY.TXT
 import dns.rdtypes.IN.A
 import dns.rdtypes.IN.AAAA
 from dns.zone import Zone
+from enum import Enum, auto
+import hashlib
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
+import json
+from pathlib import Path
 import socket
-from typing import Optional, Dict, List, Self, Tuple, DefaultDict
+import sys
+from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    from nsconfig.daemon import NscDaemon
 
 
 IPAddress = IPv4Address | IPv6Address
@@ -25,12 +34,12 @@ IPAddr = str | IPAddress | List[str | IPAddress]
 
 
 class NscNode:
-    nsc_zone: 'NscZone'
+    nsc_zone: 'NscZonePrimary'
     name: str
     node: Node
     _ttl: int
 
-    def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
+    def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
         self.nsc_zone = nsc_zone
         self.name = name
         self.node = nsc_zone.zone.find_node(name, create=True)
@@ -153,35 +162,101 @@ NscZoneConfig.default_config = NscZoneConfig(
 )
 
 
+class NscZoneState:
+    serial: int
+    hash: str
+
+    def __init__(self) -> None:
+        self.serial = 0
+        self.hash = 'none'
+
+    def load(self, file: Path) -> None:
+        try:
+            with open(file) as f:
+                js = json.load(f)
+                assert isinstance(js, dict)
+                if 'serial' in js:
+                    self.serial = js['serial']
+                if 'hash' in js:
+                    self.hash = js['hash']
+        except FileNotFoundError:
+            pass
+
+    def save(self, file: Path) -> None:
+        new_file = Path(str(file) + '.new')
+        with open(new_file, 'w') as f:
+            js = {
+                'serial': self.serial,
+                'hash': self.hash,
+            }
+            json.dump(js, f, indent=4, sort_keys=True)
+        new_file.replace(file)
+
+
+class ZoneType(Enum):
+    primary = auto()
+    secondary = auto()
+
+
 class NscZone:
     nsc: 'Nsc'
     name: str
-    zone: Zone
-    _min_ttl: int
+    safe_name: str                          # For use in file names
+    zone_type: ZoneType
     reverse_for: Optional[IPNetwork]
 
-    def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
+    def __init__(self,
+                 nsc: 'Nsc',
+                 name: str,
+                 reverse_for: Optional[IPNetwork],
+                 **kwargs) -> None:
         self.nsc = nsc
         self.name = name
+        self.safe_name = name.replace('/', '@')
         self.config = NscZoneConfig(**kwargs).finalize()
-        self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
-        self._min_ttl = int(self.config.min_ttl.total_seconds())
         self.reverse_for = reverse_for
 
+    def process(self) -> None:
+        pass
+
+
+class NscZonePrimary(NscZone):
+    zone: Zone
+    _min_ttl: int
+    zone_file: Path
+    state_file: Path
+    state: NscZoneState
+    prev_state: NscZoneState
+
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+        self.zone_type = ZoneType.primary
+        self.zone_file = self.nsc.zone_dir / self.safe_name
+        self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
+
+        self.state = NscZoneState()
+        self.prev_state = NscZoneState()
+        self.prev_state.load(self.state_file)
+
+        self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
+        self._min_ttl = int(self.config.min_ttl.total_seconds())
+        self.update_soa()
+
+    def update_soa(self) -> None:
         conf = self.config
-        root = self[""]
-        root._add(
-            dns.rdtypes.ANY.SOA.SOA(
-                RdataClass.IN, RdataType.SOA,
-                mname=conf.origin_server,
-                rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
-                serial=12345,
-                refresh=int(conf.refresh.total_seconds()),
-                retry=int(conf.retry.total_seconds()),
-                expire=int(conf.expire.total_seconds()),
-                minimum=int(conf.min_ttl.total_seconds()),
-            )
+        soa = dns.rdtypes.ANY.SOA.SOA(
+            RdataClass.IN, RdataType.SOA,
+            mname=conf.origin_server,
+            rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
+            serial=self.state.serial,
+            refresh=int(conf.refresh.total_seconds()),
+            retry=int(conf.retry.total_seconds()),
+            expire=int(conf.expire.total_seconds()),
+            minimum=int(conf.min_ttl.total_seconds()),
         )
+        self.zone.delete_rdataset("", RdataType.SOA)
+        self[""]._add(soa)
 
     def n(self, name: str) -> NscNode:
         return NscNode(self, name)
@@ -194,60 +269,145 @@ class NscZone:
         n.A(*args, reverse=reverse)
         return n
 
-    def dump(self) -> None:
+    def dump(self, file: Optional[TextIO] = None) -> None:
         # Could use self.zone.to_file(sys.stdout), but we want better formatting
-        print(f'; Zone file for {self.name}')
+        file = file or sys.stdout
+        file.write(f'; Zone file for {self.name}\n\n')
         last_name = None
         for name, ttl, rec in self.zone.iterate_rdatas():
             if name == last_name:
                 print_name = ""
             else:
                 print_name = name
-            print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
+            file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
             last_name = name
 
     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
-        # Called only for addresses from this reverse network
-        assert self.reverse_for is not None
+        assert isinstance(self.reverse_for, IPv4Network)
         parts = str(addr).split('.')
         parts = parts[self.reverse_for.prefixlen // 8:]
         name = '.'.join(reversed(parts))
         self.n(name).PTR(ptr_to)
 
     def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
-        # Called only for addresses from this reverse network
-        assert self.reverse_for is not None
+        assert isinstance(self.reverse_for, IPv6Network)
         parts = addr.exploded.replace(':', "")
         parts = parts[self.reverse_for.prefixlen // 4:]
         name = '.'.join(reversed(parts))
         self.n(name).PTR(ptr_to)
 
+    def gen_hash(self) -> None:
+        sha = hashlib.sha1()
+        for name, ttl, rec in self.zone.iterate_rdatas():
+            text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
+            sha.update(text.encode('us-ascii'))
+        self.state.hash = sha.hexdigest()[:16]
+
+    def gen_serial(self) -> None:
+        prev = self.prev_state.serial
+        if self.state.hash == self.prev_state.hash and prev > 0:
+            self.state.serial = self.prev_state.serial
+        else:
+            base = int(self.nsc.start_time.strftime('%Y%m%d00'))
+            if prev <= base:
+                self.state.serial = base + 1
+            else:
+                self.state.serial = prev + 1
+                if prev >= base + 99:
+                    print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
+
+    def process(self) -> None:
+        if self.zone_type == ZoneType.primary:
+            self.gen_hash()
+            self.gen_serial()
+
+    def write_zone(self) -> None:
+        self.update_soa()
+        new_file = Path(str(self.zone_file) + '.new')
+        with open(new_file, 'w') as f:
+            self.dump(file=f)
+        new_file.replace(self.zone_file)
+
+    def write_state(self) -> None:
+        self.state.save(self.state_file)
+
+    def is_changed(self) -> bool:
+        return self.state.serial != self.prev_state.serial
+
+
+class NscZoneSecondary(NscZone):
+    primary_server: IPAddress
+    secondary_file: Path
+
+    def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.zone_type = ZoneType.secondary
+        self.primary_server = primary_server
+        self.secondary_file = self.nsc.secondary_dir / self.safe_name
+
 
 class Nsc:
+    start_time: datetime
     zones: Dict[str, NscZone]
     default_zone_config: NscZoneConfig
     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
+    root_dir: Path
+    state_dir: Path
+    zone_dir: Path
+    secondary_dir: Path
+    daemon: 'NscDaemon'  # Set by DaemonConfig class
 
-    def __init__(self, **kwargs) -> None:
+    def __init__(self,
+                 directory: str = '.',
+                 daemon: Optional['NscDaemon'] = None,
+                 **kwargs) -> None:
+        self.start_time = datetime.now()
         self.zones = {}
         self.default_zone_config = NscZoneConfig(**kwargs)
         self.ipv4_reverse = defaultdict(list)
         self.ipv6_reverse = defaultdict(list)
 
-    def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
+        self.root_dir = Path(directory)
+        self.state_dir = self.root_dir / 'state'
+        self.state_dir.mkdir(parents=True, exist_ok=True)
+        self.zone_dir = self.root_dir / 'zone'
+        self.zone_dir.mkdir(parents=True, exist_ok=True)
+        self.secondary_dir = self.root_dir / 'secondary'
+        self.secondary_dir.mkdir(parents=True, exist_ok=True)
+
+        if daemon is None:
+            from nsconfig.daemon import NscDaemonNull
+            daemon = NscDaemonNull()
+        self.daemon = daemon
+        daemon.setup(self)
+
+    def add_zone(self,
+                 name: Optional[str] = None,
+                 reverse_for: str | IPNetwork | None = None,
+                 follow_primary: str | IPAddress | None = None,
+                 inherit_config: Optional[NscZoneConfig] = None,
+                 **kwargs) -> Zone:
         if inherit_config is None:
             inherit_config = self.default_zone_config
-        z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
-        assert z.name not in self.zones
-        self.zones[z.name] = z
-        return z
 
-    def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
-        if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
-            net = ip_network(net, strict=True)
-        name = name or self._reverse_zone_name(net)
-        return self.add_zone(name, reverse_for=net, **kwargs)
+        if reverse_for is not None:
+            if isinstance(reverse_for, str):
+                reverse_for = ip_network(reverse_for, strict=True)
+            name = name or self._reverse_zone_name(reverse_for)
+        assert name is not None
+        assert name not in self.zones
+
+        z: NscZone
+        if follow_primary is None:
+            z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
+        else:
+            if isinstance(follow_primary, str):
+                follow_primary = ip_address(follow_primary)
+            z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
+
+        self.zones[name] = z
+        return z
 
     def _reverse_zone_name(self, net: IPNetwork) -> str:
         if isinstance(net, IPv4Network):
@@ -279,7 +439,7 @@ class Nsc:
 
     def fill_reverse(self) -> None:
         for z in self.zones.values():
-            if z.reverse_for is not None:
+            if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
                 if isinstance(z.reverse_for, IPv4Network):
                     for addr4, ptr_list in self.ipv4_reverse.items():
                         if addr4 in z.reverse_for:
@@ -291,6 +451,10 @@ class Nsc:
                             for ptr_to in ptr_list:
                                 z._add_ipv6_reverse(addr6, ptr_to)
 
-    def dump(self) -> None:
-        for z in self.zones.values():
-            z.dump()
+    def get_zones(self) -> List[NscZone]:
+        return [self.zones[k] for k in sorted(self.zones.keys())]
+
+    def process(self) -> None:
+        self.fill_reverse()
+        for z in self.get_zones():
+            z.process()