]> git.ipfire.org Git - people/ms/libloc.git/blobdiff - src/python/location-query.in
Implement signing/verifying databases
[people/ms/libloc.git] / src / python / location-query.in
index 733f0687d9f15425b9bf3bebb8a2c3fb716cea7a..c455d1e5c660c98a3de1d2e6caaab311991475d0 100644 (file)
@@ -19,6 +19,9 @@
 
 import argparse
 import gettext
+import ipaddress
+import os
+import socket
 import sys
 import syslog
 
@@ -32,7 +35,90 @@ def _(singular, plural=None, n=None):
 
        return gettext.dgettext("libloc", singular)
 
+# Output formatters
+
+class OutputFormatter(object):
+       def __init__(self, ns):
+               self.ns = ns
+
+       def __enter__(self):
+               # Open the output
+               self.open()
+
+               return self
+
+       def __exit__(self, type, value, tb):
+               if tb is None:
+                       self.close()
+
+       @property
+       def name(self):
+               if "country_code" in self.ns:
+                       return "networks_country_%s" % self.ns.country_code[0]
+
+               elif "asn" in self.ns:
+                       return "networks_AS%s" % self.ns.asn[0]
+
+       def open(self):
+               pass
+
+       def close(self):
+               pass
+
+       def network(self, network):
+               print(network)
+
+
+class IpsetOutputFormatter(OutputFormatter):
+       """
+               For nftables
+       """
+       def open(self):
+               print("create %s hash:net family inet hashsize 1024 maxelem 65536" % self.name)
+
+       def network(self, network):
+               print("add %s %s" % (self.name, network))
+
+
+class NftablesOutputFormatter(OutputFormatter):
+       """
+               For nftables
+       """
+       def open(self):
+               print("define %s = {" % self.name)
+
+       def close(self):
+               print("}")
+
+       def network(self, network):
+               print(" %s," % network)
+
+
+class XTGeoIPOutputFormatter(OutputFormatter):
+       """
+               Formats the output in that way, that it can be loaded by
+               the xt_geoip kernel module from xtables-addons.
+       """
+       def network(self, network):
+               n = ipaddress.ip_network("%s" % network)
+
+               for address in (n.network_address, n.broadcast_address):
+                       bytes = socket.inet_pton(
+                               socket.AF_INET6 if address.version == 6 else socket.AF_INET,
+                               "%s" % address,
+                       )
+
+                       os.write(1, bytes)
+
+
 class CLI(object):
+       output_formats = {
+               "ipset"    : IpsetOutputFormatter,
+               "list"     : OutputFormatter,
+               "nftables" : NftablesOutputFormatter,
+               "xt_geoip" : XTGeoIPOutputFormatter,
+       }
+
        def parse_cli(self):
                parser = argparse.ArgumentParser(
                        description=_("Location Database Command Line Interface"),
@@ -52,6 +138,11 @@ class CLI(object):
                        default="@databasedir@/database.db", help=_("Path to database"),
                )
 
