]> mj.ucw.cz Git - pynsc.git/commitdiff
Updating zones
authorMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 22:43:49 +0000 (00:43 +0200)
committerMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 22:43:49 +0000 (00:43 +0200)
TODO
nsconfig/cli.py
nsconfig/core.py

diff --git a/TODO b/TODO
index aa85749ad4b578bb6d5fe642add0f23bf4a01fb1..eb31acb75e4d3692f965f508b86d9ff62db5a78b 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,3 +1,6 @@
 - Names with dots
+- E-mail addresses with dots in SOA
 - Classless reverse delegation
 - Blackhole zones
+- Secondary zones
+- Formatting of warnings
index ae3e97658e4177e3e4c02e61bbfdb47614abca8f..241ede47d051a92559b75cc36b1ca12fe750837c 100644 (file)
@@ -1,11 +1,55 @@
 import argparse
+from pathlib import Path
+from texttable import Texttable
 
 from nsconfig.core import Nsc
 
 
 def do_test(nsc: Nsc) -> None:
-    nsc.fill_reverse()
-    nsc.dump()
+    test_dir = Path('test')
+    test_dir.mkdir(exist_ok=True)
+    for z in nsc.get_zones():
+        print(f'Zone:        {z.name}')
+        print(f'Old serial:  {z.prev_state.serial}')
+        print(f'Old hash:    {z.prev_state.hash}')
+        print(f'New serial:  {z.state.serial}')
+        print(f'New hash:    {z.state.hash}')
+        out_file = test_dir / z.safe_name
+        print(f'Dumping to:  {out_file}')
+        with open(out_file, 'w') as f:
+            z.dump(file=f)
+        print()
+
+
+def do_status(nsc: Nsc) -> None:
+    table = Texttable(max_width=0)
+    table.header(['Zone', 'Old serial', 'Old hash', 'New serial', 'New hash', 'S'])
+    table.set_cols_dtype(['t', 'i', 't', 'i', 't', 't'])
+    table.set_deco(Texttable.HEADER)
+
+    for z in nsc.get_zones():
+        if z.state.serial == z.prev_state.serial:
+            action = ""
+        else:
+            action = '*'
+        table.add_row([
+            z.name,
+            z.prev_state.serial,
+            z.prev_state.hash,
+            z.state.serial,
+            z.state.hash,
+            action,
+        ])
+
+    print(table.draw())
+
+
+def do_update(nsc: Nsc) -> None:
+    for z in nsc.get_zones():
+        if z.state.serial != z.prev_state.serial:
+            print(f'Updating zone {z.name} (serial {z.state.serial})')
+            z.write_zone()
+            z.write_state()
 
 
 def main(nsc: Nsc) -> None:
@@ -14,7 +58,17 @@ def main(nsc: Nsc) -> None:
 
     test_parser = subparsers.add_parser('test', help='test new configuration', description='Test new configuration')
 
+    status_parser = subparsers.add_parser('status', help='list status of zones', description='List status of zones')
+
+    update_parser = subparsers.add_parser('update', help='update configuration', description='Update zone files and daemon configuration as needed')
+
     args = parser.parse_args()
 
+    nsc.process()
+
     if args.action == 'test':
         do_test(nsc)
+    elif args.action == 'status':
+        do_status(nsc)
+    elif args.action == 'update':
+        do_update(nsc)
index 86f56ae90f7cd9b8e403cb773f095819fa35e7ab..ce19bd98d73e759b2d9de2b1c3f30433a4954243 100644 (file)
@@ -1,5 +1,5 @@
 from collections import defaultdict
-from datetime import timedelta
+from datetime import datetime, timedelta
 import dns.name
 from dns.name import Name
 from dns.node import Node
@@ -14,9 +14,13 @@ import dns.rdtypes.ANY.TXT
 import dns.rdtypes.IN.A
 import dns.rdtypes.IN.AAAA
 from dns.zone import Zone
+import hashlib
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
+import json
+from pathlib import Path
 import socket
-from typing import Optional, Dict, List, Self, Tuple, DefaultDict
+import sys
+from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
 
 
 IPAddress = IPv4Address | IPv6Address
@@ -153,35 +157,80 @@ NscZoneConfig.default_config = NscZoneConfig(
 )
 
 
+class NscZoneState:
+    serial: int
+    hash: str
+
+    def __init__(self) -> None:
+        self.serial = 0
+        self.hash = 'none'
+
+    def load(self, file: Path) -> None:
+        try:
+            with open(file) as f:
+                js = json.load(f)
+                assert isinstance(js, dict)
+                if 'serial' in js:
+                    self.serial = js['serial']
+                if 'hash' in js:
+                    self.hash = js['hash']
+        except FileNotFoundError:
+            pass
+
+    def save(self, file: Path) -> None:
+        new_file = Path(str(file) + '.new')
+        with open(new_file, 'w') as f:
+            js = {
+                'serial': self.serial,
+                'hash': self.hash,
+            }
+            json.dump(js, f, indent=4, sort_keys=True)
+        new_file.replace(file)
+
+
 class NscZone:
     nsc: 'Nsc'
     name: str
+    safe_name: str      # For use in file names
     zone: Zone
     _min_ttl: int
     reverse_for: Optional[IPNetwork]
