]> mj.ucw.cz Git - pynsc.git/commitdiff
Interpretation of names with dots made configurable
authorMartin Mares <mj@ucw.cz>
Wed, 7 Aug 2024 20:23:58 +0000 (22:23 +0200)
committerMartin Mares <mj@ucw.cz>
Wed, 7 Aug 2024 20:23:58 +0000 (22:23 +0200)
TODO
example/example_org.py
nsconfig/__init__.py
nsconfig/core.py
nsconfig/util.py

diff --git a/TODO b/TODO
index 9009b862c73953508b74758f74d92f604f2ef733..a45f602bbc8f6fe77191c32777b37cae6b178d3d 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,5 +1,5 @@
-- Names with dots
 - E-mail addresses with dots in SOA
 - DNSSEC
 - Logging
 - More records
+- follow_primary -> secondary_to?
index 905f9350832233915ea6793a7e5ace49cc59818b..2f7fe12efc3f882abdb0f09704de13e6abb9b60e 100644 (file)
@@ -12,6 +12,7 @@ z = nsc.add_zone(
     .NS('ns1', 'ns2')
     .ttl(60)
     .MX(0, 'mail')
+    .MX(1, 'mail.backup.@')
     .MX(10, 'mail.example.net')
     .ttl()
     .TXT('Litera scripta manet'))
@@ -24,4 +25,6 @@ z.host('ns2', '10.2.0.1', 'fd12:3456:789a:2::1')
     .MX(0, 'mail')
     .MX(10, 'mail.example.net'))
 
+z.host('mail.backup', '10.1.0.3', 'fd12:3456:789a:1::3')
+
 nsc.add_zone('example.com', alias_for=z)
index f9a05410252ae8d89aa961ae31c7e3fed791e594..c4c15dc16d1797f3fc0dc3b449a2a3018287b5e7 100644 (file)
@@ -2,3 +2,4 @@
 # (c) 2024 Martin Mareš <mj@ucw.cz>
 
 from nsconfig.core import Nsc, NscZone, NscZoneConfig, NscNode
+from nsconfig.util import NameParseMode
index 5ffc534d458d4c1a460ec69fe8c70d7c1c12f0ef..b10b53f94740c0fd71421fd2f047d177ff0e8291 100644 (file)
@@ -30,7 +30,7 @@ import sys
 from typing import Optional, Dict, List, Self, DefaultDict, TextIO, Tuple, TYPE_CHECKING
 
 from nsconfig.util import flatten_list, parse_address, parse_network, parse_name, parse_duration
-from nsconfig.util import IPAddress, IPNetwork, IPAddr
+from nsconfig.util import IPAddress, IPNetwork, IPAddr, NameParseMode
 
 
 if TYPE_CHECKING:
@@ -46,7 +46,7 @@ class NscNode:
     def __init__(self, nsc_zone: 'NscZonePrimary', name: str) -> None:
         self.nsc_zone = nsc_zone
         self.name = name
-        self.node = nsc_zone.zone.find_node(name, create=True)
+        self.node = nsc_zone.zone.find_node(parse_name(name, NameParseMode.relative), create=True)
         self._ttl = nsc_zone.config.default_ttl
 
     def ttl(self, seconds: Optional[int] = None, **kwargs) -> Self:
@@ -62,6 +62,9 @@ class NscNode:
         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
         rds.add(rec, ttl=self._ttl)
 
+    def _parse_name(self, name, **kwargs):
+        return parse_name(name, mode=self.nsc_zone.config.name_parse_mode, **kwargs)
+
     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
         for a in map(parse_address, flatten_list(addrs)):
             if isinstance(a, IPv4Address):
@@ -69,7 +72,7 @@ class NscNode:
             else:
                 self._add(dns.rdtypes.IN.AAAA.AAAA(RdataClass.IN, RdataType.AAAA, str(a)))
             if reverse:
-                self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name, origin=self.nsc_zone.dns_name))
+                self.nsc_zone.nsc._add_reverse_mapping(a, parse_name(self.name, mode=NameParseMode.relative, origin=self.nsc_zone.dns_name))
         return self
 
     def CNAME(self, target: Name | str) -> Self:
@@ -82,7 +85,7 @@ class NscNode:
 
     def MX(self, pri: int, name: str) -> Self:
         self._add(
-            dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, parse_name(name))
+            dns.rdtypes.ANY.MX.MX(RdataClass.IN, RdataType.MX, pri, self._parse_name(name))
         )
         return self
 
@@ -92,8 +95,8 @@ class NscNode:
         return self
 
     def NS(self, *names: str | List[str]) -> Self:
-        for name in map(parse_name, flatten_list(names)):
-            self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
+        for name in flatten_list(names):
+            self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, self._parse_name(name)))
         return self
 
     def PTR(self, target: Name | str) -> Self:
