]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
When enabling source, also enable et/open...
authorJason Ish <ish@unx.ca>
Mon, 4 Dec 2017 19:18:01 +0000 (13:18 -0600)
committerJason Ish <ish@unx.ca>
Mon, 4 Dec 2017 19:18:01 +0000 (13:18 -0600)
But only if the source being enabled is not et/open, or the
source being enabled does not replace et/open.

This is also only done on creation of the directory:
  /var/lib/suricata/update/sources

suricata/update/commands/enablesource.py
suricata/update/sources.py

index f797925567b33e067ed09247bcf870e435d21864..316494493d911253d1acb8114b918578f90ad2cd 100644 (file)
@@ -25,11 +25,13 @@ from suricata.update import sources
 
 logger = logging.getLogger()
 
+default_source = "et/open"
+
 def register(parser):
     parser.add_argument("name")
     parser.add_argument("params", nargs="*", metavar="param=val")
     parser.set_defaults(func=enable_source)
-    
+
 def enable_source(config):
     name = config.args.name
 
@@ -54,7 +56,7 @@ def enable_source(config):
         return 1
 
     source_index = sources.load_source_index(config)
-    
+
     if not name in source_index.get_sources():
         logger.error("Unknown source: %s", name)
         return 1
@@ -84,19 +86,44 @@ def enable_source(config):
                     if r:
                         break
                 params[param] = r.strip()
-    new_source = sources.SourceConfiguration(name, params=params).dict()
+    new_source = sources.SourceConfiguration(name, params=params)
 
-    if not os.path.exists(sources.get_source_directory()):
+    # If the source directory does not exist, create it. Also create
+    # the default rule-source of et/open, unless the source being
+    # enabled replaces it.
+    source_directory = sources.get_source_directory()
+    if not os.path.exists(source_directory):
         try:
-            logger.info("Creating directory %s", sources.get_source_directory())
-            os.makedirs(sources.get_source_directory())
+            logger.info("Creating directory %s", source_directory)
+            os.makedirs(source_directory)
         except Exception as err:
-            logger.error("Failed to create directory %s: %s",
-                         sources.get_source_directory(), err)
+            logger.error(
+                "Failed to create directory %s: %s", source_directory, err)
             return 1
 
-    filename = os.path.join(
-        sources.get_source_directory(), "%s.yaml" % (sources.safe_filename(name)))
-    logger.info("Writing %s", filename)
-    with open(filename, "w") as fileobj:
-        fileobj.write(yaml.dump(new_source, default_flow_style=False))
+        if "replaces" in source and default_source in source["replaces"]:
+            logger.debug(
+                "Not enabling default source as selected source replaces it")
+        elif new_source.name == default_source:
+            logger.debug(
+                "Not enabling default source as selected source is the default")
+        else:
+            logger.info("Enabling default source %s", default_source)
+            if not source_index.get_source_by_name(default_source):
+                logger.error("Default source %s not in index", default_source)
+            else:
+                default_source_config = sources.SourceConfiguration(
+                    default_source)
+                write_source_config(default_source_config, True)
+
+    write_source_config(new_source, True)
+    logger.info("Source %s enabled", new_source.name)
+
+def write_source_config(config, enabled):
+    if enabled:
+        filename = sources.get_enabled_source_filename(config.name)
+    else:
+        filename = sources.get_disabled_source_filename(config.name)
+    with open(filename, "wb") as fileobj:
+        logger.debug("Writing %s", filename)
+        fileobj.write(yaml.safe_dump(config.dict(), default_flow_style=False))
index b5524b49ca8495f824abc85a7ee100a10fc8433e..a15ae49c77c35ab9b603c06caa1158521e921d70 100644 (file)
@@ -112,6 +112,11 @@ class Index:
     def get_sources(self):
         return self.index["sources"]
 
+    def get_source_by_name(self, name):
+        if name in self.index["sources"]:
+            return self.index["sources"][name]
+        return None
+
 def load_source_index(config):
     return Index(get_index_filename(config))
 
@@ -161,7 +166,7 @@ def remove_source(config):
         os.remove(disabled_source_filename)
         logger.info("Source %s removed, previously disabled.", name)
         return 0
-    
+
     logger.warning("Source %s does not exist.", name)
     return 1