]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
Put cache directory under the rules directory.
authorJason Ish <ish@unx.ca>
Tue, 14 Nov 2017 10:42:21 +0000 (11:42 +0100)
committerJason Ish <ish@unx.ca>
Tue, 14 Nov 2017 10:42:21 +0000 (11:42 +0100)
One less directory to manage permissions on.

suricata/update/main.py
tests/test_main.py

index 3e871533f4f65e096cd0262cf7aea22dfbf01d1e..ac2ef4ba095a0736f39d004b7607a52dd64d8f2c 100644 (file)
@@ -272,10 +272,15 @@ class DropRuleFilter(object):
         drop_rule.enabled = rule.enabled
         return drop_rule
 
-class Fetch(object):
+class Fetch:
 
-    def __init__(self, args):
-        self.args = args
+    def __init__(self, config):
+        self.config = config
+        if config is not None:
+            self.args = config.args
+        else:
+            # Should only happen in tests.
+            self.args = None
 
     def check_checksum(self, tmp_filename, url):
         try:
@@ -320,7 +325,7 @@ class Fetch(object):
     def get_tmp_filename(self, url):
         url_hash = hashlib.md5(url.encode("utf-8")).hexdigest()
         return os.path.join(
-            self.args.cache_dir,
+            self.config.get_cache_dir(),
             "%s-%s" % (url_hash, self.url_basename(url)))
 
     def fetch(self, url):
@@ -334,8 +339,8 @@ class Fetch(object):
             if self.check_checksum(tmp_filename, url):
                 logger.info("Remote checksum has not changed. Not fetching.")
                 return self.extract_files(tmp_filename)
-        if not os.path.exists(self.args.cache_dir):
-            os.makedirs(self.args.cache_dir, mode=0o770)
+        if not os.path.exists(self.config.get_cache_dir()):
+            os.makedirs(self.config.get_cache_dir(), mode=0o770)
         logger.info("Fetching %s." % (url))
         try:
             suricata.update.net.get(
@@ -803,6 +808,8 @@ class Config:
         self.config = {}
         self.config.update(self.DEFAULTS)
 
+        self.cache_dir = None
+
     def load(self):
         if self.args.config:
             with open(self.args.config) as fileobj:
@@ -850,6 +857,14 @@ class Config:
     def set(self, key, val):
         self.config[key] = val
 
+    def get_cache_dir(self):
+        if self.cache_dir:
+            return self.cache_dir
+        return os.path.join(self.args.output, ".cache")
+
+    def set_cache_dir(self, directory):
+        self.cache_dir = directory
+
 def test_suricata(config, suricata_path):
     if not suricata_path:
         logger.info("No suricata application binary found, skipping test.")
@@ -956,7 +971,7 @@ def load_sources(config, suricata_version):
 
     # Now download each URL.
     for url in urls:
-        Fetch(config.args).run(url, files)
+        Fetch(config).run(url, files)
 
     # Now load local rules specified in the configuration file.
     for local in config.config["local"]:
@@ -983,8 +998,6 @@ def main():
     parser.add_argument("-o", "--output", metavar="<directory>",
                         dest="output", default="/var/lib/suricata/rules",
                         help="Directory to write rules to")
-    parser.add_argument("--cache-dir", default="/var/lib/suricata/cache",
-                        metavar="<directory>", help="set the cache directory")
     parser.add_argument("--suricata", metavar="<path>",
                         help="Path to Suricata program")
     parser.add_argument("--suricata-version", metavar="<version>",
@@ -1167,13 +1180,13 @@ def main():
         drop_filters += load_drop_filters(drop_conf_filename)
 
     # Check that the cache directory exists and is writable.
-    if not os.path.exists(args.cache_dir):
+    if not os.path.exists(config.get_cache_dir()):
         try:
-            os.makedirs(args.cache_dir, mode=0o770)
+            os.makedirs(config.get_cache_dir(), mode=0o770)
         except Exception as err:
             logger.warning(
                 "Cache directory does exist and could not be created. /var/tmp will be used instead.")
-            args.cache_dir = "/var/tmp"
+            config.set_cache_dir("/var/tmp")
 
     files = load_sources(config, suricata_version)
 
index a9c160ba8e3172e02ec4c038521689d3fc74530c..fa035bc8da6d8488c4ffc44025b2a66319d1f5b4 100644 (file)
@@ -111,7 +111,6 @@ class TestRulecat(unittest.TestCase):
                  "file://%s/emerging.rules.tar.gz" % (
                      os.getcwd()),
                  "--local", "./rule-with-unicode.rules",
-                 "--cache-dir", "./tmp",
                  "--force",
                  "--output", "./tmp/rules/",
                  "--yaml-fragment", "./tmp/suricata-rules.yaml",
@@ -151,7 +150,6 @@ class TestRulecat(unittest.TestCase):
                  "file://%s/emerging.rules.tar.gz" % (
                      os.getcwd()),
                  "--local", "./rule-with-unicode.rules",
-                 "--cache-dir", "./tmp",
                  "--force",
                  "--output", "./tmp/rules/",
                  "--yaml-fragment", "./tmp/suricata-rules.yaml",