+    zone_file: Path
+    state_file: Path
+    state: NscZoneState
+    prev_state: NscZoneState
 
     def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
         self.nsc = nsc
         self.name = name
+        self.safe_name = name.replace('/', '@')
         self.config = NscZoneConfig(**kwargs).finalize()
         self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
         self._min_ttl = int(self.config.min_ttl.total_seconds())
         self.reverse_for = reverse_for
 
+        self.zone_file = nsc.zone_dir / self.safe_name
+        self.state_file = nsc.state_dir / (self.safe_name + '.json')
+        self.state = NscZoneState()
+        self.prev_state = NscZoneState()
+        self.prev_state.load(self.state_file)
+
+        self.update_soa()
+
+    def update_soa(self) -> None:
         conf = self.config
-        root = self[""]
-        root._add(
-            dns.rdtypes.ANY.SOA.SOA(
-                RdataClass.IN, RdataType.SOA,
-                mname=conf.origin_server,
-                rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
-                serial=12345,
-                refresh=int(conf.refresh.total_seconds()),
-                retry=int(conf.retry.total_seconds()),
-                expire=int(conf.expire.total_seconds()),
-                minimum=int(conf.min_ttl.total_seconds()),
-            )
+        soa = dns.rdtypes.ANY.SOA.SOA(
+            RdataClass.IN, RdataType.SOA,
+            mname=conf.origin_server,
+            rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
+            serial=self.state.serial,
+            refresh=int(conf.refresh.total_seconds()),
+            retry=int(conf.retry.total_seconds()),
+            expire=int(conf.expire.total_seconds()),
+            minimum=int(conf.min_ttl.total_seconds()),
         )
+        self.zone.delete_rdataset("", RdataType.SOA)
+        self[""]._add(soa)
 
     def n(self, name: str) -> NscNode:
         return NscNode(self, name)
@@ -194,16 +243,17 @@ class NscZone:
         n.A(*args, reverse=reverse)
         return n
 
-    def dump(self) -> None:
+    def dump(self, file: Optional[TextIO] = None) -> None:
         # Could use self.zone.to_file(sys.stdout), but we want better formatting
-        print(f'; Zone file for {self.name}')
+        file = file or sys.stdout
+        file.write(f'; Zone file for {self.name}\n\n')
         last_name = None
         for name, ttl, rec in self.zone.iterate_rdatas():
             if name == last_name:
                 print_name = ""
             else:
                 print_name = name
-            print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
+            file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
             last_name = name
 
     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
@@ -222,19 +272,64 @@ class NscZone:
         name = '.'.join(reversed(parts))
         self.n(name).PTR(ptr_to)
 
+    def gen_hash(self) -> None:
+        sha = hashlib.sha1()
+        for name, ttl, rec in self.zone.iterate_rdatas():
+            text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
+            sha.update(text.encode('us-ascii'))
+        self.state.hash = sha.hexdigest()[:16]
+
+    def gen_serial(self) -> None:
+        prev = self.prev_state.serial
+        if self.state.hash == self.prev_state.hash and prev > 0:
+            self.state.serial = self.prev_state.serial
+        else:
+            base = int(self.nsc.start_time.strftime('%Y%m%d00'))
+            if prev <= base:
+                self.state.serial = base + 1
+            else:
+                self.state.serial = prev + 1
+                if prev >= base + 99:
+                    print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
+
+    def process(self) -> None:
+        self.gen_hash()
+        self.gen_serial()
+
+    def write_zone(self) -> None:
+        self.update_soa()
+        new_file = Path(str(self.zone_file) + '.new')
+        with open(new_file, 'w') as f:
+            self.dump(file=f)
+        new_file.replace(self.zone_file)
+
+    def write_state(self) -> None:
+        self.state.save(self.state_file)
+
 
 class Nsc:
+    start_time: datetime
     zones: Dict[str, NscZone]
     default_zone_config: NscZoneConfig
     ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
     ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
+    root_dir: Path
+    state_dir: Path
+    zone_dir: Path
 
-    def __init__(self, **kwargs) -> None:
+    def __init__(self, directory: str = '.', **kwargs) -> None:
+        self.start_time = datetime.now()
         self.zones = {}
         self.default_zone_config = NscZoneConfig(**kwargs)
         self.ipv4_reverse = defaultdict(list)
         self.ipv6_reverse = defaultdict(list)
 
+        self.root_dir = Path(directory)
+        self.state_dir = self.root_dir / 'state'
+        self.state_dir.mkdir(parents=True, exist_ok=True)
+        self.zone_dir = self.root_dir / 'zone'
+        self.zone_dir.mkdir(parents=True, exist_ok=True)
+
     def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
         if inherit_config is None:
             inherit_config = self.default_zone_config
@@ -291,6 +386,10 @@ class Nsc:
                             for ptr_to in ptr_list:
                                 z._add_ipv6_reverse(addr6, ptr_to)
 
-    def dump(self) -> None:
-        for z in self.zones.values():
-            z.dump()
+    def get_zones(self) -> List[NscZone]:
+        return [self.zones[k] for k in sorted(self.zones.keys())]
+
+    def process(self) -> None:
+        self.fill_reverse()
+        for z in self.get_zones():
+            z.process()