]> git.ipfire.org Git - location/libloc.git/commitdiff
location-importer: Implement importing/exporting countries
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 19 May 2020 17:44:59 +0000 (17:44 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 19 May 2020 17:44:59 +0000 (17:44 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/python/location-importer.in

index 3b7e37f5923792d62983e790dc30ecfc7ac0f32a..d412f104edb63aa4b7265a5a13cc08e0142155ec 100644 (file)
@@ -93,6 +93,14 @@ class CLI(object):
                )
                update_overrides.set_defaults(func=self.handle_update_overrides)
 
+               # Import countries
+               import_countries = subparsers.add_parser("import-countries",
+                       help=_("Import countries"),
+               )
+               import_countries.add_argument("file", nargs=1, type=argparse.FileType("r"),
+                       help=_("File to import"))
+               import_countries.set_defaults(func=self.handle_import_countries)
+
                args = parser.parse_args()
 
                # Configure logging
@@ -148,6 +156,11 @@ class CLI(object):
                                CREATE TABLE IF NOT EXISTS autnums(number bigint, name text NOT NULL);
                                CREATE UNIQUE INDEX IF NOT EXISTS autnums_number ON autnums(number);
 
+                               -- countries
+                               CREATE TABLE IF NOT EXISTS countries(
+                                       country_code text NOT NULL, name text NOT NULL, continent_code text NOT NULL);
+                               CREATE UNIQUE INDEX IF NOT EXISTS countries_country_code ON countries(country_code);
+
                                -- networks
                                CREATE TABLE IF NOT EXISTS networks(network inet, country text);
                                CREATE UNIQUE INDEX IF NOT EXISTS networks_network ON networks(network);
@@ -301,6 +314,15 @@ class CLI(object):
                        if row.is_anycast:
                                network.set_flag(location.NETWORK_FLAG_ANYCAST)
 
+               # Add all countries
+               log.info("Writing countries...")
+               rows = self.db.query("SELECT * FROM countries ORDER BY country_code")
+
+               for row in rows:
+                       c = writer.add_country(row.country_code)
+                       c.continent_code = row.continent_code
+                       c.name = row.name
+
                # Write everything to file
                log.info("Writing database to file...")
                for file in ns.file:
@@ -661,6 +683,28 @@ class CLI(object):
                                                else:
                                                        log.warning("Unsupport type: %s" % type)
 
+       def handle_import_countries(self, ns):
+               with self.db.transaction():
+                       # Drop all data that we have
+                       self.db.execute("TRUNCATE TABLE countries")
+
+                       for file in ns.file:
+                               for line in file:
+                                       line = line.rstrip()
+
+                                       # Ignore any comments
+                                       if line.startswith("#"):
+                                               continue
+
+                                       try:
+                                               country_code, continent_code, name = line.split(maxsplit=2)
+                                       except:
+                                               log.warning("Could not parse line: %s" % line)
+                                               continue
+
+                                       self.db.execute("INSERT INTO countries(country_code, name, continent_code) \
+                                               VALUES(%s, %s, %s) ON CONFLICT DO NOTHING", country_code, name, continent_code)
+
 
 def split_line(line):
        key, colon, val = line.partition(":")