]> git.ipfire.org Git - location/libloc.git/blobdiff - src/python/location.in
location: Print proper error message for any uncaught exceptions
[location/libloc.git] / src / python / location.in
index 5c1effd2d920def5c82b4e24f74576c1752ea44f..eec32e4ca966c8973be32d40d086557fce177ef7 100644 (file)
@@ -3,7 +3,7 @@
 #                                                                             #
 # libloc - A library to determine the location of someone on the Internet     #
 #                                                                             #
-# Copyright (C) 2017 IPFire Development Team <info@ipfire.org>                #
+# Copyright (C) 2017-2021 IPFire Development Team <info@ipfire.org>           #
 #                                                                             #
 # This library is free software; you can redistribute it and/or               #
 # modify it under the terms of the GNU Lesser General Public                  #
@@ -88,6 +88,10 @@ class CLI(object):
 
                # Update
                update = subparsers.add_parser("update", help=_("Update database"))
+               update.add_argument("--cron",
+                       help=_("Update the library only once per interval"),
+                       choices=("daily", "weekly", "monthly"),
+               )
                update.set_defaults(func=self.handle_update)
 
                # Verify
@@ -142,11 +146,23 @@ class CLI(object):
                list_networks_by_flags.add_argument("--anycast",
                        action="store_true", help=_("Anycasts"),
                )
+               list_networks_by_flags.add_argument("--drop",
+                       action="store_true", help=_("Hostile Networks safe to drop"),
+               )
                list_networks_by_flags.add_argument("--family", choices=("ipv6", "ipv4"))
                list_networks_by_flags.add_argument("--format",
                        choices=location.export.formats.keys(), default="list")
                list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags)
 
+               # List bogons
+               list_bogons = subparsers.add_parser("list-bogons",
+                       help=_("Lists all bogons"),
+               )
+               list_bogons.add_argument("--family", choices=("ipv6", "ipv4"))
+               list_bogons.add_argument("--format",
+                       choices=location.export.formats.keys(), default="list")
+               list_bogons.set_defaults(func=self.handle_list_bogons)
+
                # List countries
                list_countries = subparsers.add_parser("list-countries",
                        help=_("Lists all countries"),
@@ -169,7 +185,7 @@ class CLI(object):
                export.add_argument("--family",
                        help=_("Specify address family"), choices=("ipv6", "ipv4"),
                )
-               export.add_argument("objects", nargs="+", help=_("List country codes or ASNs to export"))
+               export.add_argument("objects", nargs="*", help=_("List country codes or ASNs to export"))
                export.set_defaults(func=self.handle_export)
 
                args = parser.parse_args()
@@ -195,9 +211,14 @@ class CLI(object):
                try:
                        db = location.Database(args.database)
                except FileNotFoundError as e:
-                       sys.stderr.write("location: Could not open database %s: %s\n" \
-                               % (args.database, e))
-                       sys.exit(1)
+                       # Allow continuing without a database
+                       if args.func == self.handle_update:
+                               db = None
+
+                       else:
+                               sys.stderr.write("location: Could not open database %s: %s\n" \
+                                       % (args.database, e))
+                               sys.exit(1)
 
                # Translate family (if present)
                if "family" in args:
@@ -217,6 +238,11 @@ class CLI(object):
                        sys.stderr.write("%s\n" % e)
                        ret = 2
 
+               # Catch any other exceptions
+               except Exception as e:
+                       sys.stderr.write("%s\n" % e)
+                       ret = 1
+
                # Return with exit code
                if ret:
                        sys.exit(ret)
@@ -244,6 +270,7 @@ class CLI(object):
                                network = db.lookup(address)
                        except ValueError:
                                print(_("Invalid IP address: %s") % address, file=sys.stderr)
+                               return 2
 
                        args = {
                                "address" : address,
@@ -295,6 +322,12 @@ class CLI(object):
                                        _("Anycast"), _("yes"),
                                ))
 
+                       # Hostile Network
+                       if network.has_flag(location.NETWORK_FLAG_DROP):
+                               print(format % (
+                                       _("Hostile Network safe to drop"), _("yes"),
+                               ))
+
                return ret
 
        def handle_dump(self, db, ns):
@@ -321,7 +354,8 @@ class CLI(object):
 
                if db.description:
                        for line in db.description.splitlines():
-                               f.write("# %s\n" % line)
+                               line = "# %s" % line
+                               f.write("%s\n" % line.rstrip())
 
                        f.write("#\n")
 
@@ -335,6 +369,7 @@ class CLI(object):
                        location.NETWORK_FLAG_ANONYMOUS_PROXY    : "is-anonymous-proxy:",
                        location.NETWORK_FLAG_SATELLITE_PROVIDER : "is-satellite-provider:",
                        location.NETWORK_FLAG_ANYCAST            : "is-anycast:",
+                       location.NETWORK_FLAG_DROP               : "drop:",
                }
 
                # Iterate over all networks
@@ -387,14 +422,30 @@ class CLI(object):
                                print(a)
 
        def handle_update(self, db, ns):
