]> git.ipfire.org Git - location/libloc.git/commitdiff
python: Implement importing override files into the database
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 13 May 2020 17:19:31 +0000 (17:19 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 13 May 2020 17:20:04 +0000 (17:20 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/python/importer.py
src/python/location-importer.in

index a4a11f51ff4c4461ad442833a246a96dab1f2f5a..e24684cd54db40eb512f5d365e979cc5b95eebc4 100644 (file)
@@ -152,6 +152,23 @@ class DownloaderContext(object):
                return self.response
 
 
+def read_blocks(f):
+       for block in iterate_over_blocks(f):
+               type = None
+               data = {}
+
+               for i, line in enumerate(block):
+                       key, value = line.split(":", 1)
+
+                       # The key of the first line defines the type
+                       if i == 0:
+                               type = key
+
+                       # Store value
+                       data[key] = value.strip()
+
+               yield type, data
+
 def iterate_over_blocks(f, charsets=("utf-8", "latin1")):
        block = []
 
index c2413e54ba9c9c519579a8c196717724e941af91..a4e24da1841601171d64d3f2ee3a14e14202b503 100644 (file)
@@ -71,6 +71,15 @@ class CLI(object):
                update_announcements.add_argument("server", nargs=1,
                        help=_("Route Server to connect to"), metavar=_("SERVER"))
 
+               # Update overrides
+               update_overrides = subparsers.add_parser("update-overrides",
+                       help=_("Update overrides"),
+               )
+               update_overrides.add_argument(
+                       "files", nargs="+", help=_("Files to import"),
+               )
+               update_overrides.set_defaults(func=self.handle_update_overrides)
+
                args = parser.parse_args()
 
                # Enable debug logging
@@ -128,6 +137,27 @@ class CLI(object):
                                CREATE TABLE IF NOT EXISTS networks(network inet, country text);
                                CREATE UNIQUE INDEX IF NOT EXISTS networks_network ON networks(network);
                                CREATE INDEX IF NOT EXISTS networks_search ON networks USING GIST(network inet_ops);
+
+                               -- overrides
+                               CREATE TABLE IF NOT EXISTS autnum_overrides(
+                                       number bigint NOT NULL,
+                                       name text,
+                                       is_anonymous_proxy boolean DEFAULT FALSE,
+                                       is_satellite_provider boolean DEFAULT FALSE,
+                                       is_anycast boolean DEFAULT FALSE
+                               );
+                               CREATE UNIQUE INDEX IF NOT EXISTS autnum_overrides_number
+                                       ON autnum_overrides(number);
+
+                               CREATE TABLE IF NOT EXISTS network_overrides(
+                                       network inet NOT NULL,
+                                       country text,
+                                       is_anonymous_proxy boolean DEFAULT FALSE,
+                                       is_satellite_provider boolean DEFAULT FALSE,
+                                       is_anycast boolean DEFAULT FALSE
+                               );
+                               CREATE UNIQUE INDEX IF NOT EXISTS network_overrides_network
+                                       ON network_overrides(network);
                        """)
 
                return db
@@ -380,6 +410,73 @@ class CLI(object):
                                        DELETE FROM announcements WHERE last_seen_at <= CURRENT_TIMESTAMP - INTERVAL '14 days';
                                """)
 
+       def handle_update_overrides(self, ns):
+               with self.db.transaction():
+                       # Drop all data that we have
+                       self.db.execute("""
+                               TRUNCATE TABLE autnum_overrides;
+                               TRUNCATE TABLE network_overrides;
+                       """)
+
+                       for file in ns.files:
+                               log.info("Reading %s..." % file)
+
+                               with open(file, "rb") as f:
+                                       for type, block in location.importer.read_blocks(f):
+                                               if type == "net":
+                                                       network = block.get("net")
+                                                       # Try to parse and normalise the network
+                                                       try:
+                                                               network = ipaddress.ip_network(network, strict=False)
+                                                       except ValueError as e:
+                                                               log.warning("Invalid IP network: %s: %s" % (network, e))
+                                                               continue
+
+                                                       self.db.execute("""
+                                                               INSERT INTO network_overrides(
+                                                                       network,
+                                                                       country,
+                                                                       is_anonymous_proxy,
+                                                                       is_satellite_provider,
+                                                                       is_anycast
+                                                               ) VALUES (%s, %s, %s, %s)
+                                                               ON CONFLICT (network) DO NOTHING""",
+                                                               "%s" % network,
+                                                               block.get("country"),
+                                                               block.get("is-anonymous-proxy") == "yes",
+                                                               block.get("is-satellite-provider") == "yes",
+                                                               block.get("is-anycast") == "yes",
+                                                       )
+
+                                               elif type == "autnum":
+                                                       autnum = block.get("autnum")
+
+                                                       # Check if AS number begins with "AS"
+                                                       if not autnum.startswith("AS"):
+                                                               log.warning("Invalid AS number: %s" % autnum)
+                                                               continue
+
+                                                       # Strip "AS"
+                                                       autnum = autnum[2:]
+
+                                                       self.db.execute("""
+                                                               INSERT INTO autnum_overrides(
+                                                                       number,
+                                                                       name,
+                                                                       is_anonymous_proxy,
+                                                                       is_satellite_provider,
+                                                                       is_anycast
+                                                               ) VALUES(%s, %s, %s, %s, %s)
+                                                               ON CONFLICT DO NOTHING""",
+                                                               autnum, block.get("name"),
+                                                               block.get("is-anonymous-proxy") == "yes",
+                                                               block.get("is-satellite-provider") == "yes",
+                                                               block.get("is-anycast") == "yes",
+                                                       )
+
+                                               else:
+                                                       log.warning("Unsupport type: %s" % type)
+
 
 def split_line(line):
        key, colon, val = line.partition(":")