]> git.ipfire.org Git - location/libloc.git/commitdiff
Merge location-exporter(8) into location(8)
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 3 Jun 2020 18:36:28 +0000 (18:36 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 3 Jun 2020 19:01:31 +0000 (19:01 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
.gitignore
Makefile.am
src/python/export.py [new file with mode: 0644]
src/python/location-exporter.in [deleted file]
src/python/location.in

index 6ab8e9165eec48b97f36af3f65784e6b5fb13d0a..7f4e33b0e789c6828b02d5059c3cbb0a4a3d2ef1 100644 (file)
@@ -15,7 +15,6 @@ Makefile.in
 /libtool
 /stamp-h1
 /src/python/location
-/src/python/location-exporter
 /src/python/location-importer
 /src/systemd/location-update.service
 /src/systemd/location-update.timer
index 59870b1bee1127f0f6fca1022770ea1026e40353..9f520cc6079bd4b91d7414257673fc7c2d024931 100644 (file)
@@ -150,6 +150,7 @@ dist_pkgpython_PYTHON = \
        src/python/__init__.py \
        src/python/database.py \
        src/python/downloader.py \
+       src/python/export.py \
        src/python/i18n.py \
        src/python/importer.py \
        src/python/logger.py
@@ -239,17 +240,14 @@ uninstall-perl:
 
 bin_SCRIPTS = \
        src/python/location \
-       src/python/location-exporter \
        src/python/location-importer
 
 EXTRA_DIST += \
        src/python/location.in \
-       src/python/location-exporter.in \
        src/python/location-importer.in
 
 CLEANFILES += \
        src/python/location \
-       src/python/location-exporter \
        src/python/location-importer
 
 # ------------------------------------------------------------------------------
diff --git a/src/python/export.py b/src/python/export.py
new file mode 100644 (file)
index 0000000..69fe964
--- /dev/null
@@ -0,0 +1,185 @@
+#!/usr/bin/python3
+###############################################################################
+#                                                                             #
+# libloc - A library to determine the location of someone on the Internet     #
+#                                                                             #
+# Copyright (C) 2020 IPFire Development Team <info@ipfire.org>                #
+#                                                                             #
+# This library is free software; you can redistribute it and/or               #
+# modify it under the terms of the GNU Lesser General Public                  #
+# License as published by the Free Software Foundation; either                #
+# version 2.1 of the License, or (at your option) any later version.          #
+#                                                                             #
+# This library is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU           #
+# Lesser General Public License for more details.                             #
+#                                                                             #
+###############################################################################
+
+import io
+import ipaddress
+import logging
+import os
+import socket
+
+# Initialise logging
+log = logging.getLogger("location.export")
+log.propagate = 1
+
+class OutputWriter(object):
+       suffix = "networks"
+       mode = "w"
+
+       def __init__(self, f, prefix=None):
+               self.f, self.prefix = f, prefix
+
+               # Immediately write the header
+               self._write_header()
+
+       @classmethod
+       def open(cls, filename, **kwargs):
+               """
+                       Convenience function to open a file
+               """
+               f = open(filename, cls.mode)
+
+               return cls(f, **kwargs)
+
+       def __repr__(self):
+               return "<%s f=%s>" % (self.__class__.__name__, self.f)
+
+       def _write_header(self):
+               """
+                       The header of the file
+               """
+               pass
+
+       def _write_footer(self):
+               """
+                       The footer of the file
+               """
+               pass
+
+       def write(self, network):
+               self.f.write("%s\n" % network)
+
+       def finish(self):
+               """
+                       Called when all data has been written
+               """
+               self._write_footer()
+
+               # Close the file
+               self.f.close()
+
+
+class IpsetOutputWriter(OutputWriter):
+       """
+               For ipset
+       """
+       suffix = "ipset"
+
+       def _write_header(self):
+               self.f.write("create %s hash:net family inet hashsize 1024 maxelem 65536\n" % self.prefix)
+
+       def write(self, network):
+               self.f.write("add %s %s\n" % (self.prefix, network))
+
+
+class NftablesOutputWriter(OutputWriter):
+       """
+               For nftables
+       """
+       suffix = "set"
+
+       def _write_header(self):
+               self.f.write("define %s = {\n" % self.prefix)
+
+       def _write_footer(self):
+               self.f.write("}\n")
+
+       def write(self, network):
+               self.f.write("  %s,\n" % network)
+
+
+class XTGeoIPOutputWriter(OutputWriter):
+       """
+               Formats the output in that way, that it can be loaded by
+               the xt_geoip kernel module from xtables-addons.
+       """
+       suffix = "iv"
+       mode = "wb"
+
+       def write(self, network):
+               n = ipaddress.ip_network("%s" % network)
+
+               for address in (n.network_address, n.broadcast_address):
+                       bytes = socket.inet_pton(
+                               socket.AF_INET6 if address.version == 6 else socket.AF_INET,
+                               "%s" % address,
+                       )
+
+                       self.f.write(bytes)
+
+
+formats = {
+       "ipset"    : IpsetOutputWriter,
+       "list"     : OutputWriter,
+       "nftables" : NftablesOutputWriter,
+       "xt_geoip" : XTGeoIPOutputWriter,
+}
+
+class Exporter(object):
+       def __init__(self, db, writer):
+               self.db, self.writer = db, writer
+
+       def export(self, directory, families, countries, asns):
+               for family in families:
+                       log.debug("Exporting family %s" % family)
+
+                       writers = {}
+
+                       # Create writers for countries
+                       for country_code in countries:
+                               filename = self._make_filename(
+                                       directory, prefix=country_code, suffix=self.writer.suffix, family=family,
+                               )
+
+                               writers[country_code] = self.writer.open(filename, prefix="CC_%s" % country_code)
+
+                       # Create writers for ASNs
+                       for asn in asns:
+                               filename = self._make_filename(
+                                       directory, "AS%s" % asn, suffix=self.writer.suffix, family=family,
+                               )
+
+                               writers[asn] = self.writer.open(filename, prefix="AS%s" % asn)
+
+                       # Get all networks that match the family
+                       networks = self.db.search_networks(family=family)
+
+                       # Walk through all networks
+                       for network in networks:
+                               # Write matching countries
+                               try:
+                                       writers[network.country_code].write(network)
+                               except KeyError:
+                                       pass
+
+                               # Write matching ASNs
+                               try:
+                                       writers[network.asn].write(network)
+                               except KeyError:
+                                       pass
+
+                       # Write everything to the filesystem
+                       for writer in writers.values():
+                               writer.finish()
+
+       def _make_filename(self, directory, prefix, suffix, family):
+               filename = "%s.%s%s" % (
+                       prefix, suffix, "6" if family == socket.AF_INET6 else "4"
+               )
+
+               return os.path.join(directory, filename)
diff --git a/src/python/location-exporter.in b/src/python/location-exporter.in
deleted file mode 100644 (file)
index d82f1d3..0000000
+++ /dev/null
@@ -1,300 +0,0 @@
-#!/usr/bin/python3
-###############################################################################
-#                                                                             #
-# libloc - A library to determine the location of someone on the Internet     #
-#                                                                             #
-# Copyright (C) 2019 IPFire Development Team <info@ipfire.org>                #
-#                                                                             #
-# This library is free software; you can redistribute it and/or               #
-# modify it under the terms of the GNU Lesser General Public                  #
-# License as published by the Free Software Foundation; either                #
-# version 2.1 of the License, or (at your option) any later version.          #
-#                                                                             #
-# This library is distributed in the hope that it will be useful,             #
-# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU           #
-# Lesser General Public License for more details.                             #
-#                                                                             #
-###############################################################################
-
-import argparse
-import io
-import ipaddress
-import logging
-import os.path
-import re
-import socket
-import sys
-
-# Load our location module
-import location
-from location.i18n import _
-
-# Initialise logging
-log = logging.getLogger("location.exporter")
-log.propagate = 1
-
-class OutputWriter(object):
-       suffix = "networks"
-
-       def __init__(self, family, country_code=None, asn=None):
-               self.family, self.country_code, self.asn = family, country_code, asn
-
-               self.f = io.BytesIO()
-
-       def write_out(self, directory):
-               # Make the output filename
-               filename = os.path.join(
-                       directory, self._make_filename(),
-               )
-
-               with open(filename, "wb") as f:
-                       self._write_header(f)
-
-                       # Copy all data into the file
-                       f.write(self.f.getbuffer())
-
-                       self._write_footer(f)
-
-       def _make_filename(self):
-               return "%s.%s%s" % (
-                       self.country_code or "AS%s" % self.asn,
-                       self.suffix,
-                       "6" if self.family == socket.AF_INET6 else "4"
-               )
-
-       @property
-       def name(self):
-               if self.country_code:
-                       return "CC_%s" % self.country_code
-
-               if self.asn:
-                       return "AS%s" % self.asn
-
-       def _write_header(self, f):
-               """
-                       The header of the file
-               """
-               pass
-
-       def _write_footer(self, f):
-               """
-                       The footer of the file
-               """
-               pass
-
-       def write(self, network):
-               s = "%s\n" % network
-
-               self.f.write(s.encode("ascii"))
-
-
-class IpsetOutputWriter(OutputWriter):
-       """
-               For ipset
-       """
-       suffix = "ipset"
-
-       def _write_header(self, f):
-               h = "create %s hash:net family inet hashsize 1024 maxelem 65536\n" % self.name
-
-               f.write(h.encode("ascii"))
-
-       def write(self, network):
-               s = "add %s %s\n" % (self.name, network)
-
-               self.f.write(s.encode("ascii"))
-
-
-class NftablesOutputWriter(OutputWriter):
-       """
-               For nftables
-       """
-       suffix = "set"
-
-       def _write_header(self, f):
-               h = "define %s = {\n" % self.name
-
-               f.write(h.encode("ascii"))
-
-       def _write_footer(self, f):
-               f.write(b"}")
-
-       def write(self, network):
-               s = "   %s,\n" % network
-
-               self.f.write(s.encode("ascii"))
-
-
-class XTGeoIPOutputWriter(OutputWriter):
-       """
-               Formats the output in that way, that it can be loaded by
-               the xt_geoip kernel module from xtables-addons.
-       """
-       suffix = "iv"
-
-       def write(self, network):
-               n = ipaddress.ip_network("%s" % network)
-
-               for address in (n.network_address, n.broadcast_address):
-                       bytes = socket.inet_pton(
-                               socket.AF_INET6 if address.version == 6 else socket.AF_INET,
-                               "%s" % address,
-                       )
-
-                       self.f.write(bytes)
-
-
-class Exporter(object):
-       def __init__(self, db, writer):
-               self.db = db
-               self.writer = writer
-
-       def export(self, directory, families, countries, asns):
-               for family in families:
-                       log.debug("Exporting family %s" % family)
-
-                       writers = {}
-
-                       # Create writers for countries
-                       for country_code in countries:
-                               writers[country_code] = self.writer(family, country_code=country_code)
-
-                       # Create writers for ASNs
-                       for asn in asns:
-                               writers[asn] = self.writer(family, asn=asn)
-
-                       # Get all networks that match the family
-                       networks = self.db.search_networks(family=family)
-
-                       # Walk through all networks
-                       for network in networks:
-                               # Write matching countries
-                               if network.country_code in countries:
-                                       writers[network.country_code].write(network)
-
-                               # Write matching ASNs
-                               if network.asn in asns:
-                                       writers[network.asn].write(network)
-
-                       # Write everything to the filesystem
-                       for writer in writers.values():
-                               writer.write_out(directory)
-
-
-class CLI(object):
-       output_formats = {
-               "ipset"    : IpsetOutputWriter,
-               "list"     : OutputWriter,
-               "nftables" : NftablesOutputWriter,
-               "xt_geoip" : XTGeoIPOutputWriter,
-       }
-
-       def parse_cli(self):
-               parser = argparse.ArgumentParser(
-                       description=_("Location Exporter Command Line Interface"),
-               )
-
-               # Global configuration flags
-               parser.add_argument("--debug", action="store_true",
-                       help=_("Enable debug output"))
-               parser.add_argument("--quiet", action="store_true",
-                       help=_("Enable quiet mode"))
-
-               # version
-               parser.add_argument("--version", action="version",
-                       version="%(prog)s @VERSION@")
-
-               # database
-               parser.add_argument("--database", "-d",
-                       default="@databasedir@/database.db", help=_("Path to database"),
-               )
-
-               # format
-               parser.add_argument("--format", help=_("Output format"),
-                       default="list", choices=self.output_formats.keys())
-
-               # directory
-               parser.add_argument("--directory", help=_("Output directory"), required=True)
-
-               # family
-               parser.add_argument("--family", help=_("Specify address family"), choices=("ipv6", "ipv4"))
-
-               # Countries and Autonomous Systems
-               parser.add_argument("objects", nargs="+")
-
-               args = parser.parse_args()
-
-               # Configure logging
-               if args.debug:
-                       location.logger.set_level(logging.DEBUG)
-               elif args.quiet:
-                       location.logger.set_level(logging.WARNING)
-
-               return args
-
-       def run(self):
-               # Parse command line arguments
-               args = self.parse_cli()
-
-               # Call function
-               ret = self.handle_export(args)
-
-               # Return with exit code
-               if ret:
-                       sys.exit(ret)
-
-               # Otherwise just exit
-               sys.exit(0)
-
-       def handle_export(self, ns):
-               countries, asns = [], []
-
-               # Translate family
-               if ns.family == "ipv6":
-                       families = [ socket.AF_INET6 ]
-               elif ns.family == "ipv4":
-                       families = [ socket.AF_INET ]
-               else:
-                       families = [ socket.AF_INET6, socket.AF_INET ]
-
-               for object in ns.objects:
-                       m = re.match("^AS(\d+)$", object)
-                       if m:
-                               object = int(m.group(1))
-
-                               asns.append(object)
-
-                       elif location.country_code_is_valid(object) \
-                                       or object in ("A1", "A2", "A3"):
-                               countries.append(object)
-
-                       else:
-                               log.warning("Invalid argument: %s" % object)
-                               continue
-
-               if not countries and not asns:
-                       log.error("Nothing to export")
-                       return 2
-
-               # Open the database
-               try:
-                       db = location.Database(ns.database)
-               except FileNotFoundError as e:
-                       log.error("Count not open database: %s" % ns.database)
-                       return 1
-
-               # Select the output format
-               writer = self.output_formats.get(ns.format)
-               assert writer
-
-               e = Exporter(db, writer)
-               e.export(ns.directory, countries=countries, asns=asns, families=families)
-
-
-def main():
-       # Run the command line interface
-       c = CLI()
-       c.run()
-
-main()
index 10618e2f2b74c44e78815e9665a4166fbf5bb52b..7614cae29d43f01d5a96d719dceb730e33a3e331 100644 (file)
@@ -22,6 +22,7 @@ import datetime
 import ipaddress
 import logging
 import os
+import re
 import shutil
 import socket
 import sys
@@ -30,6 +31,8 @@ import time
 # Load our location module
 import location
 import location.downloader
+import location.export
+
 from location.i18n import _
 
 # Setup logging
@@ -37,88 +40,7 @@ log = logging.getLogger("location")
 
 # Output formatters
 
-class OutputFormatter(object):
-       def __init__(self, ns):
-               self.ns = ns
-
-       def __enter__(self):
-               # Open the output
-               self.open()
-
-               return self
-
-       def __exit__(self, type, value, tb):
-               if tb is None:
-                       self.close()
-
-       @property
-       def name(self):
-               if "country_code" in self.ns:
-                       return "networks_country_%s" % self.ns.country_code[0]
-
-               elif "asn" in self.ns:
-                       return "networks_AS%s" % self.ns.asn[0]
-
-       def open(self):
-               pass
-
-       def close(self):
-               pass
-
-       def network(self, network):
-               print(network)
-
-
-class IpsetOutputFormatter(OutputFormatter):
-       """
-               For nftables
-       """
-       def open(self):
-               print("create %s hash:net family inet hashsize 1024 maxelem 65536" % self.name)
-
-       def network(self, network):
-               print("add %s %s" % (self.name, network))
-
-
-class NftablesOutputFormatter(OutputFormatter):
-       """
-               For nftables
-       """
-       def open(self):
-               print("define %s = {" % self.name)
-
-       def close(self):
-               print("}")
-
-       def network(self, network):
-               print(" %s," % network)
-
-
-class XTGeoIPOutputFormatter(OutputFormatter):
-       """
-               Formats the output in that way, that it can be loaded by
-               the xt_geoip kernel module from xtables-addons.
-       """
-       def network(self, network):
-               n = ipaddress.ip_network("%s" % network)
-
-               for address in (n.network_address, n.broadcast_address):
-                       bytes = socket.inet_pton(
-                               socket.AF_INET6 if address.version == 6 else socket.AF_INET,
-                               "%s" % address,
-                       )
-
-                       os.write(1, bytes)
-
-
 class CLI(object):
-       output_formats = {
-               "ipset"    : IpsetOutputFormatter,
-               "list"     : OutputFormatter,
-               "nftables" : NftablesOutputFormatter,
-               "xt_geoip" : XTGeoIPOutputFormatter,
-       }
-
        def parse_cli(self):
                parser = argparse.ArgumentParser(
                        description=_("Location Database Command Line Interface"),
@@ -193,8 +115,8 @@ class CLI(object):
                )
                list_networks_by_as.add_argument("asn", nargs=1, type=int)
                list_networks_by_as.add_argument("--family", choices=("ipv6", "ipv4"))
-               list_networks_by_as.add_argument("--output-format",
-                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_as.add_argument("--format",
+                       choices=location.export.formats.keys(), default="list")
                list_networks_by_as.set_defaults(func=self.handle_list_networks_by_as)
 
                # List all networks in a country
@@ -203,8 +125,8 @@ class CLI(object):
                )
                list_networks_by_cc.add_argument("country_code", nargs=1)
                list_networks_by_cc.add_argument("--family", choices=("ipv6", "ipv4"))
-               list_networks_by_cc.add_argument("--output-format",
-                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_cc.add_argument("--format",
+                       choices=location.export.formats.keys(), default="list")
                list_networks_by_cc.set_defaults(func=self.handle_list_networks_by_cc)
 
                # List all networks with flags
@@ -221,10 +143,23 @@ class CLI(object):
                        action="store_true", help=_("Anycasts"),
                )
                list_networks_by_flags.add_argument("--family", choices=("ipv6", "ipv4"))
-               list_networks_by_flags.add_argument("--output-format",
-                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_flags.add_argument("--format",
+                       choices=location.export.formats.keys(), default="list")
                list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags)
 
+               # Export
+               export = subparsers.add_parser("export",
+                       help=_("Exports data in many formats to load it into packet filters"),
+               )
+               export.add_argument("--format", help=_("Output format"),
+                       choices=location.export.formats.keys(), default="list")
+               export.add_argument("--directory", help=_("Output directory"), required=True)
+               export.add_argument("--family",
+                       help=_("Specify address family"), choices=("ipv6", "ipv4"),
+               )
+               export.add_argument("objects", nargs="+", help=_("List country codes or ASNs to export"))
+               export.set_defaults(func=self.handle_export)
+
                args = parser.parse_args()
 
                # Configure logging
@@ -494,25 +429,36 @@ class CLI(object):
 
        def __get_output_formatter(self, ns):
                try:
-                       cls = self.output_formats[ns.output_format]
+                       cls = location.export.formats[ns.format]
                except KeyError:
-                       cls = OutputFormatter
+                       cls = location.export.OutputFormatter
 
-               return cls(ns)
+               return cls
 
        def handle_list_networks_by_as(self, db, ns):
-               with self.__get_output_formatter(ns) as f:
-                       for asn in ns.asn:
-                               # Print all matching networks
-                               for n in db.search_networks(asn=asn, family=ns.family):
-                                       f.network(n)
+               writer = self.__get_output_formatter(ns)
+
+               for asn in ns.asn:
+                       f = writer(sys.stdout, prefix="AS%s" % asn)
+
+                       # Print all matching networks
+                       for n in db.search_networks(asn=asn, family=ns.family):
+                               f.write(n)
+
+                       f.finish()
 
        def handle_list_networks_by_cc(self, db, ns):
-               with self.__get_output_formatter(ns) as f:
-                       for country_code in ns.country_code:
-                               # Print all matching networks
-                               for n in db.search_networks(country_code=country_code, family=ns.family):
-                                       f.network(n)
+               writer = self.__get_output_formatter(ns)
+
+               for country_code in ns.country_code:
+                       # Open standard output
+                       f = writer(sys.stdout, prefix=country_code)
+
+                       # Print all matching networks
+                       for n in db.search_networks(country_code=country_code, family=ns.family):
+                               f.write(n)
+
+                       f.finish()
 
        def handle_list_networks_by_flags(self, db, ns):
                flags = 0
@@ -529,9 +475,49 @@ class CLI(object):
                if not flags:
                        raise ValueError(_("You must at least pass one flag"))
 
-               with self.__get_output_formatter(ns) as f:
-                       for n in db.search_networks(flags=flags, family=ns.family):
-                               f.network(n)
+               writer = self.__get_output_formatter(ns)
+               f = writer(sys.stdout, prefix="custom")
+
+               for n in db.search_networks(flags=flags, family=ns.family):
+                       f.write(n)
+
+               f.finish()
+
+       def handle_export(self, db, ns):
+               countries, asns = [], []
+
+               # Translate family
+               if ns.family == "ipv6":
+                       families = [ socket.AF_INET6 ]
+               elif ns.family == "ipv4":
+                       families = [ socket.AF_INET ]
+               else:
+                       families = [ socket.AF_INET6, socket.AF_INET ]
+
+               for object in ns.objects:
+                       m = re.match("^AS(\d+)$", object)
+                       if m:
+                               object = int(m.group(1))
+
+                               asns.append(object)
+
+                       elif location.country_code_is_valid(object) \
+                                       or object in ("A1", "A2", "A3"):
+                               countries.append(object)
+
+                       else:
+                               log.warning("Invalid argument: %s" % object)
+                               continue
+
+               if not countries and not asns:
+                       log.error("Nothing to export")
+                       return 2
+
+               # Select the output format
+               writer = self.__get_output_formatter(ns)
+
+               e = location.export.Exporter(db, writer)
+               e.export(ns.directory, countries=countries, asns=asns, families=families)
 
 
 def main():