+               if ns.cron and db:
+                       now = time.time()
+
+                       if ns.cron == "daily":
+                               delta = datetime.timedelta(days=1)
+                       elif ns.cron == "weekly":
+                               delta = datetime.timedelta(days=7)
+                       elif ns.cron == "monthly":
+                               delta = datetime.timedelta(days=30)
+
+                       delta = delta.total_seconds()
+
+                       # Check if the database has recently been updated
+                       if db.created_at >= (now - delta):
+                               log.info(
+                                       _("The database has been updated recently"),
+                               )
+                               return 3
+
                # Fetch the timestamp we need from DNS
                t = location.discover_latest_version()
 
-               # Parse timestamp into datetime format
-               timestamp = datetime.datetime.fromtimestamp(t) if t else None
-
                # Check the version of the local database
-               if db and timestamp and db.created_at >= timestamp.timestamp():
+               if db and t and db.created_at >= t:
                        log.info("Already on the latest version")
                        return
 
@@ -406,7 +457,7 @@ class CLI(object):
 
                # Try downloading a new database
                try:
-                       t = d.download(public_key=ns.public_key, timestamp=timestamp, tmpdir=tmpdir)
+                       t = d.download(public_key=ns.public_key, timestamp=t, tmpdir=tmpdir)
 
                # If no file could be downloaded, log a message
                except FileNotFoundError as e:
@@ -422,13 +473,7 @@ class CLI(object):
 
                return 0
 
-       def handle_verify(self, ns):
-               try:
-                       db = location.Database(ns.database)
-               except FileNotFoundError as e:
-                       log.error("%s: %s" % (ns.database, e))
-                       return 127
-
+       def handle_verify(self, db, ns):
                # Verify the database
                with open(ns.public_key, "r") as f:
                        if not db.verify(f):
@@ -472,7 +517,7 @@ class CLI(object):
                        f = writer(sys.stdout, prefix="AS%s" % asn)
 
                        # Print all matching networks
-                       for n in db.search_networks(asn=asn, family=ns.family):
+                       for n in db.search_networks(asns=[asn], family=ns.family):
                                f.write(n)
 
                        f.finish()
@@ -485,7 +530,7 @@ class CLI(object):
                        f = writer(sys.stdout, prefix=country_code)
 
                        # Print all matching networks
-                       for n in db.search_networks(country_code=country_code, family=ns.family):
+                       for n in db.search_networks(country_codes=[country_code], family=ns.family):
                                f.write(n)
 
                        f.finish()
@@ -502,6 +547,9 @@ class CLI(object):
                if ns.anycast:
                        flags |= location.NETWORK_FLAG_ANYCAST
 
+               if ns.drop:
+                       flags |= location.NETWORK_FLAG_DROP
+
                if not flags:
                        raise ValueError(_("You must at least pass one flag"))
 
@@ -513,14 +561,21 @@ class CLI(object):
 
                f.finish()
 
+       def handle_list_bogons(self, db, ns):
+               writer = self.__get_output_formatter(ns)
+               f = writer(sys.stdout, prefix="bogons")
+
+               for n in db.list_bogons(family=ns.family):
+                       f.write(n)
+
+               f.finish()
+
        def handle_export(self, db, ns):
                countries, asns = [], []
 
                # Translate family
-               if ns.family == "ipv6":
-                       families = [ socket.AF_INET6 ]
-               elif ns.family == "ipv4":
-                       families = [ socket.AF_INET ]
+               if ns.family:
+                       families = [ ns.family ]
                else:
                        families = [ socket.AF_INET6, socket.AF_INET ]
 
@@ -532,16 +587,16 @@ class CLI(object):
                                asns.append(object)
 
                        elif location.country_code_is_valid(object) \
-                                       or object in ("A1", "A2", "A3"):
+                                       or object in ("A1", "A2", "A3", "XD"):
                                countries.append(object)
 
                        else:
                                log.warning("Invalid argument: %s" % object)
                                continue
 
+               # Default to exporting all countries
                if not countries and not asns:
-                       log.error("Nothing to export")
-                       return 2
+                       countries = ["A1", "A2", "A3", "XD"] + [country.code for country in db.countries]
 
                # Select the output format
                writer = self.__get_output_formatter(ns)
@@ -550,6 +605,37 @@ class CLI(object):
                e.export(ns.directory, countries=countries, asns=asns, families=families)
 
 
+def format_timedelta(t):
+       s = []
+
+       if t.days:
+               s.append(
+                       _("One Day", "%(days)s Days", t.days) % { "days" : t.days, }
+               )
+
+       hours = t.seconds // 3600
+       if hours:
+               s.append(
+                       _("One Hour", "%(hours)s Hours", hours) % { "hours" : hours, }
+               )
+
+       minutes = (t.seconds % 3600) // 60
+       if minutes:
+               s.append(
+                       _("One Minute", "%(minutes)s Minutes", minutes) % { "minutes" : minutes, }
+               )
+
+       seconds = t.seconds % 60
+       if t.seconds:
+               s.append(
+                       _("One Second", "%(seconds)s Seconds", seconds) % { "seconds" : seconds, }
+               )
+
+       if not s:
+               return _("Now")
+
+       return _("%s ago") % ", ".join(s)
+
 def main():
        # Run the command line interface
        c = CLI()