]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
Load files into list, not dict to keep duplicate filenames
authorJason Ish <jason.ish@oisf.net>
Wed, 8 Jul 2020 22:52:12 +0000 (16:52 -0600)
committerShivani Bhardwaj <shivanib134@gmail.com>
Thu, 3 Sep 2020 15:51:33 +0000 (21:21 +0530)
By loading all downloaded rule files into the same dict, filenames
that are the same cause other files to be lost, and the content
may be different.

Instead use of list objects that tracks the filename and the content
to avoid losing the contents of a filename that already exists.

If the contents is duplicated, the rule deduplication process
will catch that.

Redmine ticket:
https://redmine.openinfosecfoundation.org/issues/3174

suricata/update/main.py

index c7d59318814bd7cf02de555a84a846b298ff8adb..5f263dfca6ea100a91bee9a0d63404ca298a52ea 100644 (file)
@@ -32,6 +32,7 @@ import io
 import tempfile
 import signal
 import errno
+from collections import namedtuple
 
 try:
     # Python 3.
@@ -69,6 +70,8 @@ try:
 except:
     revision = None
 
+SourceFile = namedtuple("SourceFile", ["filename", "content"])
+
 if sys.argv[0] == __file__:
     sys.path.insert(
         0, os.path.abspath(os.path.join(__file__, "..", "..", "..")))
@@ -203,9 +206,8 @@ class Fetch:
         logger.info("Done.")
         return self.extract_files(tmp_filename)
 
-    def run(self, url=None, files=None):
-        if files is None:
-            files = {}
+    def run(self, url=None):
+        files = {}
         if url:
             try:
                 fetched = self.fetch(url)
@@ -300,7 +302,7 @@ def load_local(local, files):
                         filename))
             try:
                 with open(filename, "rb") as fileobj:
-                    files[filename] = fileobj.read()
+                    files.append(SourceFile(filename, fileobj.read()))
             except Exception as err:
                 logger.error("Failed to open %s: %s" % (filename, err))
 
@@ -355,7 +357,7 @@ def load_dist_rules(files):
             logger.info("Loading distribution rule file %s", path)
             try:
                 with open(path, "rb") as fileobj:
-                    files[path] = fileobj.read()
+                    files.append(SourceFile(path, fileobj.read()))
             except Exception as err:
                 logger.error("Failed to open %s: %s" % (path, err))
                 sys.exit(1)
@@ -833,8 +835,6 @@ def copytree(src, dst):
                     dst_path)
 
 def load_sources(suricata_version):
-    files = {}
-
     urls = []
 
     http_header = None
@@ -915,8 +915,11 @@ def load_sources(suricata_version):
     urls = set(urls)
 
     # Now download each URL.
+    files = []
     for url in urls:
-        Fetch().run(url, files)
+        source_files = Fetch().run(url)
+        for key in source_files:
+            files.append(SourceFile(key, source_files[key]))
 
     # Now load local rules.
     if config.get("local") is not None:
@@ -1131,24 +1134,21 @@ def _main():
 
     load_dist_rules(files)
 
-    # Remove ignored files.
-    for filename in list(files.keys()):
-        if ignore_file(config.get("ignore"), filename):
-            logger.info("Ignoring file %s" % (filename))
-            del(files[filename])
-
     rules = []
     classification_files = []
     dep_files = {}
-    for filename in sorted(files):
-        if "classification.config" in filename:
-            classification_files.append((filename, files[filename]))
+    for entry in sorted(files, key = lambda e: e.filename):
+        if "classification.config" in entry.filename:
+            classification_files.append((entry.filename, entry.content))
             continue
-        if not filename.endswith(".rules"):
-            dep_files.update({filename: files[filename]})
+        if not entry.filename.endswith(".rules"):
+            dep_files.update({entry.filename: entry.content})
+            continue
+        if ignore_file(config.get("ignore"), entry.filename):
+            logger.info("Ignoring file {}".format(entry.filename))
             continue
-        logger.debug("Parsing %s." % (filename))
-        rules += rule_mod.parse_fileobj(io.BytesIO(files[filename]), filename)
+        logger.debug("Parsing {}".format(entry.filename))
+        rules += rule_mod.parse_fileobj(io.BytesIO(entry.content), entry.filename)
 
     rulemap = build_rule_map(rules)
     logger.info("Loaded %d rules." % (len(rules)))