@@ -101,7 +104,7 @@ class NscNode:
         return self
 
     def SRV(self, priority: int, weight: int, port: int, target: Name | str) -> Self:
-        self._add(dns.rdtypes.IN.SRV.SRV(RdataClass.IN, RdataType.SRV, priority, weight, port, parse_name(target)))
+        self._add(dns.rdtypes.IN.SRV.SRV(RdataClass.IN, RdataType.SRV, priority, weight, port, self._parse_name(target)))
         return self
 
     def TXT(self, *text: str | List[str]) -> Self:
@@ -130,6 +133,7 @@ class NscZoneConfig:
     origin_server: str
     daemon_options: List[str]
     add_null_mx: bool
+    name_parse_mode: NameParseMode
 
     default_config: Optional['NscZoneConfig'] = None
 
@@ -145,6 +149,7 @@ class NscZoneConfig:
                  daemon_options: Optional[List[str]] = None,
                  add_daemon_options: Optional[List[str]] = None,
                  add_null_mx: Optional[bool] = None,
+                 name_parse_mode: Optional[NameParseMode] = None,
                  inherit_config: Optional['NscZoneConfig'] = None,
                  ) -> None:
         if inherit_config is None:
@@ -158,6 +163,7 @@ class NscZoneConfig:
         self.origin_server = origin_server if origin_server is not None else inherit_config.origin_server
         self.daemon_options = daemon_options if daemon_options is not None else inherit_config.daemon_options
         self.add_null_mx = add_null_mx if add_null_mx is not None else inherit_config.add_null_mx
+        self.name_parse_mode = name_parse_mode if name_parse_mode is not None else inherit_config.name_parse_mode
         if add_daemon_options is not None:
             self.daemon_options += add_daemon_options
 
@@ -181,6 +187,7 @@ NscZoneConfig.default_config = NscZoneConfig(
     origin_server="",
     daemon_options=[],
     add_null_mx=False,
+    name_parse_mode=NameParseMode.absolute,
 )
 
 
@@ -227,6 +234,7 @@ class NscZone:
     dns_name: Name
     safe_name: str                          # For use in file names
     zone_type: ZoneType
+    config: NscZoneConfig
     reverse_for: Optional[IPNetwork]
 
     def __init__(self,
@@ -393,7 +401,7 @@ class NscZonePrimary(NscZone):
 
         for i in range(start, start + num):
             target = f'{i}.{subdomain}'
-            self[str(i)].CNAME(parse_name(target, relative=True))
+            self[str(i)].CNAME(parse_name(target, mode=NameParseMode.relative))
 
         return self[subdomain]
 
index 21cff80f47cb129322f8975a72e13c7a2a949e67..0faf10c6ced2722f933474a2e701f477e651745a 100644 (file)
@@ -3,6 +3,7 @@
 
 import dns.name
 from dns.name import Name
+from enum import Enum, auto
 from ipaddress import ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network, IPv6Network
 from datetime import timedelta
 from typing import Any, List, Optional
@@ -13,6 +14,17 @@ IPNetwork = IPv4Network | IPv6Network
 IPAddr = str | IPAddress | List[str | IPAddress]
 
 
+class NameParseMode(Enum):
+    # How to parse DNS names (first matching rule wins):
+    #    - names with no dots are always relative
+    #    - names ending with ".@" are also relative
+    #    - names ending with "." are always absolute
+    #    - names on the left-hand side of records are relative
+    #    - other names are interpreted according to the parsing mode
+    absolute = auto()    # default
+    relative = auto()
+
+
 def flatten_list(args: Any) -> List[Any]:
     def flat(args):
         if isinstance(args, list) or isinstance(args, tuple):
@@ -44,12 +56,18 @@ def parse_network(addr: IPNetwork | str) -> IPNetwork:
         raise ValueError('Cannot parse IP network')
 
 
-def parse_name(name: str, relative: bool = False, origin: Optional[Name] = None) -> Name:
-    # FIXME: Names with escaped dots
-    if '.' in name and not relative and origin is None:
-        return dns.name.from_text(name)
-    else:
+def parse_name(name: str, mode: NameParseMode = NameParseMode.relative, origin: Optional[Name] = None) -> Name:
+    if name.endswith('.@'):
+        return dns.name.from_text(name[:-2], origin=origin)
+    if mode == NameParseMode.relative:
         return dns.name.from_text(name, origin=origin)
+    elif mode == NameParseMode.absolute:
+        if '.' in name:
+            return dns.name.from_text(name)
+        else:
+            return dns.name.from_text(name, origin=origin)
+    else:
+        ...
 
 
 def parse_duration(delta: timedelta | int) -> int: