From c242f7325bd6fc4ba26047ac24196d1c161c6e01 Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Fri, 13 Nov 2020 12:09:03 +0000 Subject: [PATCH] python: Move tree flattening into C Signed-off-by: Michael Tremer --- src/python/database.c | 8 ++- src/python/export.py | 138 +++++------------------------------------- 2 files changed, 21 insertions(+), 125 deletions(-) diff --git a/src/python/database.c b/src/python/database.c index 7f8c2c2..d169547 100644 --- a/src/python/database.c +++ b/src/python/database.c @@ -258,17 +258,19 @@ static PyObject* Database_networks_flattened(DatabaseObject *self) { } static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) { - char* kwlist[] = { "country_code", "asn", "flags", "family", NULL }; + char* kwlist[] = { "country_code", "asn", "flags", "family", "flatten", NULL }; const char* country_code = NULL; unsigned int asn = 0; int flags = 0; int family = 0; + int flatten = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siii", kwlist, &country_code, &asn, &flags, &family)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siiip", kwlist, &country_code, &asn, &flags, &family, &flatten)) return NULL; struct loc_database_enumerator* enumerator; - int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS, 0); + int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS, + (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0); if (r) { PyErr_SetFromErrno(PyExc_SystemError); return NULL; diff --git a/src/python/export.py b/src/python/export.py index dd44332..be4a68e 100644 --- a/src/python/export.py +++ b/src/python/export.py @@ -87,58 +87,12 @@ class OutputWriter(object): def _write_network(self, network): self.f.write("%s\n" % network) - def write(self, network, subnets): + def write(self, network): if self.flatten and self._flatten(network): log.debug("Skipping writing network %s (last one was %s)" % (network, self._last_network)) return - # Convert network into a Python object - _network = ipaddress.ip_network("%s" % network) - - # Write the network when it has no subnets - if not subnets: - log.debug("Writing %s to %s" % (_network, self.f)) - return self._write_network(_network) - - # Convert subnets into Python objects - _subnets = [ipaddress.ip_network("%s" % subnet) for subnet in subnets] - - # Split the network into smaller bits so that - # we can accomodate for any gaps in it later - to_check = set() - for _subnet in _subnets: - to_check.update( - _network.address_exclude(_subnet) - ) - - # Clear the list of all subnets - subnets = [] - - # Check if all subnets to not overlap with anything else - while to_check: - subnet_to_check = to_check.pop() - - for _subnet in _subnets: - # Drop this subnet if it equals one of the subnets - # or if it is subnet of one of them - if subnet_to_check == _subnet or subnet_to_check.subnet_of(_subnet): - break - - # Break it down if it overlaps - if subnet_to_check.overlaps(_subnet): - to_check.update( - subnet_to_check.address_exclude(_subnet) - ) - break - - # Add the subnet again as it passed the check - else: - subnets.append(subnet_to_check) - - # Write all networks as compact as possible - for network in ipaddress.collapse_addresses(subnets): - log.debug("Writing %s to %s" % (network, self.f)) - self._write_network(network) + return self._write_network(network) def finish(self): """ @@ -188,7 +142,7 @@ class XTGeoIPOutputWriter(OutputWriter): mode = "wb" def _write_network(self, network): - for address in (network.network_address, network.broadcast_address): + for address in (network.first_address, network.last_address): # Convert this into a string of bits bytes = socket.inet_pton( socket.AF_INET6 if network.version == 6 else socket.AF_INET, "%s" % address, @@ -231,42 +185,21 @@ class Exporter(object): writers[asn] = self.writer.open(self.db, filename, prefix="AS%s" % asn) # Get all networks that match the family - networks = self.db.search_networks(family=family) - - # Create a stack with all networks in order where we can put items back - # again and retrieve them in the next iteration. - networks = BufferedStack(networks) + networks = self.db.search_networks(family=family, flatten=True) # Walk through all networks for network in networks: - # Collect all networks which are a subnet of network - subnets = [] - for subnet in networks: - # If the next subnet was not a subnet, we have to push - # it back on the stack and break this loop - if not subnet.is_subnet_of(network): - networks.push(subnet) - break - - subnets.append(subnet) - # Write matching countries - if network.country_code and network.country_code in writers: - # Mismatching subnets - gaps = [ - subnet for subnet in subnets if not network.country_code == subnet.country_code - ] - - writers[network.country_code].write(network, gaps) + try: + writers[network.country_code].write(network) + except KeyError: + pass # Write matching ASNs - if network.asn and network.asn in writers: - # Mismatching subnets - gaps = [ - subnet for subnet in subnets if not network.asn == subnet.asn - ] - - writers[network.asn].write(network, gaps) + try: + writers[network.asn].write(network) + except KeyError: + pass # Handle flags for flag in flags: @@ -274,19 +207,10 @@ class Exporter(object): # Fetch the "fake" country code country = flags[flag] - if not country in writers: - continue - - gaps = [ - subnet for subnet in subnets - if not subnet.has_flag(flag) - ] - - writers[country].write(network, gaps) - - # Push all subnets back onto the stack - for subnet in reversed(subnets): - networks.push(subnet) + try: + writers[country].write(network) + except KeyError: + pass # Write everything to the filesystem for writer in writers.values(): @@ -298,33 +222,3 @@ class Exporter(object): ) return os.path.join(directory, filename) - - -class BufferedStack(object): - """ - This class takes an iterator and when being iterated - over it returns objects from that iterator for as long - as there are any. - - It additionally has a function to put an item back on - the back so that it will be returned again at the next - iteration. - """ - def __init__(self, iterator): - self.iterator = iterator - self.stack = [] - - def __iter__(self): - return self - - def __next__(self): - if self.stack: - return self.stack.pop(0) - - return next(self.iterator) - - def push(self, elem): - """ - Takes an element and puts it on the stack - """ - self.stack.insert(0, elem) -- 2.39.2