+               # public key
+               parser.add_argument("--public-key", "-k",
+                       default="@databasedir@/signing-key.pem", help=_("Public Signing Key"),
+               )
+
                # lookup an IP address
                lookup = subparsers.add_parser("lookup",
                        help=_("Lookup one or multiple IP addresses"),
@@ -73,22 +164,54 @@ class CLI(object):
                search_as.add_argument("query", nargs=1)
                search_as.set_defaults(func=self.handle_search_as)
 
+               # List all networks in an AS
+               list_networks_by_as = subparsers.add_parser("list-networks-by-as",
+                       help=_("Lists all networks in an AS"),
+               )
+               list_networks_by_as.add_argument("asn", nargs=1, type=int)
+               list_networks_by_as.add_argument("--output-format",
+                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_as.set_defaults(func=self.handle_list_networks_by_as)
+
                # List all networks in a country
-               search_as = subparsers.add_parser("list-networks-by-cc",
+               list_networks_by_cc = subparsers.add_parser("list-networks-by-cc",
                        help=_("Lists all networks in a country"),
                )
-               search_as.add_argument("country_code", nargs=1)
-               search_as.set_defaults(func=self.handle_list_networks_by_cc)
+               list_networks_by_cc.add_argument("country_code", nargs=1)
+               list_networks_by_cc.add_argument("--output-format",
+                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_cc.set_defaults(func=self.handle_list_networks_by_cc)
+
+               # List all networks with flags
+               list_networks_by_flags = subparsers.add_parser("list-networks-by-flags",
+                       help=_("Lists all networks with flags"),
+               )
+               list_networks_by_flags.add_argument("--anonymous-proxy",
+                       action="store_true", help=_("Anonymous Proxies"),
+               )
+               list_networks_by_flags.add_argument("--satellite-provider",
+                       action="store_true", help=_("Satellite Providers"),
+               )
+               list_networks_by_flags.add_argument("--anycast",
+                       action="store_true", help=_("Anycasts"),
+               )
+               list_networks_by_flags.add_argument("--output-format",
+                       choices=self.output_formats.keys(), default="list")
+               list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags)
+
+               args = parser.parse_args()
+
+               # Print usage if no action was given
+               if not "func" in args:
+                       parser.print_usage()
+                       sys.exit(2)
 
-               return parser.parse_args()
+               return args
 
        def run(self):
                # Parse command line arguments
                args = self.parse_cli()
 
-               # Callback function must be defined
-               assert args.func, "Callback function not defined"
-
                # Open database
                try:
                        db = location.Database(args.database)
@@ -97,6 +220,18 @@ class CLI(object):
                                % (args.database, e))
                        sys.exit(1)
 
+               # Verify the database
+               try:
+                       with open(args.public_key, "r") as f:
+                               if not db.verify(f):
+                                       sys.stderr.write("location-query: Could not verify the database\n")
+                                       sys.exit(1)
+
+               # Catch any errors when loading the public key
+               except (FileNotFoundError, OSError) as e:
+                       sys.stderr.write("Could not read the public key: %s\n" % e)
+                       sys.exit(1)
+
                # Call function
                ret = args.func(db, args)
 
@@ -177,11 +312,44 @@ class CLI(object):
                        for a in db.search_as(query):
                                print(a)
 
+       def __get_output_formatter(self, ns):
+               try:
+                       cls = self.output_formats[ns.output_format]
+               except KeyError:
+                       cls = OutputFormatter
+
+               return cls(ns)
+
+       def handle_list_networks_by_as(self, db, ns):
+               with self.__get_output_formatter(ns) as f:
+                       for asn in ns.asn:
+                               # Print all matching networks
+                               for n in db.search_networks(asn=asn):
+                                       f.network(n)
+
        def handle_list_networks_by_cc(self, db, ns):
-               for country_code in ns.country_code:
-                       # Print all matching networks
-                       for n in db.search_networks(country_code=country_code):
-                               print(n)
+               with self.__get_output_formatter(ns) as f:
+                       for country_code in ns.country_code:
+                               # Print all matching networks
+                               for n in db.search_networks(country_code=country_code):
+                                       f.network(n)
+
+       def handle_list_networks_by_flags(self, db, ns):
+               flags = 0
+
+               if ns.anonymous_proxy:
+                       flags |= location.NETWORK_FLAG_ANONYMOUS_PROXY
+
+               if ns.satellite_provider:
+                       flags |= location.NETWORK_FLAG_SATELLITE_PROVIDER
+
+               if ns.anycast:
+                       flags |= location.NETWORK_FLAG_ANYCAST
+
+               with self.__get_output_formatter(ns) as f:
+                       for n in db.search_networks(flags=flags):
+                               f.network(n)
+
 
 def main():
        # Run the command line interface