X-Git-Url: http://git.ipfire.org/?p=location%2Flibloc.git;a=blobdiff_plain;f=src%2Fpython%2Flocation.in;h=070640c2ae5cd5a07ce3875e76f444d9a2c6d719;hp=10618e2f2b74c44e78815e9665a4166fbf5bb52b;hb=0c74f6b1a3bdce5ebdc2ee452b9baf3e421dd3d1;hpb=1d237439676e8b9ee10a6dde2c64f5ba3a057210 diff --git a/src/python/location.in b/src/python/location.in index 10618e2..070640c 100644 --- a/src/python/location.in +++ b/src/python/location.in @@ -22,6 +22,7 @@ import datetime import ipaddress import logging import os +import re import shutil import socket import sys @@ -30,6 +31,8 @@ import time # Load our location module import location import location.downloader +import location.export + from location.i18n import _ # Setup logging @@ -37,88 +40,7 @@ log = logging.getLogger("location") # Output formatters -class OutputFormatter(object): - def __init__(self, ns): - self.ns = ns - - def __enter__(self): - # Open the output - self.open() - - return self - - def __exit__(self, type, value, tb): - if tb is None: - self.close() - - @property - def name(self): - if "country_code" in self.ns: - return "networks_country_%s" % self.ns.country_code[0] - - elif "asn" in self.ns: - return "networks_AS%s" % self.ns.asn[0] - - def open(self): - pass - - def close(self): - pass - - def network(self, network): - print(network) - - -class IpsetOutputFormatter(OutputFormatter): - """ - For nftables - """ - def open(self): - print("create %s hash:net family inet hashsize 1024 maxelem 65536" % self.name) - - def network(self, network): - print("add %s %s" % (self.name, network)) - - -class NftablesOutputFormatter(OutputFormatter): - """ - For nftables - """ - def open(self): - print("define %s = {" % self.name) - - def close(self): - print("}") - - def network(self, network): - print(" %s," % network) - - -class XTGeoIPOutputFormatter(OutputFormatter): - """ - Formats the output in that way, that it can be loaded by - the xt_geoip kernel module from xtables-addons. - """ - def network(self, network): - n = ipaddress.ip_network("%s" % network) - - for address in (n.network_address, n.broadcast_address): - bytes = socket.inet_pton( - socket.AF_INET6 if address.version == 6 else socket.AF_INET, - "%s" % address, - ) - - os.write(1, bytes) - - class CLI(object): - output_formats = { - "ipset" : IpsetOutputFormatter, - "list" : OutputFormatter, - "nftables" : NftablesOutputFormatter, - "xt_geoip" : XTGeoIPOutputFormatter, - } - def parse_cli(self): parser = argparse.ArgumentParser( description=_("Location Database Command Line Interface"), @@ -166,6 +88,10 @@ class CLI(object): # Update update = subparsers.add_parser("update", help=_("Update database")) + update.add_argument("--cron", + help=_("Update the library only once per interval"), + choices=("daily", "weekly", "monthly"), + ) update.set_defaults(func=self.handle_update) # Verify @@ -193,8 +119,8 @@ class CLI(object): ) list_networks_by_as.add_argument("asn", nargs=1, type=int) list_networks_by_as.add_argument("--family", choices=("ipv6", "ipv4")) - list_networks_by_as.add_argument("--output-format", - choices=self.output_formats.keys(), default="list") + list_networks_by_as.add_argument("--format", + choices=location.export.formats.keys(), default="list") list_networks_by_as.set_defaults(func=self.handle_list_networks_by_as) # List all networks in a country @@ -203,8 +129,8 @@ class CLI(object): ) list_networks_by_cc.add_argument("country_code", nargs=1) list_networks_by_cc.add_argument("--family", choices=("ipv6", "ipv4")) - list_networks_by_cc.add_argument("--output-format", - choices=self.output_formats.keys(), default="list") + list_networks_by_cc.add_argument("--format", + choices=location.export.formats.keys(), default="list") list_networks_by_cc.set_defaults(func=self.handle_list_networks_by_cc) # List all networks with flags @@ -221,10 +147,35 @@ class CLI(object): action="store_true", help=_("Anycasts"), ) list_networks_by_flags.add_argument("--family", choices=("ipv6", "ipv4")) - list_networks_by_flags.add_argument("--output-format", - choices=self.output_formats.keys(), default="list") + list_networks_by_flags.add_argument("--format", + choices=location.export.formats.keys(), default="list") list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags) + # List countries + list_countries = subparsers.add_parser("list-countries", + help=_("Lists all countries"), + ) + list_countries.add_argument("--show-name", + action="store_true", help=_("Show the name of the country"), + ) + list_countries.add_argument("--show-continent", + action="store_true", help=_("Show the continent"), + ) + list_countries.set_defaults(func=self.handle_list_countries) + + # Export + export = subparsers.add_parser("export", + help=_("Exports data in many formats to load it into packet filters"), + ) + export.add_argument("--format", help=_("Output format"), + choices=location.export.formats.keys(), default="list") + export.add_argument("--directory", help=_("Output directory"), required=True) + export.add_argument("--family", + help=_("Specify address family"), choices=("ipv6", "ipv4"), + ) + export.add_argument("objects", nargs="*", help=_("List country codes or ASNs to export")) + export.set_defaults(func=self.handle_export) + args = parser.parse_args() # Configure logging @@ -248,9 +199,14 @@ class CLI(object): try: db = location.Database(args.database) except FileNotFoundError as e: - sys.stderr.write("location: Could not open database %s: %s\n" \ - % (args.database, e)) - sys.exit(1) + # Allow continuing without a database + if args.func == self.handle_update: + db = None + + else: + sys.stderr.write("location: Could not open database %s: %s\n" \ + % (args.database, e)) + sys.exit(1) # Translate family (if present) if "family" in args: @@ -374,7 +330,8 @@ class CLI(object): if db.description: for line in db.description.splitlines(): - f.write("# %s\n" % line) + line = "# %s" % line + f.write("%s\n" % line.rstrip()) f.write("#\n") @@ -440,14 +397,32 @@ class CLI(object): print(a) def handle_update(self, db, ns): + if ns.cron and db: + now = datetime.datetime.utcnow() + + # Parse the database timestamp + t = datetime.datetime.utcfromtimestamp(db.created_at) + + if ns.cron == "daily": + delta = datetime.timedelta(days=1) + elif ns.cron == "weekly": + delta = datetime.timedelta(days=7) + elif ns.cron == "monthly": + delta = datetime.timedelta(days=30) + + # Check if the database has recently been updated + if t >= (now - delta): + log.info( + _("The database has been updated recently (%s)") % \ + format_timedelta(now - t), + ) + return 3 + # Fetch the timestamp we need from DNS t = location.discover_latest_version() - # Parse timestamp into datetime format - timestamp = datetime.datetime.fromtimestamp(t) if t else None - # Check the version of the local database - if db and timestamp and db.created_at >= timestamp.timestamp(): + if db and t and db.created_at >= t: log.info("Already on the latest version") return @@ -475,13 +450,7 @@ class CLI(object): return 0 - def handle_verify(self, ns): - try: - db = location.Database(ns.database) - except FileNotFoundError as e: - log.error("%s: %s" % (ns.database, e)) - return 127 - + def handle_verify(self, db, ns): # Verify the database with open(ns.public_key, "r") as f: if not db.verify(f): @@ -494,25 +463,54 @@ class CLI(object): def __get_output_formatter(self, ns): try: - cls = self.output_formats[ns.output_format] + cls = location.export.formats[ns.format] except KeyError: - cls = OutputFormatter + cls = location.export.OutputFormatter + + return cls + + def handle_list_countries(self, db, ns): + for country in db.countries: + line = [ + country.code, + ] + + if ns.show_continent: + line.append(country.continent_code) + + if ns.show_name: + line.append(country.name) - return cls(ns) + # Format the output + line = " ".join(line) + + # Print the output + print(line) def handle_list_networks_by_as(self, db, ns): - with self.__get_output_formatter(ns) as f: - for asn in ns.asn: - # Print all matching networks - for n in db.search_networks(asn=asn, family=ns.family): - f.network(n) + writer = self.__get_output_formatter(ns) + + for asn in ns.asn: + f = writer(sys.stdout, prefix="AS%s" % asn) + + # Print all matching networks + for n in db.search_networks(asn=asn, family=ns.family): + f.write(n) + + f.finish() def handle_list_networks_by_cc(self, db, ns): - with self.__get_output_formatter(ns) as f: - for country_code in ns.country_code: - # Print all matching networks - for n in db.search_networks(country_code=country_code, family=ns.family): - f.network(n) + writer = self.__get_output_formatter(ns) + + for country_code in ns.country_code: + # Open standard output + f = writer(sys.stdout, prefix=country_code) + + # Print all matching networks + for n in db.search_networks(country_code=country_code, family=ns.family): + f.write(n) + + f.finish() def handle_list_networks_by_flags(self, db, ns): flags = 0 @@ -529,10 +527,79 @@ class CLI(object): if not flags: raise ValueError(_("You must at least pass one flag")) - with self.__get_output_formatter(ns) as f: - for n in db.search_networks(flags=flags, family=ns.family): - f.network(n) + writer = self.__get_output_formatter(ns) + f = writer(sys.stdout, prefix="custom") + + for n in db.search_networks(flags=flags, family=ns.family): + f.write(n) + + f.finish() + + def handle_export(self, db, ns): + countries, asns = [], [] + + # Translate family + if ns.family: + families = [ ns.family ] + else: + families = [ socket.AF_INET6, socket.AF_INET ] + + for object in ns.objects: + m = re.match("^AS(\d+)$", object) + if m: + object = int(m.group(1)) + + asns.append(object) + + elif location.country_code_is_valid(object) \ + or object in ("A1", "A2", "A3"): + countries.append(object) + + else: + log.warning("Invalid argument: %s" % object) + continue + + # Default to exporting all countries + if not countries and not asns: + countries = ["A1", "A2", "A3"] + [country.code for country in db.countries] + + # Select the output format + writer = self.__get_output_formatter(ns) + + e = location.export.Exporter(db, writer) + e.export(ns.directory, countries=countries, asns=asns, families=families) + + +def format_timedelta(t): + s = [] + + if t.days: + s.append( + _("One Day", "%(days)s Days", t.days) % { "days" : t.days, } + ) + + hours = t.seconds // 3600 + if hours: + s.append( + _("One Hour", "%(hours)s Hours", hours) % { "hours" : hours, } + ) + + minutes = (t.seconds % 3600) // 60 + if minutes: + s.append( + _("One Minute", "%(minutes)s Minutes", minutes) % { "minutes" : minutes, } + ) + + seconds = t.seconds % 60 + if t.seconds: + s.append( + _("One Second", "%(seconds)s Seconds", seconds) % { "seconds" : seconds, } + ) + + if not s: + return _("Now") + return _("%s ago") % ", ".join(s) def main(): # Run the command line interface