export: Flatten the tree before exporting it

Message ID 20201027171430.20785-1-peter.mueller@ipfire.org
State Accepted
Commit bbed1fd2330e8efa6b413dc152a1a6ef2d771aac
Headers show
Series
  • export: Flatten the tree before exporting it
Related show

Commit Message

Peter Müller Oct. 27, 2020, 5:14 p.m. UTC
From: Michael Tremer <michael.tremer@ipfire.org>

This patch removes the possibility that any IP address ranges
might show up in multiple exported files.

If this was content from the database:

  * 10.0.0.0/16 - DE
  * 10.0.1.0/24 - FR

Then the IP address 10.0.1.1 would match for both countries.

The algorithm will now break the larger /16 subnet apart into
smaller subnets so that 10.0.1.0/24 is no longer overlapped.

There was some time spent on this to make this as efficient
as possible.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
---
 src/python/export.py | 154 ++++++++++++++++++++++++++++++-------------
 1 file changed, 110 insertions(+), 44 deletions(-)

Patch

diff --git a/src/python/export.py b/src/python/export.py
index 5eaf43f..dd44332 100644
--- a/src/python/export.py
+++ b/src/python/export.py
@@ -89,28 +89,55 @@  class OutputWriter(object):
 
 	def write(self, network, subnets):
 		if self.flatten and self._flatten(network):
-			log.debug("Skipping writing network %s" % network)
+			log.debug("Skipping writing network %s (last one was %s)" % (network, self._last_network))
 			return
 
+		# Convert network into a Python object
+		_network = ipaddress.ip_network("%s" % network)
+
 		# Write the network when it has no subnets
 		if not subnets:
-			network = ipaddress.ip_network("%s" % network)
-			return self._write_network(network)
+			log.debug("Writing %s to %s" % (_network, self.f))
+			return self._write_network(_network)
+
+		# Convert subnets into Python objects
+		_subnets = [ipaddress.ip_network("%s" % subnet) for subnet in subnets]
+
+		# Split the network into smaller bits so that
+		# we can accomodate for any gaps in it later
+		to_check = set()
+		for _subnet in _subnets:
+			to_check.update(
+				_network.address_exclude(_subnet)
+			)
+
+		# Clear the list of all subnets
+		subnets = []
+
+		# Check if all subnets to not overlap with anything else
+		while to_check:
+			subnet_to_check = to_check.pop()
 
-		# Collect all matching subnets
-		matching_subnets = []
+			for _subnet in _subnets:
+				# Drop this subnet if it equals one of the subnets
+				# or if it is subnet of one of them
+				if subnet_to_check == _subnet or subnet_to_check.subnet_of(_subnet):
+					break
 
-		for subnet in sorted(subnets):
-			# Try to find the subnet in the database
-			n = self.db.lookup("%s" % subnet.network_address)
+				# Break it down if it overlaps
+				if subnet_to_check.overlaps(_subnet):
+					to_check.update(
+						subnet_to_check.address_exclude(_subnet)
+					)
+					break
 
-			# No entry or matching country means those IP addresses belong here
-			if not n or n.country_code == network.country_code:
-				matching_subnets.append(subnet)
+			# Add the subnet again as it passed the check
+			else:
+				subnets.append(subnet_to_check)
 
 		# Write all networks as compact as possible
-		for network in ipaddress.collapse_addresses(matching_subnets):
-			log.debug("Writing %s to database" % network)
+		for network in ipaddress.collapse_addresses(subnets):
+			log.debug("Writing %s to %s" % (network, self.f))
 			self._write_network(network)
 
 	def finish(self):
@@ -206,40 +233,40 @@  class Exporter(object):
 			# Get all networks that match the family
 			networks = self.db.search_networks(family=family)
 
-			# Materialise the generator into a list (uses quite some memory)
-			networks = list(networks)
+			# Create a stack with all networks in order where we can put items back
+			# again and retrieve them in the next iteration.
+			networks = BufferedStack(networks)
 
 			# Walk through all networks
-			for i, network in enumerate(networks):
-				_network = ipaddress.ip_network("%s" % network)
-
-				# Search for all subnets
-				subnets = set()
-
-				while i < len(networks):
-					subnet = networks[i+1]
-
-					if subnet.is_subnet_of(network):
-						_subnet = ipaddress.ip_network("%s" % subnet)
-
-						subnets.add(_subnet)
-						subnets.update(_network.address_exclude(_subnet))
-
-						i += 1
-					else:
+			for network in networks:
+				# Collect all networks which are a subnet of network
+				subnets = []
+				for subnet in networks:
+					# If the next subnet was not a subnet, we have to push
+					# it back on the stack and break this loop
+					if not subnet.is_subnet_of(network):
+						networks.push(subnet)
 						break
 
+					subnets.append(subnet)
+
 				# Write matching countries
-				try:
-					writers[network.country_code].write(network, subnets)
-				except KeyError:
-					pass
+				if network.country_code and network.country_code in writers:
+					# Mismatching subnets
+					gaps = [
+						subnet for subnet in subnets if not network.country_code == subnet.country_code
+					]
+
+					writers[network.country_code].write(network, gaps)
 
 				# Write matching ASNs
-				try:
-					writers[network.asn].write(network, subnets)
-				except KeyError:
-					pass
+				if network.asn and network.asn in writers:
+					# Mismatching subnets
+					gaps = [
+						subnet for subnet in subnets if not network.asn == subnet.asn
+					]
+
+					writers[network.asn].write(network, gaps)
 
 				# Handle flags
 				for flag in flags:
@@ -247,10 +274,19 @@  class Exporter(object):
 						# Fetch the "fake" country code
 						country = flags[flag]
 
-						try:
-							writers[country].write(network, subnets)
-						except KeyError:
-							pass
+						if not country in writers:
+							continue
+
+						gaps = [
+							subnet for subnet in subnets
+								if not subnet.has_flag(flag)
+						]
+
+						writers[country].write(network, gaps)
+
+				# Push all subnets back onto the stack
+				for subnet in reversed(subnets):
+					networks.push(subnet)
 
 			# Write everything to the filesystem
 			for writer in writers.values():
@@ -262,3 +298,33 @@  class Exporter(object):
 		)
 
 		return os.path.join(directory, filename)
+
+
+class BufferedStack(object):
+	"""
+		This class takes an iterator and when being iterated
+		over it returns objects from that iterator for as long
+		as there are any.
+
+		It additionally has a function to put an item back on
+		the back so that it will be returned again at the next
+		iteration.
+	"""
+	def __init__(self, iterator):
+		self.iterator = iterator
+		self.stack = []
+
+	def __iter__(self):
+		return self
+
+	def __next__(self):
+		if self.stack:
+			return self.stack.pop(0)
+
+		return next(self.iterator)
+
+	def push(self, elem):
+		"""
+			Takes an element and puts it on the stack
+		"""
+		self.stack.insert(0, elem)