From: Jason Ish Date: Thu, 30 Nov 2017 19:27:12 +0000 (-0600) Subject: env var SOURCE_DIRECTORY to override default... X-Git-Tag: 1.0.0a1~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8cb4bce5e4b32bd5abe4e16e35a44fae5f35d506;p=thirdparty%2Fsuricata-update.git env var SOURCE_DIRECTORY to override default... So tests won't pick up enabled sources... --- diff --git a/suricata/update/sources.py b/suricata/update/sources.py index bb8e9ed..68f20b9 100644 --- a/suricata/update/sources.py +++ b/suricata/update/sources.py @@ -31,17 +31,24 @@ logger = logging.getLogger() DEFAULT_SOURCE_INDEX_URL = "https://raw.githubusercontent.com/jasonish/suricata-intel-index/master/index.yaml" SOURCE_INDEX_FILENAME = "index.yaml" -ENABLED_SOURCE_DIRECTORY = "/var/lib/suricata/update/sources" +DEFAULT_SOURCE_DIRECTORY = "/var/lib/suricata/update/sources" + +def get_source_directory(): + """Return the directory where source configuration files are kept.""" + if os.getenv("SOURCE_DIRECTORY"): + return os.getenv("SOURCE_DIRECTORY") + else: + return DEFAULT_SOURCE_DIRECTORY def get_index_filename(config): return os.path.join(config.get_cache_dir(), SOURCE_INDEX_FILENAME) def get_enabled_source_filename(name): - return os.path.join(ENABLED_SOURCE_DIRECTORY, "%s.yaml" % ( + return os.path.join(get_source_directory(), "%s.yaml" % ( safe_filename(name))) def get_disabled_source_filename(name): - return os.path.join(ENABLED_SOURCE_DIRECTORY, "%s.yaml.disabled" % ( + return os.path.join(get_source_directory(), "%s.yaml.disabled" % ( safe_filename(name))) def source_name_exists(name): @@ -105,10 +112,10 @@ def load_source_index(config): def get_enabled_sources(): """Return a map of enabled sources, keyed by name.""" - if not os.path.exists(ENABLED_SOURCE_DIRECTORY): + if not os.path.exists(get_source_directory()): return {} sources = {} - for dirpath, dirnames, filenames in os.walk(ENABLED_SOURCE_DIRECTORY): + for dirpath, dirnames, filenames in os.walk(get_source_directory()): for filename in filenames: if filename.endswith(".yaml"): path = os.path.join(dirpath, filename) @@ -162,7 +169,7 @@ def enable_source(config): # Check if source is already enabled. enabled_source_filename = os.path.join( - ENABLED_SOURCE_DIRECTORY, "%s.yaml" % (safe_filename(name))) + get_source_directory(), "%s.yaml" % (safe_filename(name))) if os.path.exists(enabled_source_filename): logger.error("The source %s is already enabled.", name) return 1 @@ -170,7 +177,7 @@ def enable_source(config): # First check if this source was previous disabled and then just # re-enable it. disabled_source_filename = os.path.join( - ENABLED_SOURCE_DIRECTORY, "%s.yaml.disabled" % (safe_filename(name))) + get_source_directory(), "%s.yaml.disabled" % (safe_filename(name))) if os.path.exists(disabled_source_filename): logger.info("Re-enabling previous disabled source for %s.", name) os.rename(disabled_source_filename, enabled_source_filename) @@ -214,24 +221,24 @@ def enable_source(config): params[param] = r.strip() new_source = SourceConfiguration(name, params=params).dict() - if not os.path.exists(ENABLED_SOURCE_DIRECTORY): + if not os.path.exists(get_source_directory()): try: - logger.info("Creating directory %s", ENABLED_SOURCE_DIRECTORY) - os.makedirs(ENABLED_SOURCE_DIRECTORY) + logger.info("Creating directory %s", get_source_directory()) + os.makedirs(get_source_directory()) except Exception as err: logger.error("Failed to create directory %s: %s", - ENABLED_SOURCE_DIRECTORY, err) + get_source_directory(), err) return 1 filename = os.path.join( - ENABLED_SOURCE_DIRECTORY, "%s.yaml" % (safe_filename(name))) + get_source_directory(), "%s.yaml" % (safe_filename(name))) logger.info("Writing %s", filename) with open(filename, "w") as fileobj: fileobj.write(yaml.dump(new_source, default_flow_style=False)) def disable_source(config): name = config.args.name - filename = os.path.join(ENABLED_SOURCE_DIRECTORY, "%s.yaml" % ( + filename = os.path.join(get_source_directory(), "%s.yaml" % ( safe_filename(name))) if not os.path.exists(filename): logger.debug("Filename %s does not exist.", filename) diff --git a/tests/test_main.py b/tests/test_main.py index 18af6d8..7972ace 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -95,6 +95,10 @@ class TestRulecat(unittest.TestCase): "--no-test", "--reload-command", "true", ], + env={ + "PATH": os.getenv("PATH"), + "SOURCE_DIRECTORY": "/tmp", + }, stdout=open("./tmp/stdout", "wb"), stderr=open("./tmp/stderr", "wb"), ) @@ -134,6 +138,10 @@ class TestRulecat(unittest.TestCase): "--no-test", "--reload-command", "true", ], + env={ + "PATH": os.getenv("PATH"), + "SOURCE_DIRECTORY": "/tmp", + }, stdout=open("./tmp/stdout", "wb"), stderr=open("./tmp/stderr", "wb"), )