]> mj.ucw.cz Git - pynsc.git/commitdiff
Unify processing of arguments to record-generating methods
authorMartin Mares <mj@ucw.cz>
Sun, 21 Apr 2024 15:51:57 +0000 (17:51 +0200)
committerMartin Mares <mj@ucw.cz>
Sun, 21 Apr 2024 15:51:57 +0000 (17:51 +0200)
example/__init__.py
example/example_org.py
nsconfig/core.py
nsconfig/util.py [new file with mode: 0644]

index d7dd2192ad1344615c8772b742511d51b93798aa..4497c6b6ceddb291e8f6a6ddee34f0220c22be0b 100644 (file)
@@ -9,7 +9,7 @@ nsc = Nsc(
 
 for rev in ['10.1.0.0/16', '10.2.0.0/16', 'fd12:3456:789a::/48']:
     rz = nsc.add_zone(reverse_for=rev)
-    rz[""].NS(['ns1.example.org', 'ns2.example.org'])
+    rz[""].NS('ns1.example.org', 'ns2.example.org')
 
 nsc.add_zone('example.net', follow_primary='10.42.0.1')
 
index dfd95ddf48aa5f83375534cc1690fe2d605395cc..fa264e56d1857a0fc46121c4a86d0521504f9ff2 100644 (file)
@@ -3,7 +3,7 @@ from example import nsc
 z = nsc.add_zone('example.org')
 
 (z[""]
-    .NS(['ns1', 'ns2'])
+    .NS('ns1', 'ns2')
     .MX(0, 'mail')
     .MX(10, 'mail.example.net')
     .TXT('Litera scripta manet'))
index 8a3bb74f694453655887758f39063fa20182779e..7613061f591443885c21c3c746952a40ffa74ad5 100644 (file)
@@ -23,6 +23,8 @@ import socket
 import sys
 from typing import Optional, Dict, List, Self, Tuple, DefaultDict, TextIO, TYPE_CHECKING
 
+from nsconfig.util import flatten_list
+
 
 if TYPE_CHECKING:
     from nsconfig.daemon import NscDaemon
@@ -56,17 +58,13 @@ class NscNode:
         rds = self.node.find_rdataset(rec.rdclass, rec.rdtype, create=True)
         rds.add(rec, ttl=self._ttl)
 
-    def _parse_addrs(self, addrs: Tuple[IPAddr, ...]) -> List[IPAddress]:
-        out = []
-        for a in addrs:
-            if not isinstance(a, list):
-                a = [a]
-            for b in a:
-                if isinstance(b, IPv4Address) or isinstance(b, IPv6Address):
-                    out.append(b)
-                else:
-                    out.append(ip_address(b))
-        return out
+    def _parse_addr(self, addr: IPAddr | str) -> IPAddress:
+        if isinstance(addr, IPv4Address) or isinstance(addr, IPv6Address):
+            return addr
+        elif isinstance(addr, str):
+            return ip_address(addr)
+        else:
+            raise ValueError('Cannot parse IP address')
 
     def _parse_name(self, name: str) -> Name:
         # FIXME: Names with escaped dots
@@ -75,14 +73,8 @@ class NscNode:
         else:
             return dns.name.from_text(name, origin=None)
 
-    def _parse_names(self, names: str | List[str]) -> List[Name]:
-        if isinstance(names, str):
-            return [self._parse_name(names)]
-        else:
-            return [self._parse_name(n) for n in names]
-
     def A(self, *addrs: IPAddr, reverse: bool = True) -> Self:
-        for a in self._parse_addrs(addrs):
+        for a in map(self._parse_addr, flatten_list(addrs)):
             if isinstance(a, IPv4Address):
                 self._add(dns.rdtypes.IN.A.A(RdataClass.IN, RdataType.A, str(a)))
             else:
@@ -97,14 +89,14 @@ class NscNode:
         )
         return self
 
-    def NS(self, names: str | List[str]) -> Self:
-        # FIXME: Variadic?
-        for name in self._parse_names(names):
+    def NS(self, *names: str | List[str]) -> Self:
+        for name in map(self._parse_name, flatten_list(names)):
             self._add(dns.rdtypes.ANY.NS.NS(RdataClass.IN, RdataType.NS, name))
         return self
 
-    def TXT(self, text: str) -> Self:
-        self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, text))
+    def TXT(self, *text: str | List[str]) -> Self:
+        for txt in flatten_list(text):
+            self._add(dns.rdtypes.ANY.TXT.TXT(RdataClass.IN, RdataType.TXT, txt))
         return self
 
     def PTR(self, target: Name | str) -> Self:
diff --git a/nsconfig/util.py b/nsconfig/util.py
new file mode 100644 (file)
index 0000000..7b9370c
--- /dev/null
@@ -0,0 +1,14 @@
+from typing import Any, List
+
+
+def flatten_list(args: Any) -> List[Any]:
+    def flat(args):
+        if isinstance(args, list) or isinstance(args, tuple):
+            for a in args:
+                flat(a)
+        else:
+            out.append(args)
+
+    out: List[Any] = []
+    flat(args)
+    return out