From e0266ad53b2241e01e20748847980dd09b363f7a Mon Sep 17 00:00:00 2001 From: Martin Mares Date: Sun, 21 Apr 2024 00:43:49 +0200 Subject: [PATCH] Updating zones --- TODO | 3 + nsconfig/cli.py | 58 ++++++++++++++++++- nsconfig/core.py | 141 ++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 179 insertions(+), 23 deletions(-) diff --git a/TODO b/TODO index aa85749..eb31acb 100644 --- a/TODO +++ b/TODO @@ -1,3 +1,6 @@ - Names with dots +- E-mail addresses with dots in SOA - Classless reverse delegation - Blackhole zones +- Secondary zones +- Formatting of warnings diff --git a/nsconfig/cli.py b/nsconfig/cli.py index ae3e976..241ede4 100644 --- a/nsconfig/cli.py +++ b/nsconfig/cli.py @@ -1,11 +1,55 @@ import argparse +from pathlib import Path +from texttable import Texttable from nsconfig.core import Nsc def do_test(nsc: Nsc) -> None: - nsc.fill_reverse() - nsc.dump() + test_dir = Path('test') + test_dir.mkdir(exist_ok=True) + for z in nsc.get_zones(): + print(f'Zone: {z.name}') + print(f'Old serial: {z.prev_state.serial}') + print(f'Old hash: {z.prev_state.hash}') + print(f'New serial: {z.state.serial}') + print(f'New hash: {z.state.hash}') + out_file = test_dir / z.safe_name + print(f'Dumping to: {out_file}') + with open(out_file, 'w') as f: + z.dump(file=f) + print() + + +def do_status(nsc: Nsc) -> None: + table = Texttable(max_width=0) + table.header(['Zone', 'Old serial', 'Old hash', 'New serial', 'New hash', 'S']) + table.set_cols_dtype(['t', 'i', 't', 'i', 't', 't']) + table.set_deco(Texttable.HEADER) + + for z in nsc.get_zones(): + if z.state.serial == z.prev_state.serial: + action = "" + else: + action = '*' + table.add_row([ + z.name, + z.prev_state.serial, + z.prev_state.hash, + z.state.serial, + z.state.hash, + action, + ]) + + print(table.draw()) + + +def do_update(nsc: Nsc) -> None: + for z in nsc.get_zones(): + if z.state.serial != z.prev_state.serial: + print(f'Updating zone {z.name} (serial {z.state.serial})') + z.write_zone() + z.write_state() def main(nsc: Nsc) -> None: @@ -14,7 +58,17 @@ def main(nsc: Nsc) -> None: test_parser = subparsers.add_parser('test', help='test new configuration', description='Test new configuration') + status_parser = subparsers.add_parser('status', help='list status of zones', description='List status of zones') + + update_parser = subparsers.add_parser('update', help='update configuration', description='Update zone files and daemon configuration as needed') + args = parser.parse_args() + nsc.process() + if args.action == 'test': do_test(nsc) + elif args.action == 'status': + do_status(nsc) + elif args.action == 'update': + do_update(nsc) diff --git a/nsconfig/core.py b/nsconfig/core.py index 86f56ae..ce19bd9 100644 --- a/nsconfig/core.py +++ b/nsconfig/core.py @@ -1,5 +1,5 @@ from collections import defaultdict -from datetime import timedelta +from datetime import datetime, timedelta import dns.name from dns.name import Name from dns.node import Node @@ -14,9 +14,13 @@ import dns.rdtypes.ANY.TXT import dns.rdtypes.IN.A import dns.rdtypes.IN.AAAA from dns.zone import Zone +import hashlib from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network +import json +from pathlib import Path import socket -from typing import Optional, Dict, List, Self, Tuple, DefaultDict +import sys +from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO IPAddress = IPv4Address | IPv6Address @@ -153,35 +157,80 @@ NscZoneConfig.default_config = NscZoneConfig( ) +class NscZoneState: + serial: int + hash: str + + def __init__(self) -> None: + self.serial = 0 + self.hash = 'none' + + def load(self, file: Path) -> None: + try: + with open(file) as f: + js = json.load(f) + assert isinstance(js, dict) + if 'serial' in js: + self.serial = js['serial'] + if 'hash' in js: + self.hash = js['hash'] + except FileNotFoundError: + pass + + def save(self, file: Path) -> None: + new_file = Path(str(file) + '.new') + with open(new_file, 'w') as f: + js = { + 'serial': self.serial, + 'hash': self.hash, + } + json.dump(js, f, indent=4, sort_keys=True) + new_file.replace(file) + + class NscZone: nsc: 'Nsc' name: str + safe_name: str # For use in file names zone: Zone _min_ttl: int reverse_for: Optional[IPNetwork] + zone_file: Path + state_file: Path + state: NscZoneState + prev_state: NscZoneState def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None: self.nsc = nsc self.name = name + self.safe_name = name.replace('/', '@') self.config = NscZoneConfig(**kwargs).finalize() self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN) self._min_ttl = int(self.config.min_ttl.total_seconds()) self.reverse_for = reverse_for + self.zone_file = nsc.zone_dir / self.safe_name + self.state_file = nsc.state_dir / (self.safe_name + '.json') + self.state = NscZoneState() + self.prev_state = NscZoneState() + self.prev_state.load(self.state_file) + + self.update_soa() + + def update_soa(self) -> None: conf = self.config - root = self[""] - root._add( - dns.rdtypes.ANY.SOA.SOA( - RdataClass.IN, RdataType.SOA, - mname=conf.origin_server, - rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots - serial=12345, - refresh=int(conf.refresh.total_seconds()), - retry=int(conf.retry.total_seconds()), - expire=int(conf.expire.total_seconds()), - minimum=int(conf.min_ttl.total_seconds()), - ) + soa = dns.rdtypes.ANY.SOA.SOA( + RdataClass.IN, RdataType.SOA, + mname=conf.origin_server, + rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots + serial=self.state.serial, + refresh=int(conf.refresh.total_seconds()), + retry=int(conf.retry.total_seconds()), + expire=int(conf.expire.total_seconds()), + minimum=int(conf.min_ttl.total_seconds()), ) + self.zone.delete_rdataset("", RdataType.SOA) + self[""]._add(soa) def n(self, name: str) -> NscNode: return NscNode(self, name) @@ -194,16 +243,17 @@ class NscZone: n.A(*args, reverse=reverse) return n - def dump(self) -> None: + def dump(self, file: Optional[TextIO] = None) -> None: # Could use self.zone.to_file(sys.stdout), but we want better formatting - print(f'; Zone file for {self.name}') + file = file or sys.stdout + file.write(f'; Zone file for {self.name}\n\n') last_name = None for name, ttl, rec in self.zone.iterate_rdatas(): if name == last_name: print_name = "" else: print_name = name - print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}') + file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n') last_name = name def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None: @@ -222,19 +272,64 @@ class NscZone: name = '.'.join(reversed(parts)) self.n(name).PTR(ptr_to) + def gen_hash(self) -> None: + sha = hashlib.sha1() + for name, ttl, rec in self.zone.iterate_rdatas(): + text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n' + sha.update(text.encode('us-ascii')) + self.state.hash = sha.hexdigest()[:16] + + def gen_serial(self) -> None: + prev = self.prev_state.serial + if self.state.hash == self.prev_state.hash and prev > 0: + self.state.serial = self.prev_state.serial + else: + base = int(self.nsc.start_time.strftime('%Y%m%d00')) + if prev <= base: + self.state.serial = base + 1 + else: + self.state.serial = prev + 1 + if prev >= base + 99: + print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}') + + def process(self) -> None: + self.gen_hash() + self.gen_serial() + + def write_zone(self) -> None: + self.update_soa() + new_file = Path(str(self.zone_file) + '.new') + with open(new_file, 'w') as f: + self.dump(file=f) + new_file.replace(self.zone_file) + + def write_state(self) -> None: + self.state.save(self.state_file) + class Nsc: + start_time: datetime zones: Dict[str, NscZone] default_zone_config: NscZoneConfig ipv4_reverse: DefaultDict[IPv4Address, List[Name]] ipv6_reverse: DefaultDict[IPv6Address, List[Name]] + root_dir: Path + state_dir: Path + zone_dir: Path - def __init__(self, **kwargs) -> None: + def __init__(self, directory: str = '.', **kwargs) -> None: + self.start_time = datetime.now() self.zones = {} self.default_zone_config = NscZoneConfig(**kwargs) self.ipv4_reverse = defaultdict(list) self.ipv6_reverse = defaultdict(list) + self.root_dir = Path(directory) + self.state_dir = self.root_dir / 'state' + self.state_dir.mkdir(parents=True, exist_ok=True) + self.zone_dir = self.root_dir / 'zone' + self.zone_dir.mkdir(parents=True, exist_ok=True) + def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone: if inherit_config is None: inherit_config = self.default_zone_config @@ -291,6 +386,10 @@ class Nsc: for ptr_to in ptr_list: z._add_ipv6_reverse(addr6, ptr_to) - def dump(self) -> None: - for z in self.zones.values(): - z.dump() + def get_zones(self) -> List[NscZone]: + return [self.zones[k] for k in sorted(self.zones.keys())] + + def process(self) -> None: + self.fill_reverse() + for z in self.get_zones(): + z.process() -- 2.39.2