]> git.ipfire.org Git - location/libloc.git/blobdiff - src/python/location.in
location update: Remove double conversion of timestamps
[location/libloc.git] / src / python / location.in
index 10618e2f2b74c44e78815e9665a4166fbf5bb52b..070640c2ae5cd5a07ce3875e76f444d9a2c6d719 100644 (file)
@@ -22,6 +22,7 @@ import datetime
 import ipaddress
 import logging
 import os
+import re
 import shutil
 import socket
 import sys
@@ -30,6 +31,8 @@ import time
 # Load our location module
 import location
 import location.downloader
+import location.export
+
 from location.i18n import _
 
 # Setup logging
@@ -37,88 +40,7 @@ log = logging.getLogger("location")
 
 # Output formatters
 
-class OutputFormatter(object):
-       def __init__(self, ns):
-               self.ns = ns
-
-       def __enter__(self):
-               # Open the output
-               self.open()
-
-               return self
-
-       def __exit__(self, type, value, tb):
-               if tb is None:
-                       self.close()
-
-       @property
-       def name(self):
-               if "country_code" in self.ns:
-                       return "networks_country_%s" % self.ns.country_code[0]
-
-               elif "asn" in self.ns:
-                       return "networks_AS%s" % self.ns.asn[0]
-
-       def open(self):
-               pass
-
-       def close(self):
-               pass
-
-       def network(self, network):
-               print(network)
-
-
-class IpsetOutputFormatter(OutputFormatter):
-       """
-               For nftables
-       """
-       def open(self):
-               print("create %s hash:net family inet hashsize 1024 maxelem 65536" % self.name)
-
-       def network(self, network):
-               print("add %s %s" % (self.name, network))
-
-
-class NftablesOutputFormatter(OutputFormatter):
-       """
-               For nftables
-       """
-       def open(self):
-               print("define %s = {" % self.name)
-
-       def close(self):
-               print("}")
-
-       def network(self, network):
-               print(" %s," % network)
-
-
-class XTGeoIPOutputFormatter(OutputFormatter):
-       """
-               Formats the output in that way, that it can be loaded by
-               the xt_geoip kernel module from xtables-addons.
-       """
-       def network(self, network):
-               n = ipaddress.ip_network("%s" % network)
-
-               for address in (n.network_address, n.broadcast_address):
-                       bytes = socket.inet_pton(
-                               socket.AF_INET6 if address.version == 6 else socket.AF_INET,
-                               "%s" % address,
-                       )
-
-                       os.write(1, bytes)
-
-
 class CLI(object):
-       output_formats = {
-               "ipset"    : IpsetOutputFormatter,
-               "list"     : OutputFormatter,
-               "nftables" : NftablesOutputFormatter,
-               "xt_geoip" : XTGeoIPOutputFormatter,
-       }
-
        def parse_cli(self):
                parser = argparse.ArgumentParser(
                        description=_("Location Database Command Line Interface"),
@@ -166,6 +88,10 @@ class CLI(object):
 
                # Update
                update = subparsers.add_parser("update", help=_("Update database"))
+               update.add_argument("--cron",
+                       help=_("Update the library only once per interval"),
+                       choices=("daily", "weekly", "monthly"),
+               )
                update.set_defaults(func=self.handle_update)
 
                # Verify
@@ -193,8 +119,8 @@ class CLI(object):
                )
                list_networks_by_as.add_argument("asn", nargs=1, type=int)
                list_networks_by_as.add_argument("--family", choices=("ipv6", "ipv4"))
-               list_networks_by_as.add_argument("--output-format",
-                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_as.add_argument("--format",
+                       choices=location.export.formats.keys(), default="list")
                list_networks_by_as.set_defaults(func=self.handle_list_networks_by_as)
 
                # List all networks in a country
@@ -203,8 +129,8 @@ class CLI(object):
                )
                list_networks_by_cc.add_argument("country_code", nargs=1)
                list_networks_by_cc.add_argument("--family", choices=("ipv6", "ipv4"))
-               list_networks_by_cc.add_argument("--output-format",
-                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_cc.add_argument("--format",
+                       choices=location.export.formats.keys(), default="list")
                list_networks_by_cc.set_defaults(func=self.handle_list_networks_by_cc)
 
                # List all networks with flags
@@ -221,10 +147,35 @@ class CLI(object):
                        action="store_true", help=_("Anycasts"),
                )
                list_networks_by_flags.add_argument("--family", choices=("ipv6", "ipv4"))
