]> git.ipfire.org Git - people/ms/libloc.git/blobdiff - src/python/location-query.in
location-query: Require at least one flag
[people/ms/libloc.git] / src / python / location-query.in
index 933024eb1c030adf822238f60571d4bf7c8d7f67..dfdff8c2b804b08732d7734c310a1614cc35d3e9 100644 (file)
 ###############################################################################
 
 import argparse
-import gettext
+import ipaddress
+import os
+import socket
 import sys
-import syslog
+import time
 
 # Load our location module
 import location
+from location.i18n import _
 
-# i18n
-def _(singular, plural=None, n=None):
-       if plural:
-               return gettext.dngettext("libloc", singular, plural, n)
+# Output formatters
+
+class OutputFormatter(object):
+       def __init__(self, ns):
+               self.ns = ns
+
+       def __enter__(self):
+               # Open the output
+               self.open()
+
+               return self
+
+       def __exit__(self, type, value, tb):
+               if tb is None:
+                       self.close()
+
+       @property
+       def name(self):
+               if "country_code" in self.ns:
+                       return "networks_country_%s" % self.ns.country_code[0]
+
+               elif "asn" in self.ns:
+                       return "networks_AS%s" % self.ns.asn[0]
+
+       def open(self):
+               pass
+
+       def close(self):
+               pass
+
+       def network(self, network):
+               print(network)
+
+
+class IpsetOutputFormatter(OutputFormatter):
+       """
+               For nftables
+       """
+       def open(self):
+               print("create %s hash:net family inet hashsize 1024 maxelem 65536" % self.name)
+
+       def network(self, network):
+               print("add %s %s" % (self.name, network))
+
+
+class NftablesOutputFormatter(OutputFormatter):
+       """
+               For nftables
+       """
+       def open(self):
+               print("define %s = {" % self.name)
+
+       def close(self):
+               print("}")
+
+       def network(self, network):
+               print(" %s," % network)
+
+
+class XTGeoIPOutputFormatter(OutputFormatter):
+       """
+               Formats the output in that way, that it can be loaded by
+               the xt_geoip kernel module from xtables-addons.
+       """
+       def network(self, network):
+               n = ipaddress.ip_network("%s" % network)
+
+               for address in (n.network_address, n.broadcast_address):
+                       bytes = socket.inet_pton(
+                               socket.AF_INET6 if address.version == 6 else socket.AF_INET,
+                               "%s" % address,
+                       )
+
+                       os.write(1, bytes)
 
-       return gettext.dgettext("libloc", singular)
 
 class CLI(object):
