From 2538ed9a602264b473f108b8f99a465771e9abfe Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Wed, 2 Oct 2019 13:55:33 +0000 Subject: [PATCH] location-query: Allow passing the database path Signed-off-by: Michael Tremer --- src/python/location-query.in | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/python/location-query.in b/src/python/location-query.in index 1433f1a..7a7492c 100644 --- a/src/python/location-query.in +++ b/src/python/location-query.in @@ -33,10 +33,6 @@ def _(singular, plural=None, n=None): return gettext.dgettext("libloc", singular) class CLI(object): - def __init__(self): - # Open database - self.db = location.Database("@databasedir@/database.db") - def parse_cli(self): parser = argparse.ArgumentParser( description=_("Location Database Command Line Interface"), @@ -51,6 +47,11 @@ class CLI(object): parser.add_argument("--version", action="version", version="%%(prog)s %s" % location.__version__) + # database + parser.add_argument("--database", "-d", + default="@databasedir@/database.db", help=_("Path to database"), + ) + # lookup an IP address lookup = subparsers.add_parser("lookup", help=_("Lookup one or multiple IP addresses"), @@ -81,8 +82,16 @@ class CLI(object): # Callback function must be defined assert args.func, "Callback function not defined" + # Open database + try: + db = location.Database(args.database) + except FileNotFoundError as e: + sys.stderr.write("location-query: Could not open database %s: %s\n" \ + % (args.database, e)) + sys.exit(1) + # Call function - ret = args.func(args) + ret = args.func(db, args) # Return with exit code if ret: @@ -91,12 +100,12 @@ class CLI(object): # Otherwise just exit sys.exit(0) - def handle_lookup(self, ns): + def handle_lookup(self, db, ns): ret = 0 for address in ns.address: try: - n = self.db.lookup(address) + n = db.lookup(address) except ValueError: print(_("Invalid IP address: %s") % address, file=sys.stderr) @@ -113,7 +122,7 @@ class CLI(object): # Try to retrieve the AS if we have an AS number if n.asn: - a = self.db.get_as(n.asn) + a = db.get_as(n.asn) # If we have found an AS we will print it in the message if a: @@ -128,7 +137,7 @@ class CLI(object): return ret - def handle_get_as(self, ns): + def handle_get_as(self, db, ns): """ Gets information about Autonomous Systems """ @@ -143,7 +152,7 @@ class CLI(object): continue # Fetch AS from database - a = self.db.get_as(asn) + a = db.get_as(asn) # Nothing found if not a: @@ -155,10 +164,10 @@ class CLI(object): return ret - def handle_search_as(self, ns): + def handle_search_as(self, db, ns): for query in ns.query: # Print all matches ASes - for a in self.db.search_as(query): + for a in db.search_as(query): print(a) def main(): -- 2.39.2