]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
Cleanup scattered main imports
authorVagisha Gupta <vagishagupta23@gmail.com>
Fri, 13 Sep 2019 05:48:13 +0000 (11:18 +0530)
committerJason Ish <ish@unx.ca>
Thu, 17 Oct 2019 23:06:29 +0000 (17:06 -0600)
Currently, lot of names of a module are imported from a package by
writing multiple import statements in main.py. Instead, Python's
standard grouping mechanism (parentheses) is used to write the import
statement to make them compact.

Redmine issue:
    https://redmine.openinfosecfoundation.org/issues/2872

suricata/update/main.py

index 1f91d2deb67eacccb925abae1cb6f23df126d4b8..ecc8b7f4cc37c6f22148e6d735f7f32d6e74a172 100644 (file)
@@ -46,22 +46,20 @@ except:
     print("error: pyyaml is required")
     sys.exit(1)
 
-if sys.argv[0] == __file__:
-    sys.path.insert(
-        0, os.path.abspath(os.path.join(__file__, "..", "..", "..")))
-
-import suricata.update.rule
-import suricata.update.engine
-import suricata.update.net
-import suricata.update.loghandler
-from suricata.update import config
-from suricata.update import configs
-from suricata.update import extract
-from suricata.update import util
-from suricata.update import sources
-from suricata.update import commands
-from suricata.update import exceptions
-from suricata.update import notes
+from suricata.update import (
+    commands,
+    config,
+    configs,
+    engine,
+    exceptions,
+    extract,
+    loghandler,
+    net,
+    notes,
+    rule as rule_mod,
+    sources,
+    util,
+)
 
 from suricata.update.version import version
 try:
@@ -69,10 +67,14 @@ try:
 except:
     revision = None
 
+if sys.argv[0] == __file__:
+    sys.path.insert(
+        0, os.path.abspath(os.path.join(__file__, "..", "..", "..")))
+
 # Initialize logging, use colour if on a tty.
 if len(logging.root.handlers) == 0:
     logger = logging.getLogger()
-    suricata.update.loghandler.configure_logging()
+    loghandler.configure_logging()
     logger.setLevel(level=logging.INFO)
 else:
     logging.basicConfig(
@@ -256,7 +258,7 @@ class ModifyRuleFilter(object):
 
     def run(self, rule):
         modified_rule = self.pattern.sub(self.repl, rule.format())
-        parsed = suricata.update.rule.parse(modified_rule, rule.group)
+        parsed = rule_mod.parse(modified_rule, rule.group)
         if parsed is None:
             logger.error("Modification of rule %s results in invalid rule: %s",
                          rule.idstr, modified_rule)
@@ -294,7 +296,7 @@ class DropRuleFilter(object):
         return self.matcher.match(rule)
 
     def run(self, rule):
-        drop_rule = suricata.update.rule.parse(re.sub("^\w+", "drop", rule.raw))
+        drop_rule = rule_mod.parse(re.sub("^\w+", "drop", rule.raw))
         drop_rule.enabled = rule.enabled
         return drop_rule
 
@@ -310,7 +312,7 @@ class Fetch:
                 open(tmp_filename, "rb").read()).hexdigest().strip()
             remote_checksum_buf = io.BytesIO()
             logger.info("Checking %s." % (checksum_url))
-            suricata.update.net.get(checksum_url, remote_checksum_buf)
+            net.get(checksum_url, remote_checksum_buf)
             remote_checksum = remote_checksum_buf.getvalue().decode().strip()
             logger.debug("Local checksum=|%s|; remote checksum=|%s|" % (
                 local_checksum, remote_checksum))
@@ -382,7 +384,7 @@ class Fetch:
         logger.info("Fetching %s." % (url))
         try:
             tmp_fileobj = tempfile.NamedTemporaryFile()
-            suricata.update.net.get(
+            net.get(
                 net_arg,
                 tmp_fileobj,
                 progress_hook=self.progress_hook)
@@ -597,7 +599,7 @@ def write_merged(filename, rulemap):
 
         oldset = {}
         if os.path.exists(filename):
-            for rule in suricata.update.rule.parse_file(filename):
+            for rule in rule_mod.parse_file(filename):
                 oldset[rule.id] = True
                 if not rule.id in rulemap:
                     removed.append(rule)
@@ -636,7 +638,7 @@ def write_to_directory(directory, files, rulemap):
                 directory, os.path.basename(filename))
 
             if os.path.exists(outpath):
-                for rule in suricata.update.rule.parse_file(outpath):
+                for rule in rule_mod.parse_file(outpath):
                     oldset[rule.id] = True
                     if not rule.id in rulemap:
                         removed.append(rule)
@@ -665,7 +667,7 @@ def write_to_directory(directory, files, rulemap):
         else:
             content = []
             for line in io.StringIO(files[filename].decode("utf-8")):
-                rule = suricata.update.rule.parse(line)
+                rule = rule_mod.parse(line)
                 if not rule:
                     content.append(line.strip())
                 else:
@@ -690,11 +692,11 @@ def write_sid_msg_map(filename, rulemap, version=1):
         for key in rulemap:
             rule = rulemap[key]
             if version == 2:
-                formatted = suricata.update.rule.format_sidmsgmap_v2(rule)
+                formatted = rule_mod.format_sidmsgmap_v2(rule)
                 if formatted:
                     print(formatted, file=fileobj)
             else:
-                formatted = suricata.update.rule.format_sidmsgmap(rule)
+                formatted = rule_mod.format_sidmsgmap(rule)
                 if formatted:
                     print(formatted, file=fileobj)
 
@@ -732,7 +734,7 @@ def dump_sample_configs():
             shutil.copy(os.path.join(configs.directory, filename), filename)
 
 def resolve_flowbits(rulemap, disabled_rules):
-    flowbit_resolver = suricata.update.rule.FlowbitResolver()
+    flowbit_resolver = rule_mod.FlowbitResolver()
     flowbit_enabled = set()
     while True:
         flowbits = flowbit_resolver.get_required_flowbits(rulemap)
@@ -854,32 +856,28 @@ def check_vars(suriconf, rulemap):
     for rule_id in rulemap:
         rule = rulemap[rule_id]
         disable = False
-        for var in suricata.update.rule.parse_var_names(
-                rule["source_addr"]):
+        for var in rule_mod.parse_var_names(rule["source_addr"]):
             if not suriconf.has_key("vars.address-groups.%s" % (var)):
                 logger.warning(
                     "Rule has unknown source address var and will be disabled: %s: %s" % (
                         var, rule.brief()))
                 notes.address_group_vars.add(var)
                 disable = True
-        for var in suricata.update.rule.parse_var_names(
-                rule["dest_addr"]):
+        for var in rule_mod.parse_var_names(rule["dest_addr"]):
             if not suriconf.has_key("vars.address-groups.%s" % (var)):
                 logger.warning(
                     "Rule has unknown dest address var and will be disabled: %s: %s" % (
                         var, rule.brief()))
                 notes.address_group_vars.add(var)
                 disable = True
-        for var in suricata.update.rule.parse_var_names(
-                rule["source_port"]):
+        for var in rule_mod.parse_var_names(rule["source_port"]):
             if not suriconf.has_key("vars.port-groups.%s" % (var)):
                 logger.warning(
                     "Rule has unknown source port var and will be disabled: %s: %s" % (
                         var, rule.brief()))
                 notes.port_group_vars.add(var)
                 disable = True
-        for var in suricata.update.rule.parse_var_names(
-                rule["dest_port"]):
+        for var in rule_mod.parse_var_names(rule["dest_port"]):
             if not suriconf.has_key("vars.port-groups.%s" % (var)):
                 logger.warning(
                     "Rule has unknown dest port var and will be disabled: %s: %s" % (
@@ -917,14 +915,14 @@ def test_suricata(suricata_path):
         logger.info("Testing with suricata -T.")
         suricata_conf = config.get("suricata-conf")
         if not config.get("no-merge"):
-            if not suricata.update.engine.test_configuration(
+            if not engine.test_configuration(
                     suricata_path, suricata_conf,
                     os.path.join(
-                        config.get_output_dir(), DEFAULT_OUTPUT_RULE_FILENAME)):
+                        config.get_output_dir(),
+                        DEFAULT_OUTPUT_RULE_FILENAME)):
                 return False
         else:
-            if not suricata.update.engine.test_configuration(
-                    suricata_path, suricata_conf):
+            if not engine.test_configuration(suricata_path, suricata_conf):
                 return False
 
     return True
@@ -1300,15 +1298,14 @@ def _main():
     # use that, otherwise attempt to get it from Suricata.
     if args.suricata_version:
         # The Suricata version was passed on the command line, parse it.
-        suricata_version = suricata.update.engine.parse_version(
-            args.suricata_version)
+        suricata_version = engine.parse_version(args.suricata_version)
         if not suricata_version:
             logger.error("Failed to parse provided Suricata version: %s" % (
                 args.suricata_version))
             return 1
         logger.info("Forcing Suricata version to %s." % (suricata_version.full))
     elif suricata_path:
-        suricata_version = suricata.update.engine.get_version(suricata_path)
+        suricata_version = engine.get_version(suricata_path)
         if suricata_version:
             logger.info("Found Suricata version %s at %s." % (
                 str(suricata_version.full), suricata_path))
@@ -1318,12 +1315,11 @@ def _main():
     else:
         logger.info(
             "Using default Suricata version of %s", DEFAULT_SURICATA_VERSION)
-        suricata_version = suricata.update.engine.parse_version(
-            DEFAULT_SURICATA_VERSION)
+        suricata_version = engine.parse_version(DEFAULT_SURICATA_VERSION)
 
     # Provide the Suricata version to the net module to add to the
     # User-Agent.
-    suricata.update.net.set_user_agent_suricata_version(suricata_version.full)
+    net.set_user_agent_suricata_version(suricata_version.full)
 
     if args.subcommand:
         if args.subcommand == "check-versions" and hasattr(args, "func"):
@@ -1380,7 +1376,7 @@ def _main():
        suricata_path and os.path.exists(suricata_path):
         logger.info("Loading %s",config.get("suricata-conf"))
         try:
-            suriconf = suricata.update.engine.Configuration.load(
+            suriconf = engine.Configuration.load(
                 config.get("suricata-conf"), suricata_path=suricata_path)
         except subprocess.CalledProcessError:
             return 1
@@ -1420,8 +1416,7 @@ def _main():
         if not filename.endswith(".rules"):
             continue
         logger.debug("Parsing %s." % (filename))
-        rules += suricata.update.rule.parse_fileobj(
-            io.BytesIO(files[filename]), filename)
+        rules += rule_mod.parse_fileobj(io.BytesIO(files[filename]), filename)
 
     rulemap = build_rule_map(rules)
     logger.info("Loaded %d rules." % (len(rules)))