From: Shivani Bhardwaj Date: Mon, 25 Feb 2019 10:29:54 +0000 (+0530) Subject: Fix failure in case of missing index.yaml, cleanup X-Git-Tag: 1.2.0rc1~33 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=65df71fdf49d81d78ece24eb7ea5e165cf01a103;p=thirdparty%2Fsuricata-update.git Fix failure in case of missing index.yaml, cleanup If the index.yaml was not present in data directory, update-sources command would fail with an IOError. Fix this by handling this case. Now, if sources are updated on a new data directory, there is an info message "Adding all sources in the log". Modularize the current structure to make it more readable and perform one thing per function. Sort the imports and clean them. --- diff --git a/suricata/update/commands/updatesources.py b/suricata/update/commands/updatesources.py index 9e22e73..25f6e3e 100644 --- a/suricata/update/commands/updatesources.py +++ b/suricata/update/commands/updatesources.py @@ -16,15 +16,12 @@ from __future__ import print_function -import os -import logging import io -import yaml +import logging +import os -from suricata.update import config -from suricata.update import sources -from suricata.update import net -from suricata.update import exceptions +import yaml +from suricata.update import config, exceptions, net, sources logger = logging.getLogger() @@ -33,31 +30,59 @@ def register(parser): parser.set_defaults(func=update_sources) +def get_initial_content(): + initial_content = None + if os.path.exists(local_index_filename): + with open(local_index_filename, "r") as stream: + initial_content = yaml.safe_load(stream) + return initial_content + + +def get_sources(before, after): + all_sources = {source: after[source] + for source in after if source not in before} + return all_sources + + +def log_sources(sources_map): + for name, all_sources in sources_map.items(): + if not all_sources: + continue + for source in all_sources: + logger.info("Source %s was %s", source, name) + + def compare_sources(initial_content, final_content): + if not initial_content: + logger.info("Adding all sources") + return if initial_content == final_content: logger.info("No change in sources") return initial_sources = initial_content.get("sources") final_sources = final_content.get("sources") - added_sources = {source: final_sources[source] - for source in final_sources if source not in initial_sources} - removed_sources = {source: initial_sources[source] - for source in initial_sources if source not in final_sources} - if added_sources: - for source in added_sources: - logger.info("Source %s was added", source) - if removed_sources: - for source in removed_sources: - logger.info("Source %s was removed", source) + added_sources = get_sources(before=initial_sources, after=final_sources) + removed_sources = get_sources(before=final_sources, after=initial_sources) + log_sources(sources_map={"added": added_sources, + "removed": removed_sources}) for source in set(initial_sources) & set(final_sources): if initial_sources[source] != final_sources[source]: logger.info("Source %s was changed", source) +def write_and_compare(initial_content, fileobj): + with open(local_index_filename, "wb") as outobj: + outobj.write(fileobj.getvalue()) + with open(local_index_filename) as stream: + final_content = yaml.safe_load(stream) + compare_sources(initial_content, final_content) + logger.info("Saved %s", local_index_filename) + + def update_sources(): + global local_index_filename local_index_filename = sources.get_index_filename() - with open(local_index_filename) as stream: - initial_content = yaml.safe_load(stream) + initial_content = get_initial_content() with io.BytesIO() as fileobj: url = sources.get_source_index_url() logger.info("Downloading %s", url) @@ -73,9 +98,4 @@ def update_sources(): logger.error("Failed to create directory %s: %s", config.get_cache_dir(), err) return 1 - with open(local_index_filename, "wb") as outobj: - outobj.write(fileobj.getvalue()) - with open(local_index_filename) as stream: - final_content = yaml.safe_load(stream) - compare_sources(initial_content, final_content) - logger.info("Saved %s", local_index_filename) + write_and_compare(initial_content=initial_content, fileobj=fileobj)