]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
enable-source: move to own source files
authorJason Ish <ish@unx.ca>
Mon, 4 Dec 2017 13:40:09 +0000 (07:40 -0600)
committerJason Ish <ish@unx.ca>
Mon, 4 Dec 2017 13:40:09 +0000 (07:40 -0600)
suricata/update/commands/__init__.py
suricata/update/commands/enablesource.py [new file with mode: 0644]
suricata/update/main.py
suricata/update/sources.py

index a090177c7e3b8ddd1a76dcbdc23794f8649bac5b..c24cd5334d1b7c74efed39bf0379ea60132cbb53 100644 (file)
@@ -18,3 +18,4 @@ from suricata.update.commands import listenabledsources
 from suricata.update.commands import addsource
 from suricata.update.commands import listsources
 from suricata.update.commands import updatesources
+from suricata.update.commands import enablesource
diff --git a/suricata/update/commands/enablesource.py b/suricata/update/commands/enablesource.py
new file mode 100644 (file)
index 0000000..f797925
--- /dev/null
@@ -0,0 +1,102 @@
+# Copyright (C) 2017 Open Information Security Foundation
+#
+# You can copy, redistribute or modify this Program under the terms of
+# the GNU General Public License version 2 as published by the Free
+# Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# version 2 along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+from __future__ import print_function
+
+import os
+import logging
+
+import yaml
+
+from suricata.update import sources
+
+logger = logging.getLogger()
+
+def register(parser):
+    parser.add_argument("name")
+    parser.add_argument("params", nargs="*", metavar="param=val")
+    parser.set_defaults(func=enable_source)
+    
+def enable_source(config):
+    name = config.args.name
+
+    # Check if source is already enabled.
+    enabled_source_filename = sources.get_enabled_source_filename(name)
+    if os.path.exists(enabled_source_filename):
+        logger.error("The source %s is already enabled.", name)
+        return 1
+
+    # First check if this source was previous disabled and then just
+    # re-enable it.
+    disabled_source_filename = sources.get_disabled_source_filename(name)
+    if os.path.exists(disabled_source_filename):
+        logger.info("Re-enabling previous disabled source for %s.", name)
+        os.rename(disabled_source_filename, enabled_source_filename)
+        return 0
+
+    if not os.path.exists(sources.get_index_filename(config)):
+        logger.warning(
+            "Source index does not exist, "
+            "try running suricata-update update-sources.")
+        return 1
+
+    source_index = sources.load_source_index(config)
+    
+    if not name in source_index.get_sources():
+        logger.error("Unknown source: %s", name)
+        return 1
+
+    # Parse key=val options.
+    opts = {}
+    for param in config.args.params:
+        key, val = param.split("=", 1)
+        opts[key] = val
+
+    source = source_index.get_sources()[name]
+
+    if "subscribe-url" in source:
+        print("The source %s requires a subscription. Subscribe here:" % (name))
+        print("  %s" % source["subscribe-url"])
+
+    params = {}
+    if "parameters" in source:
+        for param in source["parameters"]:
+            if param in opts:
+                params[param] = opts[param]
+            else:
+                prompt = source["parameters"][param]["prompt"]
+                while True:
+                    r = raw_input("%s (%s): " % (prompt, param))
+                    r = r.strip()
+                    if r:
+                        break
+                params[param] = r.strip()
+    new_source = sources.SourceConfiguration(name, params=params).dict()
+
+    if not os.path.exists(sources.get_source_directory()):
+        try:
+            logger.info("Creating directory %s", sources.get_source_directory())
+            os.makedirs(sources.get_source_directory())
+        except Exception as err:
+            logger.error("Failed to create directory %s: %s",
+                         sources.get_source_directory(), err)
+            return 1
+
+    filename = os.path.join(
+        sources.get_source_directory(), "%s.yaml" % (sources.safe_filename(name)))
+    logger.info("Writing %s", filename)
+    with open(filename, "w") as fileobj:
+        fileobj.write(yaml.dump(new_source, default_flow_style=False))
index 59e78e6884d783048a7f82f8e394e4f1baba4b53..d264042e007d278afc6e8252569ab8179e8e07de 100644 (file)
@@ -1140,11 +1140,6 @@ def _main():
         "remove-source", help="Remove a source", parents=[common_parser])
     remove_source_parser.add_argument("name")
 
