import ipaddress
import logging
import os
+import re
import shutil
import socket
import sys
# Load our location module
import location
import location.downloader
+import location.export
+
from location.i18n import _
# Setup logging
# Output formatters
-class OutputFormatter(object):
- def __init__(self, ns):
- self.ns = ns
-
- def __enter__(self):
- # Open the output
- self.open()
-
- return self
-
- def __exit__(self, type, value, tb):
- if tb is None:
- self.close()
-
- @property
- def name(self):
- if "country_code" in self.ns:
- return "networks_country_%s" % self.ns.country_code[0]
-
- elif "asn" in self.ns:
- return "networks_AS%s" % self.ns.asn[0]
-
- def open(self):
- pass
-
- def close(self):
- pass
-
- def network(self, network):
- print(network)
-
-
-class IpsetOutputFormatter(OutputFormatter):
- """
- For nftables
- """
- def open(self):
- print("create %s hash:net family inet hashsize 1024 maxelem 65536" % self.name)
-
- def network(self, network):
- print("add %s %s" % (self.name, network))
-
-
-class NftablesOutputFormatter(OutputFormatter):
- """
- For nftables
- """
- def open(self):
- print("define %s = {" % self.name)
-
- def close(self):
- print("}")
-
- def network(self, network):
- print(" %s," % network)
-
-
-class XTGeoIPOutputFormatter(OutputFormatter):
- """
- Formats the output in that way, that it can be loaded by
- the xt_geoip kernel module from xtables-addons.
- """
- def network(self, network):
- n = ipaddress.ip_network("%s" % network)
-
- for address in (n.network_address, n.broadcast_address):
- bytes = socket.inet_pton(
- socket.AF_INET6 if address.version == 6 else socket.AF_INET,
- "%s" % address,
- )
-
- os.write(1, bytes)
-
-
class CLI(object):
- output_formats = {
- "ipset" : IpsetOutputFormatter,
- "list" : OutputFormatter,
- "nftables" : NftablesOutputFormatter,
- "xt_geoip" : XTGeoIPOutputFormatter,
- }
-
def parse_cli(self):
parser = argparse.ArgumentParser(
description=_("Location Database Command Line Interface"),
# Update
update = subparsers.add_parser("update", help=_("Update database"))
+ update.add_argument("--cron",
+ help=_("Update the library only once per interval"),
+ choices=("daily", "weekly", "monthly"),
+ )
update.set_defaults(func=self.handle_update)
# Verify
)
list_networks_by_as.add_argument("asn", nargs=1, type=int)
list_networks_by_as.add_argument("--family", choices=("ipv6", "ipv4"))
- list_networks_by_as.add_argument("--output-format",
- choices=self.output_formats.keys(), default="list")
+ list_networks_by_as.add_argument("--format",
+ choices=location.export.formats.keys(), default="list")
list_networks_by_as.set_defaults(func=self.handle_list_networks_by_as)
# List all networks in a country
)
list_networks_by_cc.add_argument("country_code", nargs=1)
list_networks_by_cc.add_argument("--family", choices=("ipv6", "ipv4"))
- list_networks_by_cc.add_argument("--output-format",
- choices=self.output_formats.keys(), default="list")
+ list_networks_by_cc.add_argument("--format",
+ choices=location.export.formats.keys(), default="list")
list_networks_by_cc.set_defaults(func=self.handle_list_networks_by_cc)
# List all networks with flags
action="store_true", help=_("Anycasts"),
)
list_networks_by_flags.add_argument("--family", choices=("ipv6", "ipv4"))
- list_networks_by_flags.add_argument("--output-format",
- choices=self.output_formats.keys(), default="list")
+ list_networks_by_flags.add_argument("--format",
+ choices=location.export.formats.keys(), default="list")
list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags)
+ # List countries
+ list_countries = subparsers.add_parser("list-countries",
+ help=_("Lists all countries"),
+ )
+ list_countries.add_argument("--show-name",
+ action="store_true", help=_("Show the name of the country"),
+ )
+ list_countries.add_argument("--show-continent",
+ action="store_true", help=_("Show the continent"),
+ )
+ list_countries.set_defaults(func=self.handle_list_countries)
+
+ # Export
+ export = subparsers.add_parser("export",
+ help=_("Exports data in many formats to load it into packet filters"),
+ )
+ export.add_argument("--format", help=_("Output format"),
+ choices=location.export.formats.keys(), default="list")
+ export.add_argument("--directory", help=_("Output directory"), required=True)
+ export.add_argument("--family",
+ help=_("Specify address family"), choices=("ipv6", "ipv4"),
+ )
+ export.add_argument("objects", nargs="*", help=_("List country codes or ASNs to export"))
+ export.set_defaults(func=self.handle_export)
+
args = parser.parse_args()
# Configure logging
try:
db = location.Database(args.database)
except FileNotFoundError as e:
- sys.stderr.write("location: Could not open database %s: %s\n" \
- % (args.database, e))
- sys.exit(1)
+ # Allow continuing without a database
+ if args.func == self.handle_update:
+ db = None
+
+ else:
+ sys.stderr.write("location: Could not open database %s: %s\n" \
+ % (args.database, e))
+ sys.exit(1)
# Translate family (if present)
if "family" in args:
if db.description:
for line in db.description.splitlines():
- f.write("# %s\n" % line)
+ line = "# %s" % line
+ f.write("%s\n" % line.rstrip())
f.write("#\n")
print(a)
def handle_update(self, db, ns):
+ if ns.cron and db:
+ now = datetime.datetime.utcnow()
+
+ # Parse the database timestamp
+ t = datetime.datetime.utcfromtimestamp(db.created_at)
+
+ if ns.cron == "daily":
+ delta = datetime.timedelta(days=1)
+ elif ns.cron == "weekly":
+ delta = datetime.timedelta(days=7)
+ elif ns.cron == "monthly":
+ delta = datetime.timedelta(days=30)
+
+ # Check if the database has recently been updated
+ if t >= (now - delta):
+ log.info(
+ _("The database has been updated recently (%s)") % \
+ format_timedelta(now - t),
+ )
+ return 3
+
# Fetch the timestamp we need from DNS
t = location.discover_latest_version()
- # Parse timestamp into datetime format
- timestamp = datetime.datetime.fromtimestamp(t) if t else None
-
# Check the version of the local database
- if db and timestamp and db.created_at >= timestamp.timestamp():
+ if db and t and db.created_at >= t:
log.info("Already on the latest version")
return
return 0
- def handle_verify(self, ns):
- try:
- db = location.Database(ns.database)
- except FileNotFoundError as e:
- log.error("%s: %s" % (ns.database, e))
- return 127
-
+ def handle_verify(self, db, ns):
# Verify the database
with open(ns.public_key, "r") as f:
if not db.verify(f):
def __get_output_formatter(self, ns):
try:
- cls = self.output_formats[ns.output_format]
+ cls = location.export.formats[ns.format]
except KeyError:
- cls = OutputFormatter
+ cls = location.export.OutputFormatter
+
+ return cls
+
+ def handle_list_countries(self, db, ns):
+ for country in db.countries:
+ line = [
+ country.code,
+ ]
+
+ if ns.show_continent:
+ line.append(country.continent_code)
+
+ if ns.show_name:
+ line.append(country.name)
- return cls(ns)
+ # Format the output
+ line = " ".join(line)
+
+ # Print the output
+ print(line)
def handle_list_networks_by_as(self, db, ns):
- with self.__get_output_formatter(ns) as f:
- for asn in ns.asn:
- # Print all matching networks
- for n in db.search_networks(asn=asn, family=ns.family):
- f.network(n)
+ writer = self.__get_output_formatter(ns)
+
+ for asn in ns.asn:
+ f = writer(sys.stdout, prefix="AS%s" % asn)
+
+ # Print all matching networks
+ for n in db.search_networks(asn=asn, family=ns.family):
+ f.write(n)
+
+ f.finish()
def handle_list_networks_by_cc(self, db, ns):
- with self.__get_output_formatter(ns) as f:
- for country_code in ns.country_code:
- # Print all matching networks
- for n in db.search_networks(country_code=country_code, family=ns.family):
- f.network(n)
+ writer = self.__get_output_formatter(ns)
+
+ for country_code in ns.country_code:
+ # Open standard output
+ f = writer(sys.stdout, prefix=country_code)
+
+ # Print all matching networks
+ for n in db.search_networks(country_code=country_code, family=ns.family):
+ f.write(n)
+
+ f.finish()
def handle_list_networks_by_flags(self, db, ns):
flags = 0
if not flags:
raise ValueError(_("You must at least pass one flag"))
- with self.__get_output_formatter(ns) as f:
- for n in db.search_networks(flags=flags, family=ns.family):
- f.network(n)
+ writer = self.__get_output_formatter(ns)
+ f = writer(sys.stdout, prefix="custom")
+
+ for n in db.search_networks(flags=flags, family=ns.family):
+ f.write(n)
+
+ f.finish()
+
+ def handle_export(self, db, ns):
+ countries, asns = [], []
+
+ # Translate family
+ if ns.family:
+ families = [ ns.family ]
+ else:
+ families = [ socket.AF_INET6, socket.AF_INET ]
+
+ for object in ns.objects:
+ m = re.match("^AS(\d+)$", object)
+ if m:
+ object = int(m.group(1))
+
+ asns.append(object)
+
+ elif location.country_code_is_valid(object) \
+ or object in ("A1", "A2", "A3"):
+ countries.append(object)
+
+ else:
+ log.warning("Invalid argument: %s" % object)
+ continue
+
+ # Default to exporting all countries
+ if not countries and not asns:
+ countries = ["A1", "A2", "A3"] + [country.code for country in db.countries]
+
+ # Select the output format
+ writer = self.__get_output_formatter(ns)
+
+ e = location.export.Exporter(db, writer)
+ e.export(ns.directory, countries=countries, asns=asns, families=families)
+
+
+def format_timedelta(t):
+ s = []
+
+ if t.days:
+ s.append(
+ _("One Day", "%(days)s Days", t.days) % { "days" : t.days, }
+ )
+
+ hours = t.seconds // 3600
+ if hours:
+ s.append(
+ _("One Hour", "%(hours)s Hours", hours) % { "hours" : hours, }
+ )
+
+ minutes = (t.seconds % 3600) // 60
+ if minutes:
+ s.append(
+ _("One Minute", "%(minutes)s Minutes", minutes) % { "minutes" : minutes, }
+ )
+
+ seconds = t.seconds % 60
+ if t.seconds:
+ s.append(
+ _("One Second", "%(seconds)s Seconds", seconds) % { "seconds" : seconds, }
+ )
+
+ if not s:
+ return _("Now")
+ return _("%s ago") % ", ".join(s)
def main():
# Run the command line interface