from collections import defaultdict
-from datetime import timedelta
+from datetime import datetime, timedelta
import dns.name
from dns.name import Name
from dns.node import Node
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
from dns.zone import Zone
+from enum import Enum, auto
+import hashlib
from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
+import json
+from pathlib import Path
import socket
-from typing import Optional, Dict, List, Self, Tuple, DefaultDict
+import sys
+from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+ from nsconfig.daemon import NscDaemon
IPAddress = IPv4Address | IPv6Address
class NscNode:
- nsc_zone: 'NscZone'
+ nsc_zone: 'NscZonePrimary'
name: str
node: Node
_ttl: int
- def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
+ def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
self.nsc_zone = nsc_zone
self.name = name
self.node = nsc_zone.zone.find_node(name, create=True)
)
+class NscZoneState:
+ serial: int
+ hash: str
+
+ def __init__(self) -> None:
+ self.serial = 0
+ self.hash = 'none'
+
+ def load(self, file: Path) -> None:
+ try:
+ with open(file) as f:
+ js = json.load(f)
+ assert isinstance(js, dict)
+ if 'serial' in js:
+ self.serial = js['serial']
+ if 'hash' in js:
+ self.hash = js['hash']
+ except FileNotFoundError:
+ pass
+
+ def save(self, file: Path) -> None:
+ new_file = Path(str(file) + '.new')
+ with open(new_file, 'w') as f:
+ js = {
+ 'serial': self.serial,
+ 'hash': self.hash,
+ }
+ json.dump(js, f, indent=4, sort_keys=True)
+ new_file.replace(file)
+
+
+class ZoneType(Enum):
+ primary = auto()
+ secondary = auto()
+
+
class NscZone:
nsc: 'Nsc'
name: str
- zone: Zone
- _min_ttl: int
+ safe_name: str # For use in file names
+ zone_type: ZoneType
reverse_for: Optional[IPNetwork]
- def __init__(self, nsc: 'Nsc', name: str, reverse_for: Optional[IPNetwork] = None, **kwargs) -> None:
+ def __init__(self,
+ nsc: 'Nsc',
+ name: str,
+ reverse_for: Optional[IPNetwork],
+ **kwargs) -> None:
self.nsc = nsc
self.name = name
+ self.safe_name = name.replace('/', '@')
self.config = NscZoneConfig(**kwargs).finalize()
- self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
- self._min_ttl = int(self.config.min_ttl.total_seconds())
self.reverse_for = reverse_for
+ def process(self) -> None:
+ pass
+
+
+class NscZonePrimary(NscZone):
+ zone: Zone
+ _min_ttl: int
+ zone_file: Path
+ state_file: Path
+ state: NscZoneState
+ prev_state: NscZoneState
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ self.zone_type = ZoneType.primary
+ self.zone_file = self.nsc.zone_dir / self.safe_name
+ self.state_file = self.nsc.state_dir / (self.safe_name + '.json')
+
+ self.state = NscZoneState()
+ self.prev_state = NscZoneState()
+ self.prev_state.load(self.state_file)
+
+ self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
+ self._min_ttl = int(self.config.min_ttl.total_seconds())
+ self.update_soa()
+
+ def update_soa(self) -> None:
conf = self.config
- root = self[""]
- root._add(
- dns.rdtypes.ANY.SOA.SOA(
- RdataClass.IN, RdataType.SOA,
- mname=conf.origin_server,
- rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
- serial=12345,
- refresh=int(conf.refresh.total_seconds()),
- retry=int(conf.retry.total_seconds()),
- expire=int(conf.expire.total_seconds()),
- minimum=int(conf.min_ttl.total_seconds()),
- )
+ soa = dns.rdtypes.ANY.SOA.SOA(
+ RdataClass.IN, RdataType.SOA,
+ mname=conf.origin_server,
+ rname=conf.admin_email.replace('@', '.'), # FIXME: names with dots
+ serial=self.state.serial,
+ refresh=int(conf.refresh.total_seconds()),
+ retry=int(conf.retry.total_seconds()),
+ expire=int(conf.expire.total_seconds()),
+ minimum=int(conf.min_ttl.total_seconds()),
)
+ self.zone.delete_rdataset("", RdataType.SOA)
+ self[""]._add(soa)
def n(self, name: str) -> NscNode:
return NscNode(self, name)
n.A(*args, reverse=reverse)
return n
- def dump(self) -> None:
+ def dump(self, file: Optional[TextIO] = None) -> None:
# Could use self.zone.to_file(sys.stdout), but we want better formatting
- print(f'; Zone file for {self.name}')
+ file = file or sys.stdout
+ file.write(f'; Zone file for {self.name}\n\n')
last_name = None
for name, ttl, rec in self.zone.iterate_rdatas():
if name == last_name:
print_name = ""
else:
print_name = name
- print(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
+ file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
last_name = name
def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
- # Called only for addresses from this reverse network
- assert self.reverse_for is not None
+ assert isinstance(self.reverse_for, IPv4Network)
parts = str(addr).split('.')
parts = parts[self.reverse_for.prefixlen // 8:]
name = '.'.join(reversed(parts))
self.n(name).PTR(ptr_to)
def _add_ipv6_reverse(self, addr: IPv6Address, ptr_to: Name) -> None:
- # Called only for addresses from this reverse network
- assert self.reverse_for is not None
+ assert isinstance(self.reverse_for, IPv6Network)
parts = addr.exploded.replace(':', "")
parts = parts[self.reverse_for.prefixlen // 4:]
name = '.'.join(reversed(parts))
self.n(name).PTR(ptr_to)
+ def gen_hash(self) -> None:
+ sha = hashlib.sha1()
+ for name, ttl, rec in self.zone.iterate_rdatas():
+ text = f'{name}\t{ttl}\t{rec.rdtype.name}\t{rec.to_text()}\n'
+ sha.update(text.encode('us-ascii'))
+ self.state.hash = sha.hexdigest()[:16]
+
+ def gen_serial(self) -> None:
+ prev = self.prev_state.serial
+ if self.state.hash == self.prev_state.hash and prev > 0:
+ self.state.serial = self.prev_state.serial
+ else:
+ base = int(self.nsc.start_time.strftime('%Y%m%d00'))
+ if prev <= base:
+ self.state.serial = base + 1
+ else:
+ self.state.serial = prev + 1
+ if prev >= base + 99:
+ print(f'WARNING: Serial number overflow for zone {self.name}, current is {self.state.serial}')
+
+ def process(self) -> None:
+ if self.zone_type == ZoneType.primary:
+ self.gen_hash()
+ self.gen_serial()
+
+ def write_zone(self) -> None:
+ self.update_soa()
+ new_file = Path(str(self.zone_file) + '.new')
+ with open(new_file, 'w') as f:
+ self.dump(file=f)
+ new_file.replace(self.zone_file)
+
+ def write_state(self) -> None:
+ self.state.save(self.state_file)
+
+ def is_changed(self) -> bool:
+ return self.state.serial != self.prev_state.serial
+
+
+class NscZoneSecondary(NscZone):
+ primary_server: IPAddress
+ secondary_file: Path
+
+ def __init__(self, *args, primary_server=IPAddress, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.zone_type = ZoneType.secondary
+ self.primary_server = primary_server
+ self.secondary_file = self.nsc.secondary_dir / self.safe_name
+
class Nsc:
+ start_time: datetime
zones: Dict[str, NscZone]
default_zone_config: NscZoneConfig
ipv4_reverse: DefaultDict[IPv4Address, List[Name]]
ipv6_reverse: DefaultDict[IPv6Address, List[Name]]
+ root_dir: Path
+ state_dir: Path
+ zone_dir: Path
+ secondary_dir: Path
+ daemon: 'NscDaemon' # Set by DaemonConfig class
- def __init__(self, **kwargs) -> None:
+ def __init__(self,
+ directory: str = '.',
+ daemon: Optional['NscDaemon'] = None,
+ **kwargs) -> None:
+ self.start_time = datetime.now()
self.zones = {}
self.default_zone_config = NscZoneConfig(**kwargs)
self.ipv4_reverse = defaultdict(list)
self.ipv6_reverse = defaultdict(list)
- def add_zone(self, *args, inherit_config: Optional[NscZoneConfig] = None, **kwargs) -> Zone:
+ self.root_dir = Path(directory)
+ self.state_dir = self.root_dir / 'state'
+ self.state_dir.mkdir(parents=True, exist_ok=True)
+ self.zone_dir = self.root_dir / 'zone'
+ self.zone_dir.mkdir(parents=True, exist_ok=True)
+ self.secondary_dir = self.root_dir / 'secondary'
+ self.secondary_dir.mkdir(parents=True, exist_ok=True)
+
+ if daemon is None:
+ from nsconfig.daemon import NscDaemonNull
+ daemon = NscDaemonNull()
+ self.daemon = daemon
+ daemon.setup(self)
+
+ def add_zone(self,
+ name: Optional[str] = None,
+ reverse_for: str | IPNetwork | None = None,
+ follow_primary: str | IPAddress | None = None,
+ inherit_config: Optional[NscZoneConfig] = None,
+ **kwargs) -> Zone:
if inherit_config is None:
inherit_config = self.default_zone_config
- z = NscZone(self, *args, inherit_config=inherit_config, **kwargs)
- assert z.name not in self.zones
- self.zones[z.name] = z
- return z
- def add_reverse_zone(self, net: str | IPNetwork, name: Optional[str] = None, **kwargs) -> Zone:
- if not (isinstance(net, IPv4Network) or isinstance(net, IPv6Network)):
- net = ip_network(net, strict=True)
- name = name or self._reverse_zone_name(net)
- return self.add_zone(name, reverse_for=net, **kwargs)
+ if reverse_for is not None:
+ if isinstance(reverse_for, str):
+ reverse_for = ip_network(reverse_for, strict=True)
+ name = name or self._reverse_zone_name(reverse_for)
+ assert name is not None
+ assert name not in self.zones
+
+ z: NscZone
+ if follow_primary is None:
+ z = NscZonePrimary(self, name, reverse_for=reverse_for, inherit_config=inherit_config, **kwargs)
+ else:
+ if isinstance(follow_primary, str):
+ follow_primary = ip_address(follow_primary)
+ z = NscZoneSecondary(self, name, reverse_for=reverse_for, primary_server=follow_primary, inherit_config=inherit_config, **kwargs)
+
+ self.zones[name] = z
+ return z
def _reverse_zone_name(self, net: IPNetwork) -> str:
if isinstance(net, IPv4Network):
def fill_reverse(self) -> None:
for z in self.zones.values():
- if z.reverse_for is not None:
+ if isinstance(z, NscZonePrimary) and z.reverse_for is not None:
if isinstance(z.reverse_for, IPv4Network):
for addr4, ptr_list in self.ipv4_reverse.items():
if addr4 in z.reverse_for:
for ptr_to in ptr_list:
z._add_ipv6_reverse(addr6, ptr_to)
- def dump(self) -> None:
- for z in self.zones.values():
- z.dump()
+ def get_zones(self) -> List[NscZone]:
+ return [self.zones[k] for k in sorted(self.zones.keys())]
+
+ def process(self) -> None:
+ self.fill_reverse()
+ for z in self.get_zones():
+ z.process()