]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
Fix failure in case of missing index.yaml, cleanup
authorShivani Bhardwaj <shivanib134@gmail.com>
Mon, 25 Feb 2019 10:29:54 +0000 (15:59 +0530)
committerJason Ish <ish@unx.ca>
Thu, 17 Oct 2019 23:06:29 +0000 (17:06 -0600)
If the index.yaml was not present in data directory, update-sources
command would fail with an IOError. Fix this by handling this case. Now,
if sources are updated on a new data directory, there is an info message
"Adding all sources in the log".

Modularize the current structure to make it more readable and perform
one thing per function. Sort the imports and clean them.

suricata/update/commands/updatesources.py

index 9e22e735c9075f7f9332b71310284c1499217e1a..25f6e3ebfb27c87f3237b43933e78c30ea383365 100644 (file)
 
 from __future__ import print_function
 
-import os
-import logging
 import io
-import yaml
+import logging
+import os
 
-from suricata.update import config
-from suricata.update import sources
-from suricata.update import net
-from suricata.update import exceptions
+import yaml
+from suricata.update import config, exceptions, net, sources
 
 logger = logging.getLogger()
 
@@ -33,31 +30,59 @@ def register(parser):
     parser.set_defaults(func=update_sources)
 
 
+def get_initial_content():
+    initial_content = None
+    if os.path.exists(local_index_filename):
+        with open(local_index_filename, "r") as stream:
+            initial_content = yaml.safe_load(stream)
+    return initial_content
+
+
+def get_sources(before, after):
+    all_sources = {source: after[source]
+        for source in after if source not in before}
+    return all_sources
+
+
+def log_sources(sources_map):
+    for name, all_sources in sources_map.items():
+        if not all_sources:
+            continue
+        for source in all_sources:
+            logger.info("Source %s was %s", source, name)
+
+
 def compare_sources(initial_content, final_content):
+    if not initial_content:
+        logger.info("Adding all sources")
+        return
     if initial_content == final_content:
         logger.info("No change in sources")
         return
     initial_sources = initial_content.get("sources")
     final_sources = final_content.get("sources")
-    added_sources = {source: final_sources[source]
-                     for source in final_sources if source not in initial_sources}
-    removed_sources = {source: initial_sources[source]
-                       for source in initial_sources if source not in final_sources}
-    if added_sources:
-        for source in added_sources:
-            logger.info("Source %s was added", source)
-    if removed_sources:
-        for source in removed_sources:
-            logger.info("Source %s was removed", source)
+    added_sources = get_sources(before=initial_sources, after=final_sources)
+    removed_sources = get_sources(before=final_sources, after=initial_sources)
+    log_sources(sources_map={"added": added_sources,
+                        "removed": removed_sources})
     for source in set(initial_sources) & set(final_sources):
         if initial_sources[source] != final_sources[source]:
             logger.info("Source %s was changed", source)
 
 
+def write_and_compare(initial_content, fileobj):
+    with open(local_index_filename, "wb") as outobj:
+        outobj.write(fileobj.getvalue())
+    with open(local_index_filename) as stream:
+        final_content = yaml.safe_load(stream)
+    compare_sources(initial_content, final_content)
+    logger.info("Saved %s", local_index_filename)
+
+
 def update_sources():
+    global local_index_filename
     local_index_filename = sources.get_index_filename()
-    with open(local_index_filename) as stream:
-        initial_content = yaml.safe_load(stream)
+    initial_content = get_initial_content()
     with io.BytesIO() as fileobj:
         url = sources.get_source_index_url()
         logger.info("Downloading %s", url)
@@ -73,9 +98,4 @@ def update_sources():
                 logger.error("Failed to create directory %s: %s",
                              config.get_cache_dir(), err)
                 return 1
-        with open(local_index_filename, "wb") as outobj:
-            outobj.write(fileobj.getvalue())
-        with open(local_index_filename) as stream:
-            final_content = yaml.safe_load(stream)
-        compare_sources(initial_content, final_content)
-        logger.info("Saved %s", local_index_filename)
+        write_and_compare(initial_content=initial_content, fileobj=fileobj)