]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
update-sources: catch network errors and error out
authorJason Ish <ish@unx.ca>
Thu, 14 Dec 2017 20:35:45 +0000 (14:35 -0600)
committerJason Ish <ish@unx.ca>
Thu, 14 Dec 2017 20:35:45 +0000 (14:35 -0600)
Issue:
https://redmine.openinfosecfoundation.org/issues/2348

suricata/update/commands/updatesources.py
suricata/update/exceptions.py [new file with mode: 0644]
suricata/update/main.py
suricata/update/sources.py

index 0b9e4b260f6142ba6d47a156db46ca18860962c0..7f6bfedef4e76e3175b4cb0b41ef98df2506b53f 100644 (file)
@@ -23,6 +23,7 @@ import io
 from suricata.update import config
 from suricata.update import sources
 from suricata.update import net
+from suricata.update import exceptions
 
 logger = logging.getLogger()
 
@@ -32,12 +33,13 @@ def register(parser):
 def update_sources():
     local_index_filename = sources.get_index_filename()
     with io.BytesIO() as fileobj:
+        url = sources.get_source_index_url()
+        logger.info("Downloading %s", url)
         try:
-            url = sources.get_source_index_url()
-            logger.info("Downloading %s", url)
             net.get(url, fileobj)
         except Exception as err:
-            raise Exception("Failed to download index: %s: %s" % (url, err))
+            raise exceptions.ApplicationError(
+                "Failed to download index: %s: %s" % (url, err))
         if not os.path.exists(config.get_cache_dir()):
             try:
                 os.makedirs(config.get_cache_dir())
diff --git a/suricata/update/exceptions.py b/suricata/update/exceptions.py
new file mode 100644 (file)
index 0000000..1f2c547
--- /dev/null
@@ -0,0 +1,21 @@
+# Copyright (C) 2017 Open Information Security Foundation
+#
+# You can copy, redistribute or modify this Program under the terms of
+# the GNU General Public License version 2 as published by the Free
+# Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# version 2 along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+class ApplicationError(Exception):
+    pass
+
+class InvalidConfigurationError(ApplicationError):
+    pass
index d66726d11c7e353bf591631d2504b2502839f557..a46935d1a091a6b99ab91599c1fc2da65f6d6e50 100644 (file)
@@ -59,6 +59,7 @@ from suricata.update import extract
 from suricata.update import util
 from suricata.update import sources
 from suricata.update import commands
+from suricata.update import exceptions
 
 from suricata.update.version import version
 try:
@@ -84,12 +85,6 @@ DEFAULT_SURICATA_VERSION = "4.0.0"
 # single file concatenating all input rule files together.
 DEFAULT_OUTPUT_RULE_FILENAME = "suricata.rules"
 
-class ApplicationError(Exception):
-    pass
-
-class InvalidConfigurationError(ApplicationError):
-    pass
-
 class AllRuleMatcher(object):
     """Matcher object to match all rules. """
 
@@ -881,7 +876,7 @@ def load_sources(suricata_version):
                 url = source["url"] % params
             else:
                 if not index:
-                    raise ApplicationError(
+                    raise exceptions.ApplicationError(
                         "Source index is required for source %s; "
                         "run suricata-update update-sources" % (source["source"]))
                 url = index.resolve_url(name, params)
@@ -891,7 +886,7 @@ def load_sources(suricata_version):
     if config.get("sources"):
         for url in config.get("sources"):
             if type(url) not in [type("")]:
-                raise InvalidConfigurationError(
+                raise exceptions.InvalidConfigurationError(
                     "Invalid datatype for source URL: %s" % (str(url)))
             url = url % internal_params
             logger.debug("Adding source %s.", url)
@@ -1363,7 +1358,7 @@ def _main():
 def main():
     try:
         sys.exit(_main())
-    except ApplicationError as err:
+    except exceptions.ApplicationError as err:
         logger.error(err)
     sys.exit(1)
 
index ac64ce1a459a17fc4e989e0dc41d8e72f9703c25..9368d6cdba3dfe64b677562d0ce04856c9567072 100644 (file)
@@ -67,6 +67,9 @@ def get_source_index_url():
     return DEFAULT_SOURCE_INDEX_URL
 
 def save_source_config(source_config):
+    if not os.path.exists(get_source_directory()):
+        logger.info("Creating directory %s", get_source_directory())
+        os.makedirs(get_source_directory())
     with open(get_enabled_source_filename(source_config.name), "w") as fileobj:
         fileobj.write(yaml.safe_dump(
             source_config.dict(), default_flow_style=False))