-       def __init__(self):
-               # Open database
-               self.db = location.Database("@databasedir@/database.db")
+       output_formats = {
+               "ipset"    : IpsetOutputFormatter,
+               "list"     : OutputFormatter,
+               "nftables" : NftablesOutputFormatter,
+               "xt_geoip" : XTGeoIPOutputFormatter,
+       }
 
        def parse_cli(self):
                parser = argparse.ArgumentParser(
@@ -46,10 +121,27 @@ class CLI(object):
                # Global configuration flags
                parser.add_argument("--debug", action="store_true",
                        help=_("Enable debug output"))
+               parser.add_argument("--quiet", action="store_true",
+                       help=_("Enable quiet mode"))
 
                # version
                parser.add_argument("--version", action="version",
-                       version="%%(prog)s %s" % location.__version__)
+                       version="%(prog)s @VERSION@")
+
+               # database
+               parser.add_argument("--database", "-d",
+                       default="@databasedir@/database.db", help=_("Path to database"),
+               )
+
+               # public key
+               parser.add_argument("--public-key", "-k",
+                       default="@databasedir@/signing-key.pem", help=_("Public Signing Key"),
+               )
+
+               # Show the database version
+               version = subparsers.add_parser("version",
+                       help=_("Show database version"))
+               version.set_defaults(func=self.handle_version)
 
                # lookup an IP address
                lookup = subparsers.add_parser("lookup",
@@ -58,6 +150,13 @@ class CLI(object):
                lookup.add_argument("address", nargs="+")
                lookup.set_defaults(func=self.handle_lookup)
 
+               # Dump the whole database
+               dump = subparsers.add_parser("dump",
+                       help=_("Dump the entire database"),
+               )
+               dump.add_argument("output", nargs="?", type=argparse.FileType("w"))
+               dump.set_defaults(func=self.handle_dump)
+
                # Get AS
                get_as = subparsers.add_parser("get-as",
                        help=_("Get information about one or multiple Autonomous Systems"),
@@ -65,17 +164,95 @@ class CLI(object):
                get_as.add_argument("asn", nargs="+")
                get_as.set_defaults(func=self.handle_get_as)
 
-               return parser.parse_args()
+               # Search for AS
+               search_as = subparsers.add_parser("search-as",
+                       help=_("Search for Autonomous Systems that match the string"),
+               )
+               search_as.add_argument("query", nargs=1)
+               search_as.set_defaults(func=self.handle_search_as)
+
+               # List all networks in an AS
+               list_networks_by_as = subparsers.add_parser("list-networks-by-as",
+                       help=_("Lists all networks in an AS"),
+               )
+               list_networks_by_as.add_argument("asn", nargs=1, type=int)
+               list_networks_by_as.add_argument("--family", choices=("ipv6", "ipv4"))
+               list_networks_by_as.add_argument("--output-format",
+                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_as.set_defaults(func=self.handle_list_networks_by_as)
+
+               # List all networks in a country
+               list_networks_by_cc = subparsers.add_parser("list-networks-by-cc",
+                       help=_("Lists all networks in a country"),
+               )
+               list_networks_by_cc.add_argument("country_code", nargs=1)
+               list_networks_by_cc.add_argument("--family", choices=("ipv6", "ipv4"))
+               list_networks_by_cc.add_argument("--output-format",
+                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_cc.set_defaults(func=self.handle_list_networks_by_cc)
+
+               # List all networks with flags
+               list_networks_by_flags = subparsers.add_parser("list-networks-by-flags",
+                       help=_("Lists all networks with flags"),
+               )
+               list_networks_by_flags.add_argument("--anonymous-proxy",
+                       action="store_true", help=_("Anonymous Proxies"),
+               )
+               list_networks_by_flags.add_argument("--satellite-provider",
+                       action="store_true", help=_("Satellite Providers"),
+               )
+               list_networks_by_flags.add_argument("--anycast",
+                       action="store_true", help=_("Anycasts"),
+               )
+               list_networks_by_flags.add_argument("--family", choices=("ipv6", "ipv4"))
+               list_networks_by_flags.add_argument("--output-format",
+                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags)
+
+               args = parser.parse_args()
+
+               # Configure logging
+               if args.debug:
+                       location.logger.set_level(logging.DEBUG)
+               elif args.quiet:
+                       location.logger.set_level(logging.WARNING)
+
+               # Print usage if no action was given
+               if not "func" in args:
+                       parser.print_usage()
+                       sys.exit(2)
+
+               return args
 
        def run(self):
                # Parse command line arguments
                args = self.parse_cli()
 
-               # Callback function must be defined
-               assert args.func, "Callback function not defined"
+               # Open database
+               try:
+                       db = location.Database(args.database)
+               except FileNotFoundError as e:
+                       sys.stderr.write("location-query: Could not open database %s: %s\n" \
+                               % (args.database, e))
+                       sys.exit(1)
+
+               # Translate family (if present)
+               if "family" in args:
+                       if args.family == "ipv6":
+                               args.family = socket.AF_INET6
+                       elif args.family == "ipv4":
+                               args.family = socket.AF_INET
+                       else:
+                               args.family = 0
 
                # Call function
-               ret = args.func(args)
+               try:
+                       ret = args.func(db, args)
+
+               # Catch invalid inputs
+               except ValueError as e:
+                       sys.stderr.write("%s\n" % e)
+                       ret = 2
 
                # Return with exit code
                if ret:
@@ -84,44 +261,136 @@ class CLI(object):
                # Otherwise just exit
                sys.exit(0)
 
-       def handle_lookup(self, ns):
+       def handle_version(self, db, ns):
+               """
+                       Print the version of the database
+               """
+               t = time.strftime(
+                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+               )
+
+               print(t)
+
+       def handle_lookup(self, db, ns):
                ret = 0
 
+               format = "  %-24s: %s"
+
                for address in ns.address:
                        try:
-                               n = self.db.lookup(address)
+                               network = db.lookup(address)
                        except ValueError:
-                               sys.stderr.write(_("Invalid IP address: %s") % address)
+                               print(_("Invalid IP address: %s") % address, file=sys.stderr)
 
                        args = {
                                "address" : address,
-                               "network" : n,
+                               "network" : network,
                        }
 
                        # Nothing found?
-                       if not n:
-                               print(_("Nothing found for %(address)s") % args)
+                       if not network:
+                               print(_("Nothing found for %(address)s") % args, file=sys.stderr)
                                ret = 1
                                continue
 
-                       # Try to retrieve the AS if we have an AS number
-                       if n.asn:
-                               a = self.db.get_as(n.asn)
+                       print("%s:" % address)
+                       print(format % (_("Network"), network))
+
+                       # Print country
+                       if network.country_code:
+                               country = db.get_country(network.country_code)
+
+                               print(format % (
+                                       _("Country"),
+                                       country.name if country else network.country_code),
+                               )
+
+                       # Print AS information
+                       if network.asn:
+                               autonomous_system = db.get_as(network.asn)
+
+                               print(format % (
+                                       _("Autonomous System"),
+                                       autonomous_system or "AS%s" % network.asn),
+                               )
+
+                       # Anonymous Proxy
+                       if network.has_flag(location.NETWORK_FLAG_ANONYMOUS_PROXY):
+                               print(format % (
+                                       _("Anonymous Proxy"), _("yes"),
+                               ))
+
+                       # Satellite Provider
+                       if network.has_flag(location.NETWORK_FLAG_SATELLITE_PROVIDER):
+                               print(format % (
+                                       _("Satellite Provider"), _("yes"),
+                               ))
+
+                       # Anycast
+                       if network.has_flag(location.NETWORK_FLAG_ANYCAST):
+                               print(format % (
+                                       _("Anycast"), _("yes"),
+                               ))
 
-                               # If we have found an AS we will print it in the message
-                               if a:
-                                       args.update({
-                                               "as" : a,
-                                       })
+               return ret
 
-                                       print(_("%(address)s belongs to %(network)s which is a part of %(as)s") % args)
-                                       continue
+       def handle_dump(self, db, ns):
+               # Use output file or write to stdout
+               f = ns.output or sys.stdout
 
-                       print(_("%(address)s belongs to %(network)s") % args)
+               # Format everything like this
+               format = "%-24s %s\n"
 
-               return ret
+               # Write metadata
+               f.write("#\n# Location Database Export\n#\n")
+
+               f.write("# Generated: %s\n" % time.strftime(
+                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+               ))
+
+               if db.vendor:
+                       f.write("# Vendor:    %s\n" % db.vendor)
+
+               if db.license:
+                       f.write("# License:   %s\n" % db.license)
+
+               f.write("#\n")
+
+               if db.description:
+                       for line in db.description.splitlines():
+                               f.write("# %s\n" % line)
+
+                       f.write("#\n")
 
-       def handle_get_as(self, ns):
+               # Iterate over all ASes
+               for a in db.ases:
+                       f.write("\n")
+                       f.write(format % ("aut-num:", "AS%s" % a.number))
+                       f.write(format % ("name:", a.name))
+
+               flags = {
+                       location.NETWORK_FLAG_ANONYMOUS_PROXY    : "is-anonymous-proxy:",
+                       location.NETWORK_FLAG_SATELLITE_PROVIDER : "is-satellite-provider:",
+                       location.NETWORK_FLAG_ANYCAST            : "is-anycast:",
+               }
+
+               # Iterate over all networks
+               for n in db.networks:
+                       f.write("\n")
+                       f.write(format % ("net:", n))
+
+                       if n.country_code:
+                               f.write(format % ("country:", n.country_code))
+
+                       if n.asn:
+                               f.write(format % ("aut-num:", n.asn))
+
+                       # Print all flags
+                       for flag in flags:
+                               if n.has_flag(flag):
+                                       f.write(format % (flags[flag], "yes"))
+
+       def handle_get_as(self, db, ns):
                """
                        Gets information about Autonomous Systems
                """
@@ -131,16 +400,16 @@ class CLI(object):
                        try:
                                asn = int(asn)
                        except ValueError:
-                               sys.stderr.write("Invalid ASN: %s" %asn)
+                               print(_("Invalid ASN: %s") % asn, file=sys.stderr)
                                ret = 1
                                continue
 
                        # Fetch AS from database
-                       a = self.db.get_as(asn)
+                       a = db.get_as(asn)
 
                        # Nothing found
                        if not a:
-                               print(_("Could not find AS%s") % asn)
+                               print(_("Could not find AS%s") % asn, file=sys.stderr)
                                ret = 1
                                continue
 
@@ -148,6 +417,54 @@ class CLI(object):
 
                return ret
 
+       def handle_search_as(self, db, ns):
+               for query in ns.query:
+                       # Print all matches ASes
+                       for a in db.search_as(query):
+                               print(a)
+
+       def __get_output_formatter(self, ns):
+               try:
+                       cls = self.output_formats[ns.output_format]
+               except KeyError:
+                       cls = OutputFormatter
+
+               return cls(ns)
+
+       def handle_list_networks_by_as(self, db, ns):
+               with self.__get_output_formatter(ns) as f:
+                       for asn in ns.asn:
+                               # Print all matching networks
+                               for n in db.search_networks(asn=asn, family=ns.family):
+                                       f.network(n)
+
+       def handle_list_networks_by_cc(self, db, ns):
+               with self.__get_output_formatter(ns) as f:
+                       for country_code in ns.country_code:
+                               # Print all matching networks
+                               for n in db.search_networks(country_code=country_code, family=ns.family):
+                                       f.network(n)
+
+       def handle_list_networks_by_flags(self, db, ns):
+               flags = 0
+
+               if ns.anonymous_proxy:
+                       flags |= location.NETWORK_FLAG_ANONYMOUS_PROXY
+
+               if ns.satellite_provider:
+                       flags |= location.NETWORK_FLAG_SATELLITE_PROVIDER
+
+               if ns.anycast:
+                       flags |= location.NETWORK_FLAG_ANYCAST
+
+               if not flags:
+                       raise ValueError(_("You must at least pass one flag"))
+
+               with self.__get_output_formatter(ns) as f:
+                       for n in db.search_networks(flags=flags, family=ns.family):
+                               f.network(n)
+
+
 def main():
        # Run the command line interface
        c = CLI()