]> mj.ucw.cz Git - pynsc.git/commitdiff
Decouple minimum TTL from default TTL
authorMartin Mares <mj@ucw.cz>
Mon, 22 Apr 2024 09:15:40 +0000 (11:15 +0200)
committerMartin Mares <mj@ucw.cz>
Mon, 22 Apr 2024 09:19:56 +0000 (11:19 +0200)
Also clean up processing of time duration and always use seconds
as internal representation.

TODO
example/example_org.py
nsconfig/core.py
nsconfig/util.py

diff --git a/TODO b/TODO
index 36631b1c8fa28a1ebfc116f355d0cfb4c93d0f70..9397d48b71eee55287eaf2411d7debb8463322f6 100644 (file)
--- a/TODO
+++ b/TODO
@@ -2,7 +2,5 @@
 - E-mail addresses with dots in SOA
 - Blackhole zones
 - DNSSEC
 - E-mail addresses with dots in SOA
 - Blackhole zones
 - DNSSEC
-- Automated generation of Null MX
 - Logging
 - Use dns.reversename.from_address?
 - Logging
 - Use dns.reversename.from_address?
-- Decouple min_ttl from default TTL
index 02c5afee3a840233c9506d465343d62ba16317bf..ae034cff69f568983361c010492d3ec111dbfa63 100644 (file)
@@ -1,15 +1,19 @@
+from datetime import timedelta
 from example import nsc
 
 z = nsc.add_zone(
     'example.org',
     daemon_options=['check-integrity yes;'],
     add_null_mx=True,
 from example import nsc
 
 z = nsc.add_zone(
     'example.org',
     daemon_options=['check-integrity yes;'],
     add_null_mx=True,
+    default_ttl=timedelta(hours=8),
 )
 
 (z[""]
     .NS('ns1', 'ns2')
 )
 
 (z[""]
     .NS('ns1', 'ns2')
+    .ttl(60)
     .MX(0, 'mail')
     .MX(10, 'mail.example.net')
     .MX(0, 'mail')
     .MX(10, 'mail.example.net')
+    .ttl()
     .TXT('Litera scripta manet'))
 
 z.host('ns1', '10.1.0.1', 'fd12:3456:789a:1::1')
     .TXT('Litera scripta manet'))
 
 z.host('ns1', '10.1.0.1', 'fd12:3456:789a:1::1')
index a091ffeedb32851c234eb50b5e58e3e92acd29ca..f8c96615c1246d0dfb4df1d9800189e42c1effd8 100644 (file)
@@ -24,7 +24,7 @@ import socket
 import sys
 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
 
 import sys
 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, TYPE_CHECKING
 
-from nsconfig.util import flatten_list, parse_address, parse_network, parse_name
+from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration
 from nsconfig.util import IPAddress, IPNetwork, IPAddr
 
 
 from nsconfig.util import IPAddress, IPNetwork, IPAddr
 
 
@@ -42,13 +42,15 @@ class NscNode:
         self.nsc_zone = nsc_zone
         self.name = name
         self.node = nsc_zone.zone.find_node(name, create=True)
         self.nsc_zone = nsc_zone
         self.name = name
         self.node = nsc_zone.zone.find_node(name, create=True)
-        self._ttl = nsc_zone._min_ttl
+        self._ttl = nsc_zone.config.default_ttl
 
 
-    def ttl(self, *args, **kwargs) -> Self:
-        if not args and not kwargs:
-            self._ttl = self.nsc_zone._min_ttl
+    def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
+        if seconds is not None:
+            self._ttl = seconds
+        elif kwargs:
+            self._ttl = parse_duration(timedelta(**kwargs))
         else:
         else:
-            self._ttl = int(timedelta(*args, **kwargs).total_seconds())
+            self._ttl = self.nsc_zone.config.default_ttl
         return self
 
     def _add(self, rec: Rdata) -> None:
         return self
 
     def _add(self, rec: Rdata) -> None:
@@ -96,10 +98,11 @@ class NscNode:
 
 class NscZoneConfig:
     admin_email: str
 
 class NscZoneConfig:
     admin_email: str
-    refresh: timedelta
-    retry: timedelta
-    expire: timedelta
-    min_ttl: timedelta
+    refresh: int
+    retry: int
+    expire: int
+    min_ttl: int
+    default_ttl: int
     origin_server: str
     daemon_options: List[str]
     add_null_mx: bool
     origin_server: str
     daemon_options: List[str]
     add_null_mx: bool
@@ -108,10 +111,11 @@ class NscZoneConfig:
 
     def __init__(self,
                  admin_email: Optional[str] = None,
 
     def __init__(self,
                  admin_email: Optional[str] = None,
-                 refresh: Optional[timedelta] = None,
-                 retry: Optional[timedelta] = None,
-                 expire: Optional[timedelta] = None,
-                 min_ttl: Optional[timedelta] = None,
+                 refresh: Optional[int | timedelta] = None,
+                 retry: Optional[int | timedelta] = None,
+                 expire: Optional[int | timedelta] = None,
+                 min_ttl: Optional[int | timedelta] = None,
+                 default_ttl: Optional[int | timedelta] = None,
                  origin_server: Optional[str] = None,
                  daemon_options: Optional[List[str]] = None,
                  add_daemon_options: Optional[List[str]] = None,
                  origin_server: Optional[str] = None,
                  daemon_options: Optional[List[str]] = None,
                  add_daemon_options: Optional[List[str]] = None,
@@ -121,10 +125,11 @@ class NscZoneConfig:
         if inherit_config is None:
             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
         if inherit_config is None:
             inherit_config = NscZoneConfig.default_config or self   # to satisfy the type checker
         self.admin_email = admin_email if admin_email is not None else inherit_config.admin_email
-        self.refresh = refresh if refresh is not None else inherit_config.refresh
-        self.retry = retry if retry is not None else inherit_config.retry
-        self.expire = expire if expire is not None else inherit_config.expire
-        self.min_ttl = min_ttl if min_ttl is not None else inherit_config.min_ttl
+        self.refresh = parse_duration(refresh) if refresh is not None else inherit_config.refresh
+        self.retry = parse_duration(retry) if retry is not None else inherit_config.retry
+        self.expire = parse_duration(expire) if expire is not None else inherit_config.expire
+        self.min_ttl = parse_duration(min_ttl) if min_ttl is not None else inherit_config.min_ttl
+        self.default_ttl = parse_duration(default_ttl) if default_ttl is not None else inherit_config.default_ttl
         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
@@ -136,6 +141,8 @@ class NscZoneConfig:
             self.origin_server = socket.getfqdn()
         if not self.admin_email:
             self.admin_email = f'hostmaster@{self.origin_server}'
             self.origin_server = socket.getfqdn()
         if not self.admin_email:
             self.admin_email = f'hostmaster@{self.origin_server}'
+        if self.default_ttl == 0:
+            self.default_ttl = self.min_ttl
         return self
 
 
         return self
 
 
@@ -145,6 +152,7 @@ NscZoneConfig.default_config = NscZoneConfig(
     retry=timedelta(hours=2),
     expire=timedelta(days=14),
     min_ttl=timedelta(days=1),
     retry=timedelta(hours=2),
     expire=timedelta(days=14),
     min_ttl=timedelta(days=1),
+    default_ttl=0,
     origin_server="",
     daemon_options=[],
     add_null_mx=False,
     origin_server="",
     daemon_options=[],
     add_null_mx=False,
@@ -211,7 +219,6 @@ class NscZone:
 
 class NscZonePrimary(NscZone):
     zone: Zone
 
 class NscZonePrimary(NscZone):
     zone: Zone
-    _min_ttl: int
     zone_file: Path
     state_file: Path
     state: NscZoneState
     zone_file: Path
     state_file: Path
     state: NscZoneState
@@ -229,7 +236,6 @@ class NscZonePrimary(NscZone):
         self.prev_state.load(self.state_file)
 
         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
         self.prev_state.load(self.state_file)
 
         self.zone = dns.zone.Zone(origin=self.name, rdclass=RdataClass.IN)
-        self._min_ttl = int(self.config.min_ttl.total_seconds())
         self.update_soa()
 
     def update_soa(self) -> None:
         self.update_soa()
 
     def update_soa(self) -> None:
@@ -239,10 +245,10 @@ class NscZonePrimary(NscZone):
             mname=conf.origin_server,
             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
             serial=self.state.serial,
             mname=conf.origin_server,
             rname=conf.admin_email.replace('@', '.'),   # FIXME: names with dots
             serial=self.state.serial,
-            refresh=int(conf.refresh.total_seconds()),
-            retry=int(conf.retry.total_seconds()),
-            expire=int(conf.expire.total_seconds()),
-            minimum=int(conf.min_ttl.total_seconds()),
+            refresh=conf.refresh,
+            retry=conf.retry,
+            expire=conf.expire,
+            minimum=conf.min_ttl,
         )
         self.zone.delete_rdataset("", RdataType.SOA)
         self[""]._add(soa)
         )
         self.zone.delete_rdataset("", RdataType.SOA)
         self[""]._add(soa)
@@ -268,13 +274,14 @@ class NscZonePrimary(NscZone):
         # Could use self.zone.to_file(sys.stdout), but we want better formatting
         file = file or sys.stdout
         file.write(self.zone_header())
         # Could use self.zone.to_file(sys.stdout), but we want better formatting
         file = file or sys.stdout
         file.write(self.zone_header())
+        file.write(f'$TTL\t\t{self.config.default_ttl}\n\n')
         last_name = None
         for name, ttl, rec in self.zone.iterate_rdatas():
             if name == last_name:
                 print_name = ""
             else:
                 print_name = name
         last_name = None
         for name, ttl, rec in self.zone.iterate_rdatas():
             if name == last_name:
                 print_name = ""
             else:
                 print_name = name
-            file.write(f'{print_name}\t{ttl if ttl != self._min_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
+            file.write(f'{print_name}\t{ttl if ttl != self.config.default_ttl else ""}\t{rec.rdtype.name}\t{rec.to_text()}\n')
             last_name = name
 
     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
             last_name = name
 
     def _add_ipv4_reverse(self, addr: IPv4Address, ptr_to: Name) -> None:
@@ -361,7 +368,7 @@ class NscZonePrimary(NscZone):
                 if not mx_rds:
                     mx_rds.add(
                         dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
                 if not mx_rds:
                     mx_rds.add(
                         dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, 0, dns.name.root),
-                        ttl=self._min_ttl,
+                        ttl=self.config.default_ttl,
                     )
 
 
                     )
 
 
index 48e2e6f7d39bc70fd598b2aae04f3fdebffe3eca..6acc1d8a2d183e853bbf3af09c81e9c5f0babd83 100644 (file)
@@ -1,6 +1,7 @@
 import dns.name
 from dns.name import Name
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
 import dns.name
 from dns.name import Name
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
+from datetime import timedelta
 from typing import Any, List
 
 
 from typing import Any, List
 
 
@@ -46,3 +47,12 @@ def parse_name(name: str, relative: bool = False) -> Name:
         return dns.name.from_text(name)
     else:
         return dns.name.from_text(name, origin=None)
         return dns.name.from_text(name)
     else:
         return dns.name.from_text(name, origin=None)
+
+
+def parse_duration(delta: timedelta | int) -> int:
+    if isinstance(delta, timedelta):
+        return int(delta.total_seconds())
+    elif isinstance(delta, int):
+        return delta
+    else:
+        raise ValueError('Cannot parse time duration')