import argparse
+from pathlib import Path
+from texttable import Texttable
from nsconfig.core import Nsc
def do_test(nsc: Nsc) -> None:
- nsc.fill_reverse()
- nsc.dump()
+ test_dir = Path('test')
+ test_dir.mkdir(exist_ok=True)
+ for z in nsc.get_zones():
+ print(f'Zone: {z.name}')
+ print(f'Old serial: {z.prev_state.serial}')
+ print(f'Old hash: {z.prev_state.hash}')
+ print(f'New serial: {z.state.serial}')
+ print(f'New hash: {z.state.hash}')
+ out_file = test_dir / z.safe_name
+ print(f'Dumping to: {out_file}')
+ with open(out_file, 'w') as f:
+ z.dump(file=f)
+ print()
+
+
+def do_status(nsc: Nsc) -> None:
+ table = Texttable(max_width=0)
+ table.header(['Zone', 'Old serial', 'Old hash', 'New serial', 'New hash', 'S'])
+ table.set_cols_dtype(['t', 'i', 't', 'i', 't', 't'])
+ table.set_deco(Texttable.HEADER)
+
+ for z in nsc.get_zones():
+ if z.state.serial == z.prev_state.serial:
+ action = ""
+ else:
+ action = '*'
+ table.add_row([
+ z.name,
+ z.prev_state.serial,
+ z.prev_state.hash,
+ z.state.serial,
+ z.state.hash,
+ action,
+ ])
+
+ print(table.draw())
+
+
+def do_update(nsc: Nsc) -> None:
+ for z in nsc.get_zones():
+ if z.state.serial != z.prev_state.serial:
+ print(f'Updating zone {z.name} (serial {z.state.serial})')
+ z.write_zone()
+ z.write_state()
def main(nsc: Nsc) -> None:
test_parser = subparsers.add_parser('test', help='test new configuration', description='Test new configuration')
+ status_parser = subparsers.add_parser('status', help='list status of zones', description='List status of zones')
+
+ update_parser = subparsers.add_parser('update', help='update configuration', description='Update zone files and daemon configuration as needed')
+
args = parser.parse_args()
+ nsc.process()
+
if args.action == 'test':
do_test(nsc)
+ elif args.action == 'status':
+ do_status(nsc)
+ elif args.action == 'update':
+ do_update(nsc)
from collections import defaultdict
-from datetime import timedelta
+from datetime import datetime, timedelta
import dns.name
from dns.name import Name
from dns.node import Node
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
from dns.zone import Zone
+import hashlib
from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
+import json
+from pathlib import Path
import socket
-from typing import Optional, Dict, List, Self, Tuple, DefaultDict
+import sys
+from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO
IPAddress = IPv4Address | IPv6Address
)
+class NscZoneState:
+ serial: int
+ hash: str
+
+ def __init__(self) -> None:
+ self.serial = 0
+ self.hash = 'none'
+
+ def load(self, file: Path) -> None:
+ try:
+ with open(file) as f:
+ js = json.load(f)
+ assert isinstance(js, dict)
+ if 'serial' in js:
+ self.serial = js['serial']
+ if 'hash' in js:
+ self.hash = js['hash']
+ except FileNotFoundError:
+ pass
+
+ def save(self, file: Path) -> None:
+ new_file = Path(str(file) + '.new')
+ with open(new_file, 'w') as f:
+ js = {
+ 'serial': self.serial,
+ 'hash': self.hash,
+ }
+ json.dump(js, f, indent=4, sort_keys=True)
+ new_file.replace(file)
+
+
class NscZone:
nsc: 'Nsc'
name: str
+ safe_name: str # For use in file names
zone: Zone
_min_ttl: int
reverse_for: Optional[IPNetwork]
+ zone_file: Path
+ state_file: Path
+ state: NscZoneState
+ prev_state: NscZoneState
def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
self.nsc = nsc
self.name = name
+ self.safe_name = name.replace('/', '@')
self.config = NscZoneConfig(**kwargs).finalize()
self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
self._min_ttl = int(self.config.min_ttl.total_seconds())
self.reverse_for = reverse_for
+ self.zone_file = nsc.zone_dir / self.safe_name
+ self.state_file = nsc.state_dir / (self.safe_name + '.json')
+ self.state = NscZoneState()
+ self.prev_state = NscZoneState()
+ self.prev_state.load(self.state_file)
+
+ self.update_soa()
+
+ def update_soa(self) -> None:
conf = self.config
- root = self[""]
- root._add(
- dns.rdtypes.ANY.SOA.SOA(
- RdataClass.IN, RdataType.SOA,
- mname=conf.origin_server,
- rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
- serial=12345,
- refresh=int(conf.refresh.total_seconds()),
- retry=int(conf.retry.total_seconds()),
- expire=int(conf.expire.total_seconds()),
- minimum=int(conf.min_ttl.total_seconds()),
- )
+ soa = dns.rdtypes.ANY.SOA.SOA(
+ RdataClass.IN, RdataType.SOA,
+ mname=conf.origin_server,
+ rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
+ serial=self.state.serial,
+ refresh=int(conf.refresh.total_seconds()),
+ retry=int(conf.retry.total_seconds()),
+ expire=int(conf.expire.total_seconds()),
+ minimum=int(conf.min_ttl.total_seconds()),
)
+ self.zone.delete_rdataset("", RdataType.SOA)
+ self[""]._add(soa)
def n(self, name: str) -> NscNode:
return NscNode(self, name)
n.A(*args, reverse=reverse)
return n
- def dump(self) -> None:
+ def dump(self, file: Optional[TextIO] = None) -> None:
# Could use self.zone.to_file(sys.stdout), but we want better formatting
- print(f'; Zone file for {self.name}')
+ file = file or sys.stdout
+ file.write(f'; Zone file for {self.name}\n\n')
last_name = None
for name, ttl, rec in self.zone.iterate_rdatas():
if name == last_name:
print_name = ""
else:
print_name = name
- print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
+ file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
last_name = name
def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
name = '.'.join(reversed(parts))
self.n(name).PTR(ptr_to)
+ def gen_hash(self) -> None:
+ sha = hashlib.sha1()
+ for name, ttl, rec in self.zone.iterate_rdatas():
+ text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
+ sha.update(text.encode('us-ascii'))
+ self.state.hash = sha.hexdigest()[:16]
+
+ def gen_serial(self) -> None:
+ prev = self.prev_state.serial
+ if self.state.hash == self.prev_state.hash and prev > 0:
+ self.state.serial = self.prev_state.serial
+ else:
+ base = int(self.nsc.start_time.strftime('%Y%m%d00'))
+ if prev <= base:
+ self.state.serial = base + 1
+ else:
+ self.state.serial = prev + 1
+ if prev >= base + 99:
+ print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
+
+ def process(self) -> None:
+ self.gen_hash()
+ self.gen_serial()
+
+ def write_zone(self) -> None:
+ self.update_soa()
+ new_file = Path(str(self.zone_file) + '.new')
+ with open(new_file, 'w') as f:
+ self.dump(file=f)
+ new_file.replace(self.zone_file)
+
+ def write_state(self) -> None:
+ self.state.save(self.state_file)
+
class Nsc:
+ start_time: datetime
zones: Dict[str, NscZone]
default_zone_config: NscZoneConfig
ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
+ root_dir: Path
+ state_dir: Path
+ zone_dir: Path
- def __init__(self, **kwargs) -> None:
+ def __init__(self, directory: str = '.', **kwargs) -> None:
+ self.start_time = datetime.now()
self.zones = {}
self.default_zone_config = NscZoneConfig(**kwargs)
self.ipv4_reverse = defaultdict(list)
self.ipv6_reverse = defaultdict(list)
+ self.root_dir = Path(directory)
+ self.state_dir = self.root_dir / 'state'
+ self.state_dir.mkdir(parents=True, exist_ok=True)
+ self.zone_dir = self.root_dir / 'zone'
+ self.zone_dir.mkdir(parents=True, exist_ok=True)
+
def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
if inherit_config is None:
inherit_config = self.default_zone_config
for ptr_to in ptr_list:
z._add_ipv6_reverse(addr6, ptr_to)
- def dump(self) -> None:
- for z in self.zones.values():
- z.dump()
+ def get_zones(self) -> List[NscZone]:
+ return [self.zones[k] for k in sorted(self.zones.keys())]
+
+ def process(self) -> None:
+ self.fill_reverse()
+ for z in self.get_zones():
+ z.process()