]> mj.ucw.cz Git - pynsc.git/commitdiff
Initial commit
authorMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 18:33:54 +0000 (20:33 +0200)
committerMartin Mares <mj@ucw.cz>
Sat, 20 Apr 2024 18:33:54 +0000 (20:33 +0200)
nsc.py [new file with mode: 0755]

diff --git a/nsc.py b/nsc.py
new file mode 100755 (executable)
index 0000000..41abe3c
--- /dev/null
+++ b/nsc.py
@@ -0,0 +1,200 @@
+#!/usr/bin/env python3
+
+from dataclasses import dataclass
+from datetime import timedelta
+import dns.name
+from dns.name import Name
+from dns.node import Node
+from dns.rdata import Rdata
+from dns.rdataclass import RdataClass
+from dns.rdatatype import RdataType
+import dns.rdtypes.ANY.MX
+import dns.rdtypes.ANY.NS
+import dns.rdtypes.ANY.SOA
+import dns.rdtypes.ANY.TXT
+import dns.rdtypes.IN.A
+import dns.rdtypes.IN.AAAA
+from dns.zone import Zone
+from ipaddress import ip_address, IPv4Address, IPv6Address
+import socket
+import sys
+from typing import Optional, Dict, List, Self
+
+
+IPAddress = IPv4Address | IPv6Address
+
+
+class NscNode:
+    nsc_zone: 'NscZone'
+    name: str
+    node: Node
+    _ttl: int
+
+    def __init__(self, nsc_zone: 'NscZone', name: str) -> None:
+        self.nsc_zone = nsc_zone
+        self.name = name
+        self.node = nsc_zone.zone.find_node(name, create=True)
+        self._ttl = int(nsc_zone.min_ttl.total_seconds())
+
+    def ttl(self, *args, **kwargs) -> Self:
+        if not args and not kwargs:
+            self._ttl = int(self.nsc_zone.min_ttl.total_seconds())
+        else:
+            self._ttl = int(timedelta(*args, **kwargs).total_seconds())
+        return self
+
+    def _add(self, rec: Rdata) -> None:
+        rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
+        rds.add(rec, ttl=self._ttl)
+
+    def _parse_addrs(self, addrs: str | IPAddress | List[str | IPAddress]) -> List[IPAddress]:
+        if not isinstance(addrs, list):
+            addrs = [addrs]
+        out = []
+        for a in addrs:
+            if isinstance(a, IPv4Address) or isinstance(a, IPv6Address):
+                out.append(a)
+            else:
+                out.append(ip_address(a))
+        return out
+
+    def _parse_name(self, name: str) -> Name:
+        # FIXME: Names with escaped dots
+        if '.' in name:
+            return dns.name.from_text(name)
+        else:
+            return dns.name.from_text(name, origin=None)
+
+    def _parse_names(self, names: str | List[str]) -> List[Name]:
+        if isinstance(names, str):
+            return [self._parse_name(names)]
+        else:
+            return [self._parse_name(n) for n in names]
+
+    def A(self, addrs: str | IPAddress | List[str | IPAddress]) -> Self:
+        for a in self._parse_addrs(addrs):
+            if isinstance(a, IPv4Address):
+                self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
+            else:
+                self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
+        return self
+
+    def MX(self, pri: int, name: str) -> Self:
+        self._add(
+            dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
+        )
+        return self
+
+    def NS(self, names: str | List[str]) -> Self:
+        for name in self._parse_names(names):
+            self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
+        return self
+
+    def TXT(self, text: str) -> Self:
+        self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
+        return self
+
+    def generic(self, typ: str, text: str) -> Self:
+        self._add(dns.rdata.from_text(RdataClass.IN, typ, text))
+
+
+class NscZone:
+    name: str
+    admin_email: Optional[str] = None
+    refresh: timedelta = timedelta(hours=8)
+    retry: timedelta = timedelta(hours=2)
+    expire: timedelta = timedelta(days=14)
+    min_ttl: timedelta = timedelta(days=1)
+    origin_server: Optional[str] = None
+    zone: Zone
+
+    def __init__(self,
+                 name: str,
+                 admin_email: Optional[str] = None,
+                 refresh: Optional[timedelta] = None,
+                 retry: Optional[timedelta] = None,
+                 expire: Optional[timedelta] = None,
+                 min_ttl: Optional[timedelta] = None,
+                 origin_server: Optional[str] = None,
+                 ) -> None:
+        self.name = name
+        self.admin_email = admin_email if admin_email is not None else self.admin_email
+        self.refresh = refresh if refresh is not None else self.refresh
+        self.retry = retry if retry is not None else self.retry
+        self.expire = expire if expire is not None else self.expire
+        self.min_ttl = min_ttl if min_ttl is not None else self.min_ttl
+        self.origin_server = origin_server if origin_server is not None else self.origin_server
+        self.zone = dns.zone.Zone(origin=name, rdclass=RdataClass.IN)
+
+        if self.origin_server is None:
+            self.origin_server = socket.getfqdn()
+
+        if self.admin_email is None:
+            self.admin_email = f'root@{self.origin_server}'
+
+        root = self[""]
+        root._add(
+            dns.rdtypes.ANY.SOA.SOA(
+                RdataClass.IN, RdataType.SOA,
+                mname=self.origin_server,
+                rname=self.admin_email.replace('@', '.'),   # FIXME: names with dots
+                serial=12345,
+                refresh=int(self.refresh.total_seconds()),
+                retry=int(self.retry.total_seconds()),
+                expire=int(self.expire.total_seconds()),
+                minimum=int(self.min_ttl.total_seconds()),
+            )
+        )
+
+    def n(self, name: str) -> NscNode:
+        return NscNode(self, name)
+
+    def __getitem__(self, name: str) -> NscNode:
+        return NscNode(self, name)
+
+    def dump(self) -> None:
+        # Could use self.zone.to_file(sys.stdout), but we want better formatting
+        last_name = None
+        min_ttl = int(self.min_ttl.total_seconds())
+        for name, ttl, rec in self.zone.iterate_rdatas():
+            if name == last_name:
+                print_name = ""
+            else:
+                print_name = name
+            print(f'{print_name}\t{ttl if ttl != min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}')
+            last_name = name
+
+
+class Config:
+    zones: Dict[str, Zone]
+
+    def __init__(self) -> None:
+        self.zones = {}
+
+    def add_zone(self, *args, **kwargs) -> Zone:
+        dom = NscZone(*args, **kwargs)
+        assert dom.name not in self.zones
+        self.zones[dom.name] = dom
+        return dom
+
+
+class MyZone(Zone):
+    admin_email = 'admin@ucw.cz'
+    origin_server = 'ns.ucw.cz'
+
+
+c = Config()
+z = c.add_zone('ucw.cz')  # origin_server='jabberwock.ucw.cz')
+
+z[""].NS(['jabberwock', 'chirigo.gebbeth.cz', 'drak.ucw.cz'])
+
+z['jabberwock'].A(['1.2.3.4', '2a00:da80:fff0:2::2'])
+
+(z['mnau']
+    .A('195.113.31.123')
+    .MX(0, 'jabberwock')
+    .ttl(minutes=15)
+    .TXT('hey?')
+    .generic('HINFO', 'Something fishy'))
+
+z.dump()