-               list_networks_by_flags.add_argument("--output-format",
-                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_flags.add_argument("--format",
+                       choices=location.export.formats.keys(), default="list")
                list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags)
 
+               # List countries
+               list_countries = subparsers.add_parser("list-countries",
+                       help=_("Lists all countries"),
+               )
+               list_countries.add_argument("--show-name",
+                       action="store_true", help=_("Show the name of the country"),
+               )
+               list_countries.add_argument("--show-continent",
+                       action="store_true", help=_("Show the continent"),
+               )
+               list_countries.set_defaults(func=self.handle_list_countries)
+
+               # Export
+               export = subparsers.add_parser("export",
+                       help=_("Exports data in many formats to load it into packet filters"),
+               )
+               export.add_argument("--format", help=_("Output format"),
+                       choices=location.export.formats.keys(), default="list")
+               export.add_argument("--directory", help=_("Output directory"), required=True)
+               export.add_argument("--family",
+                       help=_("Specify address family"), choices=("ipv6", "ipv4"),
+               )
+               export.add_argument("objects", nargs="*", help=_("List country codes or ASNs to export"))
+               export.set_defaults(func=self.handle_export)
+
                args = parser.parse_args()
 
                # Configure logging
@@ -248,9 +199,14 @@ class CLI(object):
                try:
                        db = location.Database(args.database)
                except FileNotFoundError as e:
-                       sys.stderr.write("location: Could not open database %s: %s\n" \
-                               % (args.database, e))
-                       sys.exit(1)
+                       # Allow continuing without a database
+                       if args.func == self.handle_update:
+                               db = None
+
+                       else:
+                               sys.stderr.write("location: Could not open database %s: %s\n" \
+                                       % (args.database, e))
+                               sys.exit(1)
 
                # Translate family (if present)
                if "family" in args:
@@ -374,7 +330,8 @@ class CLI(object):
 
                if db.description:
                        for line in db.description.splitlines():
-                               f.write("# %s\n" % line)
+                               line = "# %s" % line
+                               f.write("%s\n" % line.rstrip())
 
                        f.write("#\n")
 
@@ -440,14 +397,32 @@ class CLI(object):
                                print(a)
 
        def handle_update(self, db, ns):
+               if ns.cron and db:
+                       now = datetime.datetime.utcnow()
+
+                       # Parse the database timestamp
+                       t = datetime.datetime.utcfromtimestamp(db.created_at)
+
+                       if ns.cron == "daily":
+                               delta = datetime.timedelta(days=1)
+                       elif ns.cron == "weekly":
+                               delta = datetime.timedelta(days=7)
+                       elif ns.cron == "monthly":
+                               delta = datetime.timedelta(days=30)
+
+                       # Check if the database has recently been updated
+                       if t >= (now - delta):
+                               log.info(
+                                       _("The database has been updated recently (%s)") % \
+                                               format_timedelta(now - t),
+                               )
+                               return 3
+
                # Fetch the timestamp we need from DNS
                t = location.discover_latest_version()
 
-               # Parse timestamp into datetime format
-               timestamp = datetime.datetime.fromtimestamp(t) if t else None
-
                # Check the version of the local database
-               if db and timestamp and db.created_at >= timestamp.timestamp():
+               if db and t and db.created_at >= t:
                        log.info("Already on the latest version")
                        return
 
@@ -475,13 +450,7 @@ class CLI(object):
 
                return 0
 
-       def handle_verify(self, ns):
-               try:
-                       db = location.Database(ns.database)
-               except FileNotFoundError as e:
-                       log.error("%s: %s" % (ns.database, e))
-                       return 127
-
+       def handle_verify(self, db, ns):
                # Verify the database
                with open(ns.public_key, "r") as f:
                        if not db.verify(f):
@@ -494,25 +463,54 @@ class CLI(object):
 
        def __get_output_formatter(self, ns):
                try:
-                       cls = self.output_formats[ns.output_format]
+                       cls = location.export.formats[ns.format]
                except KeyError:
