]> git.ipfire.org Git - people/ms/libloc.git/commitdiff
export: Sightly refactor export logic
authorMichael Tremer <michael.tremer@ipfire.org>
Thu, 3 Mar 2022 09:33:42 +0000 (09:33 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Thu, 3 Mar 2022 09:33:42 +0000 (09:33 +0000)
This patch moves creating the "tag" (formerly known as prefix) into the
writer class, so that we can modify it based on what output format we
have.

ipset and nftables will need disjunct names for IPv6 and IPv4 because
they cannot handle mixed sets.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/python/export.py

index 3ee22946bccd09f48805a8f4943e6b27d945c41e..5b8b75ca88e8b69d90534ea6b3af5faf12109046 100644 (file)
@@ -17,6 +17,7 @@
 #                                                                             #
 ###############################################################################
 
+import functools
 import io
 import ipaddress
 import logging
@@ -46,10 +47,18 @@ class OutputWriter(object):
        # Enable network flattening (i.e. networks cannot overlap)
        flatten = False
 
-       def __init__(self, f, family=None, prefix=None):
-               self.f = f
-               self.prefix = prefix
+       def __init__(self, name, family=None, directory=None):
+               self.name = name
                self.family = family
+               self.directory = directory
+
+               # Open output file
+               if self.directory:
+                       self.f = open(self.filename, self.mode)
+               elif "b" in self.mode:
+                       self.f = io.BytesIO()
+               else:
+                       self.f = io.StringIO()
 
                # Call any custom initialization
                self.init()
@@ -63,22 +72,22 @@ class OutputWriter(object):
                """
                pass
 
-       @classmethod
-       def open(cls, filename, *args, **kwargs):
-               """
-                       Convenience function to open a file
-               """
-               if filename:
-                       f = open(filename, cls.mode)
-               elif "b" in cls.mode:
-                       f = io.BytesIO()
-               else:
-                       f = io.StringIO()
+       def __repr__(self):
+               return "<%s %s f=%s>" % (self.__class__.__name__, self, self.f)
 
-               return cls(f, *args, **kwargs)
+       @functools.cached_property
+       def tag(self):
+               families = {
+                       socket.AF_INET6 : "6",
+                       socket.AF_INET  : "4",
+               }
 
-       def __repr__(self):
-               return "<%s f=%s>" % (self.__class__.__name__, self.f)
+               return "%sv%s" % (self.name, families.get(self.family, "?"))
+
+       @functools.cached_property
+       def filename(self):
+               if self.directory:
+                       return os.path.join(self.directory, "%s.%s" % (self.tag, self.suffix))
 
        def _write_header(self):
                """
@@ -158,14 +167,14 @@ class IpsetOutputWriter(OutputWriter):
        def _write_header(self):
                # This must have a fixed size, because we will write the header again in the end
                self.f.write("create %s hash:net family inet%s" % (
-                       self.prefix,
+                       self.tag,
                        "6" if self.family == socket.AF_INET6 else ""
                ))
                self.f.write(" hashsize %8d maxelem 1048576 -exist\n" % self.hashsize)
-               self.f.write("flush %s\n" % self.prefix)
+               self.f.write("flush %s\n" % self.tag)
 
        def write(self, network):
-               self.f.write("add %s %s\n" % (self.prefix, network))
+               self.f.write("add %s %s\n" % (self.tag, network))
 
                # Increment network counter
                self.networks += 1
@@ -185,7 +194,7 @@ class NftablesOutputWriter(OutputWriter):
        suffix = "set"
 
        def _write_header(self):
-               self.f.write("define %s = {\n" % self.prefix)
+               self.f.write("define %s = {\n" % self.tag)
 
        def _write_footer(self):
                self.f.write("}\n")
@@ -199,10 +208,17 @@ class XTGeoIPOutputWriter(OutputWriter):
                Formats the output in that way, that it can be loaded by
                the xt_geoip kernel module from xtables-addons.
        """
-       suffix = "iv"
        mode = "wb"
        flatten = True
 
+       @property
+       def tag(self):
+               return self.name
+
+       @property
+       def suffix(self):
+               return "iv%s" % ("6" if self.family == socket.AF_INET6 else "4")
+
        def write(self, network):
                self.f.write(network._first_address)
                self.f.write(network._last_address)
@@ -220,8 +236,6 @@ class Exporter(object):
                self.db, self.writer = db, writer
 
        def export(self, directory, families, countries, asns):
-               filename = None
-
                for family in families:
                        log.debug("Exporting family %s" % family)
 
@@ -229,27 +243,11 @@ class Exporter(object):
 
                        # Create writers for countries
                        for country_code in countries:
-                               if directory:
-                                       filename = self._make_filename(
-                                               directory,
-                                               prefix=country_code,
-                                               suffix=self.writer.suffix,
-                                               family=family,
-                                       )
-
-                               writers[country_code] = self.writer.open(filename, family, prefix=country_code)
+                               writers[country_code] = self.writer(country_code, family=family, directory=directory)
 
                        # Create writers for ASNs
                        for asn in asns:
-                               if directory:
-                                       filename = self._make_filename(
-                                               directory,
-                                               prefix="AS%s" % asn,
-                                               suffix=self.writer.suffix,
-                                               family=family,
-                                       )
-
-                               writers[asn] = self.writer.open(filename, family, prefix="AS%s" % asn)
+                               writers[asn] = self.writer("AS%s" % asn, family=family, directory=directory)
 
                        # Filter countries from special country codes
                        country_codes = [
@@ -293,10 +291,3 @@ class Exporter(object):
                        if not directory:
                                for writer in writers.values():
                                        writer.print()
-
-       def _make_filename(self, directory, prefix, suffix, family):
-               filename = "%s.%s%s" % (
-                       prefix, suffix, "6" if family == socket.AF_INET6 else "4"
-               )
-
-               return os.path.join(directory, filename)