From: Jason Ish Date: Mon, 4 Dec 2017 19:18:01 +0000 (-0600) Subject: When enabling source, also enable et/open... X-Git-Tag: 1.0.0a1~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4f5f6060d91dbae0107754ba3ed91712312f2456;p=thirdparty%2Fsuricata-update.git When enabling source, also enable et/open... But only if the source being enabled is not et/open, or the source being enabled does not replace et/open. This is also only done on creation of the directory: /var/lib/suricata/update/sources --- diff --git a/suricata/update/commands/enablesource.py b/suricata/update/commands/enablesource.py index f797925..3164944 100644 --- a/suricata/update/commands/enablesource.py +++ b/suricata/update/commands/enablesource.py @@ -25,11 +25,13 @@ from suricata.update import sources logger = logging.getLogger() +default_source = "et/open" + def register(parser): parser.add_argument("name") parser.add_argument("params", nargs="*", metavar="param=val") parser.set_defaults(func=enable_source) - + def enable_source(config): name = config.args.name @@ -54,7 +56,7 @@ def enable_source(config): return 1 source_index = sources.load_source_index(config) - + if not name in source_index.get_sources(): logger.error("Unknown source: %s", name) return 1 @@ -84,19 +86,44 @@ def enable_source(config): if r: break params[param] = r.strip() - new_source = sources.SourceConfiguration(name, params=params).dict() + new_source = sources.SourceConfiguration(name, params=params) - if not os.path.exists(sources.get_source_directory()): + # If the source directory does not exist, create it. Also create + # the default rule-source of et/open, unless the source being + # enabled replaces it. + source_directory = sources.get_source_directory() + if not os.path.exists(source_directory): try: - logger.info("Creating directory %s", sources.get_source_directory()) - os.makedirs(sources.get_source_directory()) + logger.info("Creating directory %s", source_directory) + os.makedirs(source_directory) except Exception as err: - logger.error("Failed to create directory %s: %s", - sources.get_source_directory(), err) + logger.error( + "Failed to create directory %s: %s", source_directory, err) return 1 - filename = os.path.join( - sources.get_source_directory(), "%s.yaml" % (sources.safe_filename(name))) - logger.info("Writing %s", filename) - with open(filename, "w") as fileobj: - fileobj.write(yaml.dump(new_source, default_flow_style=False)) + if "replaces" in source and default_source in source["replaces"]: + logger.debug( + "Not enabling default source as selected source replaces it") + elif new_source.name == default_source: + logger.debug( + "Not enabling default source as selected source is the default") + else: + logger.info("Enabling default source %s", default_source) + if not source_index.get_source_by_name(default_source): + logger.error("Default source %s not in index", default_source) + else: + default_source_config = sources.SourceConfiguration( + default_source) + write_source_config(default_source_config, True) + + write_source_config(new_source, True) + logger.info("Source %s enabled", new_source.name) + +def write_source_config(config, enabled): + if enabled: + filename = sources.get_enabled_source_filename(config.name) + else: + filename = sources.get_disabled_source_filename(config.name) + with open(filename, "wb") as fileobj: + logger.debug("Writing %s", filename) + fileobj.write(yaml.safe_dump(config.dict(), default_flow_style=False)) diff --git a/suricata/update/sources.py b/suricata/update/sources.py index b5524b4..a15ae49 100644 --- a/suricata/update/sources.py +++ b/suricata/update/sources.py @@ -112,6 +112,11 @@ class Index: def get_sources(self): return self.index["sources"] + def get_source_by_name(self, name): + if name in self.index["sources"]: + return self.index["sources"][name] + return None + def load_source_index(config): return Index(get_index_filename(config)) @@ -161,7 +166,7 @@ def remove_source(config): os.remove(disabled_source_filename) logger.info("Source %s removed, previously disabled.", name) return 0 - + logger.warning("Source %s does not exist.", name) return 1