]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
env var SOURCE_DIRECTORY to override default...
authorJason Ish <ish@unx.ca>
Thu, 30 Nov 2017 19:27:12 +0000 (13:27 -0600)
committerJason Ish <ish@unx.ca>
Fri, 1 Dec 2017 17:33:18 +0000 (11:33 -0600)
So tests won't pick up enabled sources...

suricata/update/sources.py
tests/test_main.py

index bb8e9ed4e7b5950847c0cc6d52c99c0740a1aa8a..68f20b9f719a5f0addc93fdaf0e419f8c5fefb04 100644 (file)
@@ -31,17 +31,24 @@ logger = logging.getLogger()
 
 DEFAULT_SOURCE_INDEX_URL = "https://raw.githubusercontent.com/jasonish/suricata-intel-index/master/index.yaml"
 SOURCE_INDEX_FILENAME = "index.yaml"
-ENABLED_SOURCE_DIRECTORY = "/var/lib/suricata/update/sources"
+DEFAULT_SOURCE_DIRECTORY = "/var/lib/suricata/update/sources"
+
+def get_source_directory():
+    """Return the directory where source configuration files are kept."""
+    if os.getenv("SOURCE_DIRECTORY"):
+        return os.getenv("SOURCE_DIRECTORY")
+    else:
+        return DEFAULT_SOURCE_DIRECTORY
 
 def get_index_filename(config):
     return os.path.join(config.get_cache_dir(), SOURCE_INDEX_FILENAME)
 
 def get_enabled_source_filename(name):
-    return os.path.join(ENABLED_SOURCE_DIRECTORY, "%s.yaml" % (
+    return os.path.join(get_source_directory(), "%s.yaml" % (
         safe_filename(name)))
 
 def get_disabled_source_filename(name):
-    return os.path.join(ENABLED_SOURCE_DIRECTORY, "%s.yaml.disabled" % (
+    return os.path.join(get_source_directory(), "%s.yaml.disabled" % (
         safe_filename(name)))
 
 def source_name_exists(name):
@@ -105,10 +112,10 @@ def load_source_index(config):
 
 def get_enabled_sources():
     """Return a map of enabled sources, keyed by name."""
-    if not os.path.exists(ENABLED_SOURCE_DIRECTORY):
+    if not os.path.exists(get_source_directory()):
         return {}
     sources = {}
-    for dirpath, dirnames, filenames in os.walk(ENABLED_SOURCE_DIRECTORY):
+    for dirpath, dirnames, filenames in os.walk(get_source_directory()):
         for filename in filenames:
             if filename.endswith(".yaml"):
                 path = os.path.join(dirpath, filename)
@@ -162,7 +169,7 @@ def enable_source(config):
 
     # Check if source is already enabled.
     enabled_source_filename = os.path.join(
-        ENABLED_SOURCE_DIRECTORY, "%s.yaml" % (safe_filename(name)))
+        get_source_directory(), "%s.yaml" % (safe_filename(name)))
     if os.path.exists(enabled_source_filename):
         logger.error("The source %s is already enabled.", name)
         return 1
@@ -170,7 +177,7 @@ def enable_source(config):
     # First check if this source was previous disabled and then just
     # re-enable it.
     disabled_source_filename = os.path.join(
-        ENABLED_SOURCE_DIRECTORY, "%s.yaml.disabled" % (safe_filename(name)))
+        get_source_directory(), "%s.yaml.disabled" % (safe_filename(name)))
     if os.path.exists(disabled_source_filename):
         logger.info("Re-enabling previous disabled source for %s.", name)
         os.rename(disabled_source_filename, enabled_source_filename)
@@ -214,24 +221,24 @@ def enable_source(config):
                 params[param] = r.strip()
     new_source = SourceConfiguration(name, params=params).dict()
 
-    if not os.path.exists(ENABLED_SOURCE_DIRECTORY):
+    if not os.path.exists(get_source_directory()):
         try:
-            logger.info("Creating directory %s", ENABLED_SOURCE_DIRECTORY)
-            os.makedirs(ENABLED_SOURCE_DIRECTORY)
+            logger.info("Creating directory %s", get_source_directory())
+            os.makedirs(get_source_directory())
         except Exception as err:
             logger.error("Failed to create directory %s: %s",
-                         ENABLED_SOURCE_DIRECTORY, err)
+                         get_source_directory(), err)
             return 1
 
     filename = os.path.join(
-        ENABLED_SOURCE_DIRECTORY, "%s.yaml" % (safe_filename(name)))
+        get_source_directory(), "%s.yaml" % (safe_filename(name)))
     logger.info("Writing %s", filename)
     with open(filename, "w") as fileobj:
         fileobj.write(yaml.dump(new_source, default_flow_style=False))
 
 def disable_source(config):
     name = config.args.name
-    filename = os.path.join(ENABLED_SOURCE_DIRECTORY, "%s.yaml" % (
+    filename = os.path.join(get_source_directory(), "%s.yaml" % (
         safe_filename(name)))
     if not os.path.exists(filename):
         logger.debug("Filename %s does not exist.", filename)
index 18af6d80710d7362eee4c4bd4236003e680ec922..7972ace8ad9849e47e1b03582b7686031f43725d 100644 (file)
@@ -95,6 +95,10 @@ class TestRulecat(unittest.TestCase):
                  "--no-test",
                  "--reload-command", "true",
                 ],
+                env={
+                    "PATH": os.getenv("PATH"),
+                    "SOURCE_DIRECTORY": "/tmp",
+                },
                 stdout=open("./tmp/stdout", "wb"),
                 stderr=open("./tmp/stderr", "wb"),
             )
@@ -134,6 +138,10 @@ class TestRulecat(unittest.TestCase):
                  "--no-test",
                  "--reload-command", "true",
                 ],
+                env={
+                    "PATH": os.getenv("PATH"),
+                    "SOURCE_DIRECTORY": "/tmp",
+                },
                 stdout=open("./tmp/stdout", "wb"),
                 stderr=open("./tmp/stderr", "wb"),
             )