Also make the verbose logging info.
from suricata.update.commands import listenabledsources
from suricata.update.commands import addsource
from suricata.update.commands import listsources
+from suricata.update.commands import updatesources
--- /dev/null
+# Copyright (C) 2017 Open Information Security Foundation
+#
+# You can copy, redistribute or modify this Program under the terms of
+# the GNU General Public License version 2 as published by the Free
+# Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# version 2 along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+from __future__ import print_function
+
+import os
+import logging
+import io
+
+from suricata.update import sources
+from suricata.update import net
+
+logger = logging.getLogger()
+
+def register(parser):
+ parser.set_defaults(func=update_sources)
+
+def update_sources(config):
+ local_index_filename = sources.get_index_filename(config)
+ with io.BytesIO() as fileobj:
+ try:
+ url = sources.get_source_index_url(config)
+ logger.info("Downloading %s", url)
+ net.get(url, fileobj)
+ except Exception as err:
+ raise Exception("Failed to download index: %s: %s" % (url, err))
+ if not os.path.exists(config.get_cache_dir()):
+ try:
+ os.makedirs(config.get_cache_dir())
+ except Exception as err:
+ logger.error("Failed to create directory %s: %s",
+ config.get_cache_dir(), err)
+ return 1
+ with open(local_index_filename, "w") as outobj:
+ outobj.write(fileobj.getvalue())
+ logger.info("Saved %s", local_index_filename)
enable_source_parser.add_argument("name")
enable_source_parser.add_argument("params", nargs="*", metavar="param=val")
- update_sources_parser = subparsers.add_parser(
- "update-sources", parents=[common_parser])
-
commands.listsources.register(subparsers.add_parser(
"list-sources", parents=[common_parser]))
commands.listenabledsources.register(subparsers.add_parser(
"list-enabled-sources", parents=[common_parser]))
commands.addsource.register(subparsers.add_parser(
"add-source", parents=[common_parser]))
+ commands.updatesources.register(subparsers.add_parser(
+ "update-sources", parents=[common_parser]))
args = parser.parse_args()
return 1
if args.subcommand:
- if args.subcommand == "update-sources":
- return sources.update_sources(config)
- elif args.subcommand == "enable-source":
+ if args.subcommand == "enable-source":
return sources.enable_source(config)
elif args.subcommand == "disable-source":
return sources.disable_source(config)
"""Return True if the source index file exists."""
return os.path.exists(get_index_filename(config))
+def get_source_index_url(config):
+ if os.getenv("SOURCE_INDEX_URL"):
+ return os.getenv("SOURCE_INDEX_URL")
+ return DEFAULT_SOURCE_INDEX_URL
+
def save_source_config(source_config):
with open(get_enabled_source_filename(source_config.name), "wb") as fileobj:
fileobj.write(yaml.safe_dump(
return sources
-def get_source_index_url(config):
- if os.getenv("SOURCE_INDEX_URL"):
- return os.getenv("SOURCE_INDEX_URL")
- return DEFAULT_SOURCE_INDEX_URL
-
-def update_sources(config):
- source_cache_filename = os.path.join(
- config.get_cache_dir(), SOURCE_INDEX_FILENAME)
- source_templates = {}
- with io.BytesIO() as fileobj:
- try:
- url = get_source_index_url(config)
- logger.debug("Downloading %s", url)
- net.get(get_source_index_url(config), fileobj)
- except Exception as err:
- raise Exception("Failed to download index: %s: %s" % (url, err))
- if not os.path.exists(config.get_cache_dir()):
- try:
- os.makedirs(config.get_cache_dir())
- except Exception as err:
- logger.error("Failed to create directory %s: %s",
- config.get_cache_dir(), err)
- return 1
- with open(source_cache_filename, "w") as outobj:
- outobj.write(fileobj.getvalue())
- logger.debug("Saved %s", source_cache_filename)
-
def load_sources(config):
sources_cache_filename = os.path.join(
config.get_cache_dir(), SOURCE_INDEX_FILENAME)