X-Git-Url: http://mj.ucw.cz/gitweb/?a=blobdiff_plain;f=nsconfig%2Fcore.py;h=8a3bb74f694453655887758f39063fa20182779e;hb=e5c27231152e92900bab1ed9f8f0f47f31e4aadb;hp=86f56ae90f7cd9b8e403cb773f095819fa35e7ab;hpb=69619552720f2ab09ff03de6521ad8b54edc87ae;p=pynsc.git diff --git a/nsconfig/core.py b/nsconfig/core.py index 86f56ae..8a3bb74 100644 --- a/nsconfig/core.py +++ b/nsconfig/core.py @@ -1,5 +1,5 @@ from collections import defaultdict -from datetime import timedelta +from datetime import datetime, timedelta import dns.name from dns.name import Name from dns.node import Node @@ -14,9 +14,18 @@ import dns.rdtypes.ANY.TXT import dns.rdtypes.IN.A import dns.rdtypes.IN.AAAA from dns.zone import Zone +from enum import Enum, auto +import hashlib from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network +import json +from pathlib import Path import socket -from typing import Optional, Dict, List, Self, Tuple, DefaultDict +import sys +from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING + + +if TYPE_CHECKING: + from nsconfig.daemon import NscDaemon IPAddress = IPv4Address | IPv6Address @@ -25,12 +34,12 @@ IPAddr = str | IPAddress | List[str | IPAddress] class NscNode: - nsc_zone: 'NscZone' + nsc_zone: 'NscZonePrimary' name: str node: Node _ttl: int - def __init__(self, nsc_zone: 'NscZone', name: str) -> None: + def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None: self.nsc_zone = nsc_zone self.name = name self.node = nsc_zone.zone.find_node(name, create=True) @@ -153,35 +162,101 @@ NscZoneConfig.default_config = NscZoneConfig( ) +class NscZoneState: + serial: int + hash: str + + def __init__(self) -> None: + self.serial = 0 + self.hash = 'none' + + def load(self, file: Path) -> None: + try: + with open(file) as f: + js = json.load(f) + assert isinstance(js, dict) + if 'serial' in js: + self.serial = js['serial'] + if 'hash' in js: + self.hash = js['hash'] + except FileNotFoundError: + pass + + def save(self, file: Path) -> None: + new_file = Path(str(file) + '.new') + with open(new_file, 'w') as f: + js = { + 'serial': self.serial, + 'hash': self.hash, + } + json.dump(js, f, indent=4, sort_keys=True) + new_file.replace(file) + + +class ZoneType(Enum): + primary = auto() + secondary = auto() + + class NscZone: nsc: 'Nsc' name: str - zone: Zone - _min_ttl: int + safe_name: str # For use in file names + zone_type: ZoneType reverse_for: Optional[IPNetwork] - def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None: + def __init__(self, + nsc: 'Nsc', + name: str, + reverse_for: Optional[IPNetwork], + **kwargs) -> None: self.nsc = nsc self.name = name + self.safe_name = name.replace('/', '@') self.config = NscZoneConfig(**kwargs).finalize() - self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN) - self._min_ttl = int(self.config.min_ttl.total_seconds()) self.reverse_for = reverse_for + def process(self) -> None: + pass + + +class NscZonePrimary(NscZone): + zone: Zone + _min_ttl: int + zone_file: Path + state_file: Path + state: NscZoneState + prev_state: NscZoneState + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.zone_type = ZoneType.primary + self.zone_file = self.nsc.zone_dir / self.safe_name + self.state_file = self.nsc.state_dir / (self.safe_name + '.json') + + self.state = NscZoneState() + self.prev_state = NscZoneState() + self.prev_state.load(self.state_file) + + self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN) + self._min_ttl = int(self.config.min_ttl.total_seconds()) + self.update_soa() + + def update_soa(self) -> None: conf = self.config - root = self[""] - root._add( - dns.rdtypes.ANY.SOA.SOA( - RdataClass.IN, RdataType.SOA, - mname=conf.origin_server, - rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots - serial=12345, - refresh=int(conf.refresh.total_seconds()), - retry=int(conf.retry.total_seconds()), - expire=int(conf.expire.total_seconds()), - minimum=int(conf.min_ttl.total_seconds()), - ) + soa = dns.rdtypes.ANY.SOA.SOA( + RdataClass.IN, RdataType.SOA, + mname=conf.origin_server, + rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots + serial=self.state.serial, + refresh=int(conf.refresh.total_seconds()), + retry=int(conf.retry.total_seconds()), + expire=int(conf.expire.total_seconds()), + minimum=int(conf.min_ttl.total_seconds()), ) + self.zone.delete_rdataset("", RdataType.SOA) + self[""]._add(soa) def n(self, name: str) -> NscNode: return NscNode(self, name) @@ -194,60 +269,145 @@ class NscZone: n.A(*args, reverse=reverse) return n - def dump(self) -> None: + def dump(self, file: Optional[TextIO] = None) -> None: # Could use self.zone.to_file(sys.stdout), but we want better formatting - print(f'; Zone file for {self.name}') + file = file or sys.stdout + file.write(f'; Zone file for {self.name}\n\n') last_name = None for name, ttl, rec in self.zone.iterate_rdatas(): if name == last_name: print_name = "" else: print_name = name - print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}') + file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n') last_name = name def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None: - # Called only for addresses from this reverse network - assert self.reverse_for is not None + assert isinstance(self.reverse_for, IPv4Network) parts = str(addr).split('.') parts = parts[self.reverse_for.prefixlen // 8:] name = '.'.join(reversed(parts)) self.n(name).PTR(ptr_to) def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None: - # Called only for addresses from this reverse network - assert self.reverse_for is not None + assert isinstance(self.reverse_for, IPv6Network) parts = addr.exploded.replace(':', "") parts = parts[self.reverse_for.prefixlen // 4:] name = '.'.join(reversed(parts)) self.n(name).PTR(ptr_to) + def gen_hash(self) -> None: + sha = hashlib.sha1() + for name, ttl, rec in self.zone.iterate_rdatas(): + text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n' + sha.update(text.encode('us-ascii')) + self.state.hash = sha.hexdigest()[:16] + + def gen_serial(self) -> None: + prev = self.prev_state.serial + if self.state.hash == self.prev_state.hash and prev > 0: + self.state.serial = self.prev_state.serial + else: + base = int(self.nsc.start_time.strftime('%Y%m%d00')) + if prev <= base: + self.state.serial = base + 1 + else: + self.state.serial = prev + 1 + if prev >= base + 99: + print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}') + + def process(self) -> None: + if self.zone_type == ZoneType.primary: + self.gen_hash() + self.gen_serial() + + def write_zone(self) -> None: + self.update_soa() + new_file = Path(str(self.zone_file) + '.new') + with open(new_file, 'w') as f: + self.dump(file=f) + new_file.replace(self.zone_file) + + def write_state(self) -> None: + self.state.save(self.state_file) + + def is_changed(self) -> bool: + return self.state.serial != self.prev_state.serial + + +class NscZoneSecondary(NscZone): + primary_server: IPAddress + secondary_file: Path + + def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.zone_type = ZoneType.secondary + self.primary_server = primary_server + self.secondary_file = self.nsc.secondary_dir / self.safe_name + class Nsc: + start_time: datetime zones: Dict[str, NscZone] default_zone_config: NscZoneConfig ipv4_reverse: DefaultDict[IPv4Address, List[Name]] ipv6_reverse: DefaultDict[IPv6Address, List[Name]] + root_dir: Path + state_dir: Path + zone_dir: Path + secondary_dir: Path + daemon: 'NscDaemon' # Set by DaemonConfig class - def __init__(self, **kwargs) -> None: + def __init__(self, + directory: str = '.', + daemon: Optional['NscDaemon'] = None, + **kwargs) -> None: + self.start_time = datetime.now() self.zones = {} self.default_zone_config = NscZoneConfig(**kwargs) self.ipv4_reverse = defaultdict(list) self.ipv6_reverse = defaultdict(list) - def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone: + self.root_dir = Path(directory) + self.state_dir = self.root_dir / 'state' + self.state_dir.mkdir(parents=True, exist_ok=True) + self.zone_dir = self.root_dir / 'zone' + self.zone_dir.mkdir(parents=True, exist_ok=True) + self.secondary_dir = self.root_dir / 'secondary' + self.secondary_dir.mkdir(parents=True, exist_ok=True) + + if daemon is None: + from nsconfig.daemon import NscDaemonNull + daemon = NscDaemonNull() + self.daemon = daemon + daemon.setup(self) + + def add_zone(self, + name: Optional[str] = None, + reverse_for: str | IPNetwork | None = None, + follow_primary: str | IPAddress | None = None, + inherit_config: Optional[NscZoneConfig] = None, + **kwargs) -> Zone: if inherit_config is None: inherit_config = self.default_zone_config - z = NscZone(self, *args, inherit_config=inherit_config, **kwargs) - assert z.name not in self.zones - self.zones[z.name] = z - return z - def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone: - if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)): - net = ip_network(net, strict=True) - name = name or self._reverse_zone_name(net) - return self.add_zone(name, reverse_for=net, **kwargs) + if reverse_for is not None: + if isinstance(reverse_for, str): + reverse_for = ip_network(reverse_for, strict=True) + name = name or self._reverse_zone_name(reverse_for) + assert name is not None + assert name not in self.zones + + z: NscZone + if follow_primary is None: + z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs) + else: + if isinstance(follow_primary, str): + follow_primary = ip_address(follow_primary) + z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs) + + self.zones[name] = z + return z def _reverse_zone_name(self, net: IPNetwork) -> str: if isinstance(net, IPv4Network): @@ -279,7 +439,7 @@ class Nsc: def fill_reverse(self) -> None: for z in self.zones.values(): - if z.reverse_for is not None: + if isinstance(z, NscZonePrimary) and z.reverse_for is not None: if isinstance(z.reverse_for, IPv4Network): for addr4, ptr_list in self.ipv4_reverse.items(): if addr4 in z.reverse_for: @@ -291,6 +451,10 @@ class Nsc: for ptr_to in ptr_list: z._add_ipv6_reverse(addr6, ptr_to) - def dump(self) -> None: - for z in self.zones.values(): - z.dump() + def get_zones(self) -> List[NscZone]: + return [self.zones[k] for k in sorted(self.zones.keys())] + + def process(self) -> None: + self.fill_reverse() + for z in self.get_zones(): + z.process()