]> git.ipfire.org Git - location/libloc.git/commitdiff
python: Move tree flattening into C
authorMichael Tremer <michael.tremer@ipfire.org>
Fri, 13 Nov 2020 12:09:03 +0000 (12:09 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Fri, 13 Nov 2020 12:09:03 +0000 (12:09 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/python/database.c
src/python/export.py

index 7f8c2c2ed9b8c27273fd486a1ff7d9205952dbd0..d169547dfa9782a00526c5e55738cfe8d4ef0dc5 100644 (file)
@@ -258,17 +258,19 @@ static PyObject* Database_networks_flattened(DatabaseObject *self) {
 }
 
 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
-       char* kwlist[] = { "country_code", "asn", "flags", "family", NULL };
+       char* kwlist[] = { "country_code", "asn", "flags", "family", "flatten", NULL };
        const char* country_code = NULL;
        unsigned int asn = 0;
        int flags = 0;
        int family = 0;
+       int flatten = 0;
 
-       if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siii", kwlist, &country_code, &asn, &flags, &family))
+       if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siiip", kwlist, &country_code, &asn, &flags, &family, &flatten))
                return NULL;
 
        struct loc_database_enumerator* enumerator;
-       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS, 0);
+       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
+               (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
        if (r) {
                PyErr_SetFromErrno(PyExc_SystemError);
                return NULL;
index dd44332c292ab771b46d1ee4525946de7f9cfefd..be4a68efb2fc36b762288bdb9d8346f2e8223fd7 100644 (file)
@@ -87,58 +87,12 @@ class OutputWriter(object):
        def _write_network(self, network):
                self.f.write("%s\n" % network)
 
-       def write(self, network, subnets):
+       def write(self, network):
                if self.flatten and self._flatten(network):
                        log.debug("Skipping writing network %s (last one was %s)" % (network, self._last_network))
                        return
 
-               # Convert network into a Python object
-               _network = ipaddress.ip_network("%s" % network)
-
-               # Write the network when it has no subnets
-               if not subnets:
-                       log.debug("Writing %s to %s" % (_network, self.f))
-                       return self._write_network(_network)
-
-               # Convert subnets into Python objects
-               _subnets = [ipaddress.ip_network("%s" % subnet) for subnet in subnets]
-
-               # Split the network into smaller bits so that
-               # we can accomodate for any gaps in it later
-               to_check = set()
-               for _subnet in _subnets:
-                       to_check.update(
-                               _network.address_exclude(_subnet)
-                       )
-
-               # Clear the list of all subnets
-               subnets = []
-
-               # Check if all subnets to not overlap with anything else
-               while to_check:
-                       subnet_to_check = to_check.pop()
-
-                       for _subnet in _subnets:
-                               # Drop this subnet if it equals one of the subnets
-                               # or if it is subnet of one of them
-                               if subnet_to_check == _subnet or subnet_to_check.subnet_of(_subnet):
-                                       break
-
-                               # Break it down if it overlaps
-                               if subnet_to_check.overlaps(_subnet):
-                                       to_check.update(
-                                               subnet_to_check.address_exclude(_subnet)
-                                       )
-                                       break
-
-                       # Add the subnet again as it passed the check
-                       else:
-                               subnets.append(subnet_to_check)
-
-               # Write all networks as compact as possible
-               for network in ipaddress.collapse_addresses(subnets):
-                       log.debug("Writing %s to %s" % (network, self.f))
-                       self._write_network(network)
+               return self._write_network(network)
 
        def finish(self):
                """
@@ -188,7 +142,7 @@ class XTGeoIPOutputWriter(OutputWriter):
        mode = "wb"
 
        def _write_network(self, network):
-               for address in (network.network_address, network.broadcast_address):
+               for address in (network.first_address, network.last_address):
                        # Convert this into a string of bits
                        bytes = socket.inet_pton(
                                socket.AF_INET6 if network.version == 6 else socket.AF_INET, "%s" % address,
@@ -231,42 +185,21 @@ class Exporter(object):
                                writers[asn] = self.writer.open(self.db, filename, prefix="AS%s" % asn)
 
                        # Get all networks that match the family
-                       networks = self.db.search_networks(family=family)
-
-                       # Create a stack with all networks in order where we can put items back
-                       # again and retrieve them in the next iteration.
-                       networks = BufferedStack(networks)
+                       networks = self.db.search_networks(family=family, flatten=True)
 
                        # Walk through all networks
                        for network in networks:
-                               # Collect all networks which are a subnet of network
-                               subnets = []
-                               for subnet in networks:
-                                       # If the next subnet was not a subnet, we have to push
-                                       # it back on the stack and break this loop
-                                       if not subnet.is_subnet_of(network):
-                                               networks.push(subnet)
-                                               break
-
-                                       subnets.append(subnet)
-
                                # Write matching countries
-                               if network.country_code and network.country_code in writers:
-                                       # Mismatching subnets
-                                       gaps = [
-                                               subnet for subnet in subnets if not network.country_code == subnet.country_code
-                                       ]
-
-                                       writers[network.country_code].write(network, gaps)
+                               try:
+                                       writers[network.country_code].write(network)
+                               except KeyError:
+                                       pass
 
                                # Write matching ASNs
-                               if network.asn and network.asn in writers:
-                                       # Mismatching subnets
-                                       gaps = [
-                                               subnet for subnet in subnets if not network.asn == subnet.asn
-                                       ]
-
-                                       writers[network.asn].write(network, gaps)
+                               try:
+                                       writers[network.asn].write(network)
+                               except KeyError:
+                                       pass
 
                                # Handle flags
                                for flag in flags:
@@ -274,19 +207,10 @@ class Exporter(object):
                                                # Fetch the "fake" country code
                                                country = flags[flag]
 
-                                               if not country in writers:
-                                                       continue
-
-                                               gaps = [
-                                                       subnet for subnet in subnets
-                                                               if not subnet.has_flag(flag)
-                                               ]
-
-                                               writers[country].write(network, gaps)
-
-                               # Push all subnets back onto the stack
-                               for subnet in reversed(subnets):
-                                       networks.push(subnet)
+                                               try:
+                                                       writers[country].write(network)
+                                               except KeyError:
+                                                       pass
 
                        # Write everything to the filesystem
                        for writer in writers.values():
@@ -298,33 +222,3 @@ class Exporter(object):
                )
 
                return os.path.join(directory, filename)
-
-
-class BufferedStack(object):
-       """
-               This class takes an iterator and when being iterated
-               over it returns objects from that iterator for as long
-               as there are any.
-
-               It additionally has a function to put an item back on
-               the back so that it will be returned again at the next
-               iteration.
-       """
-       def __init__(self, iterator):
-               self.iterator = iterator
-               self.stack = []
-
-       def __iter__(self):
-               return self
-
-       def __next__(self):
-               if self.stack:
-                       return self.stack.pop(0)
-
-               return next(self.iterator)
-
-       def push(self, elem):
-               """
-                       Takes an element and puts it on the stack
-               """
-               self.stack.insert(0, elem)