-    enable_source_parser = subparsers.add_parser(
-        "enable-source", parents=[common_parser])
-    enable_source_parser.add_argument("name")
-    enable_source_parser.add_argument("params", nargs="*", metavar="param=val")
-
     commands.listsources.register(subparsers.add_parser(
         "list-sources", parents=[common_parser]))
     commands.listenabledsources.register(subparsers.add_parser(
@@ -1153,6 +1148,8 @@ def _main():
         "add-source", parents=[common_parser]))
     commands.updatesources.register(subparsers.add_parser(
         "update-sources", parents=[common_parser]))
+    commands.enablesource.register(subparsers.add_parser(
+        "enable-source", parents=[common_parser]))
 
     args = parser.parse_args()
 
@@ -1184,9 +1181,7 @@ def _main():
         return 1
 
     if args.subcommand:
-        if args.subcommand == "enable-source":
-            return sources.enable_source(config)
-        elif args.subcommand == "disable-source":
+        if args.subcommand == "disable-source":
             return sources.disable_source(config)
         elif args.subcommand == "remove-source":
             return sources.remove_source(config)
index edb889ccbf854eca11e7ef402b0131e25d8f5bd5..b5524b49ca8495f824abc85a7ee100a10fc8433e 100644 (file)
@@ -134,86 +134,6 @@ def get_enabled_sources():
 
     return sources
 
-def load_sources(config):
-    sources_cache_filename = os.path.join(
-        config.get_cache_dir(), SOURCE_INDEX_FILENAME)
-    if os.path.exists(sources_cache_filename):
-        index = yaml.load(open(sources_cache_filename).read())
-        return index["sources"]
-    return {}
-
-def enable_source(config):
-    name = config.args.name
-
-    # Check if source is already enabled.
-    enabled_source_filename = os.path.join(
-        get_source_directory(), "%s.yaml" % (safe_filename(name)))
-    if os.path.exists(enabled_source_filename):
-        logger.error("The source %s is already enabled.", name)
-        return 1
-
-    # First check if this source was previous disabled and then just
-    # re-enable it.
-    disabled_source_filename = os.path.join(
-        get_source_directory(), "%s.yaml.disabled" % (safe_filename(name)))
-    if os.path.exists(disabled_source_filename):
-        logger.info("Re-enabling previous disabled source for %s.", name)
-        os.rename(disabled_source_filename, enabled_source_filename)
-        return 0
-
-    if not os.path.exists(get_index_filename(config)):
-        logger.warning(
-            "Source index does not exist, "
-            "try running suricata-update update-sources.")
-        return 1
-
-    sources = load_sources(config)
-    if not config.args.name in sources:
-        logger.error("Unknown source: %s", config.args.name)
-        return 1
-
-    # Parse key=val options.
-    opts = {}
-    for param in config.args.params:
-        key, val = param.split("=", 1)
-        opts[key] = val
-
-    source = sources[config.args.name]
-
-    if "subscribe-url" in source:
-        print("The source %s requires a subscription. Subscribe here:" % (name))
-        print("  %s" % source["subscribe-url"])
-
-    params = {}
-    if "parameters" in source:
-        for param in source["parameters"]:
-            if param in opts:
-                params[param] = opts[param]
-            else:
-                prompt = source["parameters"][param]["prompt"]
-                while True:
-                    r = raw_input("%s (%s): " % (prompt, param))
-                    r = r.strip()
-                    if r:
-                        break
-                params[param] = r.strip()
-    new_source = SourceConfiguration(name, params=params).dict()
-
-    if not os.path.exists(get_source_directory()):
-        try:
-            logger.info("Creating directory %s", get_source_directory())
-            os.makedirs(get_source_directory())
-        except Exception as err:
-            logger.error("Failed to create directory %s: %s",
-                         get_source_directory(), err)
-            return 1
-
-    filename = os.path.join(
-        get_source_directory(), "%s.yaml" % (safe_filename(name)))
-    logger.info("Writing %s", filename)
-    with open(filename, "w") as fileobj:
-        fileobj.write(yaml.dump(new_source, default_flow_style=False))
-
 def disable_source(config):
     name = config.args.name
     filename = os.path.join(get_source_directory(), "%s.yaml" % (