from __future__ import print_function
-import os
-import logging
import io
-import yaml
+import logging
+import os
-from suricata.update import config
-from suricata.update import sources
-from suricata.update import net
-from suricata.update import exceptions
+import yaml
+from suricata.update import config, exceptions, net, sources
logger = logging.getLogger()
parser.set_defaults(func=update_sources)
+def get_initial_content():
+ initial_content = None
+ if os.path.exists(local_index_filename):
+ with open(local_index_filename, "r") as stream:
+ initial_content = yaml.safe_load(stream)
+ return initial_content
+
+
+def get_sources(before, after):
+ all_sources = {source: after[source]
+ for source in after if source not in before}
+ return all_sources
+
+
+def log_sources(sources_map):
+ for name, all_sources in sources_map.items():
+ if not all_sources:
+ continue
+ for source in all_sources:
+ logger.info("Source %s was %s", source, name)
+
+
def compare_sources(initial_content, final_content):
+ if not initial_content:
+ logger.info("Adding all sources")
+ return
if initial_content == final_content:
logger.info("No change in sources")
return
initial_sources = initial_content.get("sources")
final_sources = final_content.get("sources")
- added_sources = {source: final_sources[source]
- for source in final_sources if source not in initial_sources}
- removed_sources = {source: initial_sources[source]
- for source in initial_sources if source not in final_sources}
- if added_sources:
- for source in added_sources:
- logger.info("Source %s was added", source)
- if removed_sources:
- for source in removed_sources:
- logger.info("Source %s was removed", source)
+ added_sources = get_sources(before=initial_sources, after=final_sources)
+ removed_sources = get_sources(before=final_sources, after=initial_sources)
+ log_sources(sources_map={"added": added_sources,
+ "removed": removed_sources})
for source in set(initial_sources) & set(final_sources):
if initial_sources[source] != final_sources[source]:
logger.info("Source %s was changed", source)
+def write_and_compare(initial_content, fileobj):
+ with open(local_index_filename, "wb") as outobj:
+ outobj.write(fileobj.getvalue())
+ with open(local_index_filename) as stream:
+ final_content = yaml.safe_load(stream)
+ compare_sources(initial_content, final_content)
+ logger.info("Saved %s", local_index_filename)
+
+
def update_sources():
+ global local_index_filename
local_index_filename = sources.get_index_filename()
- with open(local_index_filename) as stream:
- initial_content = yaml.safe_load(stream)
+ initial_content = get_initial_content()
with io.BytesIO() as fileobj:
url = sources.get_source_index_url()
logger.info("Downloading %s", url)
logger.error("Failed to create directory %s: %s",
config.get_cache_dir(), err)
return 1
- with open(local_index_filename, "wb") as outobj:
- outobj.write(fileobj.getvalue())
- with open(local_index_filename) as stream:
- final_content = yaml.safe_load(stream)
- compare_sources(initial_content, final_content)
- logger.info("Saved %s", local_index_filename)
+ write_and_compare(initial_content=initial_content, fileobj=fileobj)