-                       cls = OutputFormatter
+                       cls = location.export.OutputFormatter
+
+               return cls
+
+       def handle_list_countries(self, db, ns):
+               for country in db.countries:
+                       line = [
+                               country.code,
+                       ]
+
+                       if ns.show_continent:
+                               line.append(country.continent_code)
+
+                       if ns.show_name:
+                               line.append(country.name)
 
-               return cls(ns)
+                       # Format the output
+                       line = " ".join(line)
+
+                       # Print the output
+                       print(line)
 
        def handle_list_networks_by_as(self, db, ns):
-               with self.__get_output_formatter(ns) as f:
-                       for asn in ns.asn:
-                               # Print all matching networks
-                               for n in db.search_networks(asn=asn, family=ns.family):
-                                       f.network(n)
+               writer = self.__get_output_formatter(ns)
+
+               for asn in ns.asn:
+                       f = writer(sys.stdout, prefix="AS%s" % asn)
+
+                       # Print all matching networks
+                       for n in db.search_networks(asn=asn, family=ns.family):
+                               f.write(n)
+
+                       f.finish()
 
        def handle_list_networks_by_cc(self, db, ns):
-               with self.__get_output_formatter(ns) as f:
-                       for country_code in ns.country_code:
-                               # Print all matching networks
-                               for n in db.search_networks(country_code=country_code, family=ns.family):
-                                       f.network(n)
+               writer = self.__get_output_formatter(ns)
+
+               for country_code in ns.country_code:
+                       # Open standard output
+                       f = writer(sys.stdout, prefix=country_code)
+
+                       # Print all matching networks
+                       for n in db.search_networks(country_code=country_code, family=ns.family):
+                               f.write(n)
+
+                       f.finish()
 
        def handle_list_networks_by_flags(self, db, ns):
                flags = 0
@@ -529,10 +527,79 @@ class CLI(object):
                if not flags:
                        raise ValueError(_("You must at least pass one flag"))
 
-               with self.__get_output_formatter(ns) as f:
-                       for n in db.search_networks(flags=flags, family=ns.family):
-                               f.network(n)
+               writer = self.__get_output_formatter(ns)
+               f = writer(sys.stdout, prefix="custom")
+
+               for n in db.search_networks(flags=flags, family=ns.family):
+                       f.write(n)
+
+               f.finish()
+
+       def handle_export(self, db, ns):
+               countries, asns = [], []
+
+               # Translate family
+               if ns.family:
+                       families = [ ns.family ]
+               else:
+                       families = [ socket.AF_INET6, socket.AF_INET ]
+
+               for object in ns.objects:
+                       m = re.match("^AS(\d+)$", object)
+                       if m:
+                               object = int(m.group(1))
+
+                               asns.append(object)
+
+                       elif location.country_code_is_valid(object) \
+                                       or object in ("A1", "A2", "A3"):
+                               countries.append(object)
+
+                       else:
+                               log.warning("Invalid argument: %s" % object)
+                               continue
+
+               # Default to exporting all countries
+               if not countries and not asns:
+                       countries = ["A1", "A2", "A3"] + [country.code for country in db.countries]
+
+               # Select the output format
+               writer = self.__get_output_formatter(ns)
+
+               e = location.export.Exporter(db, writer)
+               e.export(ns.directory, countries=countries, asns=asns, families=families)
+
+
+def format_timedelta(t):
+       s = []
+
+       if t.days:
+               s.append(
+                       _("One Day", "%(days)s Days", t.days) % { "days" : t.days, }
+               )
+
+       hours = t.seconds // 3600
+       if hours:
+               s.append(
+                       _("One Hour", "%(hours)s Hours", hours) % { "hours" : hours, }
+               )
+
+       minutes = (t.seconds % 3600) // 60
+       if minutes:
+               s.append(
+                       _("One Minute", "%(minutes)s Minutes", minutes) % { "minutes" : minutes, }
+               )
+
+       seconds = t.seconds % 60
+       if t.seconds:
+               s.append(
+                       _("One Second", "%(seconds)s Seconds", seconds) % { "seconds" : seconds, }
+               )
+
+       if not s:
+               return _("Now")
 
+       return _("%s ago") % ", ".join(s)
 
 def main():
        # Run the command line interface