From: Vagisha Gupta Date: Fri, 13 Sep 2019 05:48:13 +0000 (+0530) Subject: Cleanup scattered main imports X-Git-Tag: 1.2.0rc1~31 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e9a691447cf18ee557a4634e276c58f545a02c51;p=thirdparty%2Fsuricata-update.git Cleanup scattered main imports Currently, lot of names of a module are imported from a package by writing multiple import statements in main.py. Instead, Python's standard grouping mechanism (parentheses) is used to write the import statement to make them compact. Redmine issue: https://redmine.openinfosecfoundation.org/issues/2872 --- diff --git a/suricata/update/main.py b/suricata/update/main.py index 1f91d2d..ecc8b7f 100644 --- a/suricata/update/main.py +++ b/suricata/update/main.py @@ -46,22 +46,20 @@ except: print("error: pyyaml is required") sys.exit(1) -if sys.argv[0] == __file__: - sys.path.insert( - 0, os.path.abspath(os.path.join(__file__, "..", "..", ".."))) - -import suricata.update.rule -import suricata.update.engine -import suricata.update.net -import suricata.update.loghandler -from suricata.update import config -from suricata.update import configs -from suricata.update import extract -from suricata.update import util -from suricata.update import sources -from suricata.update import commands -from suricata.update import exceptions -from suricata.update import notes +from suricata.update import ( + commands, + config, + configs, + engine, + exceptions, + extract, + loghandler, + net, + notes, + rule as rule_mod, + sources, + util, +) from suricata.update.version import version try: @@ -69,10 +67,14 @@ try: except: revision = None +if sys.argv[0] == __file__: + sys.path.insert( + 0, os.path.abspath(os.path.join(__file__, "..", "..", ".."))) + # Initialize logging, use colour if on a tty. if len(logging.root.handlers) == 0: logger = logging.getLogger() - suricata.update.loghandler.configure_logging() + loghandler.configure_logging() logger.setLevel(level=logging.INFO) else: logging.basicConfig( @@ -256,7 +258,7 @@ class ModifyRuleFilter(object): def run(self, rule): modified_rule = self.pattern.sub(self.repl, rule.format()) - parsed = suricata.update.rule.parse(modified_rule, rule.group) + parsed = rule_mod.parse(modified_rule, rule.group) if parsed is None: logger.error("Modification of rule %s results in invalid rule: %s", rule.idstr, modified_rule) @@ -294,7 +296,7 @@ class DropRuleFilter(object): return self.matcher.match(rule) def run(self, rule): - drop_rule = suricata.update.rule.parse(re.sub("^\w+", "drop", rule.raw)) + drop_rule = rule_mod.parse(re.sub("^\w+", "drop", rule.raw)) drop_rule.enabled = rule.enabled return drop_rule @@ -310,7 +312,7 @@ class Fetch: open(tmp_filename, "rb").read()).hexdigest().strip() remote_checksum_buf = io.BytesIO() logger.info("Checking %s." % (checksum_url)) - suricata.update.net.get(checksum_url, remote_checksum_buf) + net.get(checksum_url, remote_checksum_buf) remote_checksum = remote_checksum_buf.getvalue().decode().strip() logger.debug("Local checksum=|%s|; remote checksum=|%s|" % ( local_checksum, remote_checksum)) @@ -382,7 +384,7 @@ class Fetch: logger.info("Fetching %s." % (url)) try: tmp_fileobj = tempfile.NamedTemporaryFile() - suricata.update.net.get( + net.get( net_arg, tmp_fileobj, progress_hook=self.progress_hook) @@ -597,7 +599,7 @@ def write_merged(filename, rulemap): oldset = {} if os.path.exists(filename): - for rule in suricata.update.rule.parse_file(filename): + for rule in rule_mod.parse_file(filename): oldset[rule.id] = True if not rule.id in rulemap: removed.append(rule) @@ -636,7 +638,7 @@ def write_to_directory(directory, files, rulemap): directory, os.path.basename(filename)) if os.path.exists(outpath): - for rule in suricata.update.rule.parse_file(outpath): + for rule in rule_mod.parse_file(outpath): oldset[rule.id] = True if not rule.id in rulemap: removed.append(rule) @@ -665,7 +667,7 @@ def write_to_directory(directory, files, rulemap): else: content = [] for line in io.StringIO(files[filename].decode("utf-8")): - rule = suricata.update.rule.parse(line) + rule = rule_mod.parse(line) if not rule: content.append(line.strip()) else: @@ -690,11 +692,11 @@ def write_sid_msg_map(filename, rulemap, version=1): for key in rulemap: rule = rulemap[key] if version == 2: - formatted = suricata.update.rule.format_sidmsgmap_v2(rule) + formatted = rule_mod.format_sidmsgmap_v2(rule) if formatted: print(formatted, file=fileobj) else: - formatted = suricata.update.rule.format_sidmsgmap(rule) + formatted = rule_mod.format_sidmsgmap(rule) if formatted: print(formatted, file=fileobj) @@ -732,7 +734,7 @@ def dump_sample_configs(): shutil.copy(os.path.join(configs.directory, filename), filename) def resolve_flowbits(rulemap, disabled_rules): - flowbit_resolver = suricata.update.rule.FlowbitResolver() + flowbit_resolver = rule_mod.FlowbitResolver() flowbit_enabled = set() while True: flowbits = flowbit_resolver.get_required_flowbits(rulemap) @@ -854,32 +856,28 @@ def check_vars(suriconf, rulemap): for rule_id in rulemap: rule = rulemap[rule_id] disable = False - for var in suricata.update.rule.parse_var_names( - rule["source_addr"]): + for var in rule_mod.parse_var_names(rule["source_addr"]): if not suriconf.has_key("vars.address-groups.%s" % (var)): logger.warning( "Rule has unknown source address var and will be disabled: %s: %s" % ( var, rule.brief())) notes.address_group_vars.add(var) disable = True - for var in suricata.update.rule.parse_var_names( - rule["dest_addr"]): + for var in rule_mod.parse_var_names(rule["dest_addr"]): if not suriconf.has_key("vars.address-groups.%s" % (var)): logger.warning( "Rule has unknown dest address var and will be disabled: %s: %s" % ( var, rule.brief())) notes.address_group_vars.add(var) disable = True - for var in suricata.update.rule.parse_var_names( - rule["source_port"]): + for var in rule_mod.parse_var_names(rule["source_port"]): if not suriconf.has_key("vars.port-groups.%s" % (var)): logger.warning( "Rule has unknown source port var and will be disabled: %s: %s" % ( var, rule.brief())) notes.port_group_vars.add(var) disable = True - for var in suricata.update.rule.parse_var_names( - rule["dest_port"]): + for var in rule_mod.parse_var_names(rule["dest_port"]): if not suriconf.has_key("vars.port-groups.%s" % (var)): logger.warning( "Rule has unknown dest port var and will be disabled: %s: %s" % ( @@ -917,14 +915,14 @@ def test_suricata(suricata_path): logger.info("Testing with suricata -T.") suricata_conf = config.get("suricata-conf") if not config.get("no-merge"): - if not suricata.update.engine.test_configuration( + if not engine.test_configuration( suricata_path, suricata_conf, os.path.join( - config.get_output_dir(), DEFAULT_OUTPUT_RULE_FILENAME)): + config.get_output_dir(), + DEFAULT_OUTPUT_RULE_FILENAME)): return False else: - if not suricata.update.engine.test_configuration( - suricata_path, suricata_conf): + if not engine.test_configuration(suricata_path, suricata_conf): return False return True @@ -1300,15 +1298,14 @@ def _main(): # use that, otherwise attempt to get it from Suricata. if args.suricata_version: # The Suricata version was passed on the command line, parse it. - suricata_version = suricata.update.engine.parse_version( - args.suricata_version) + suricata_version = engine.parse_version(args.suricata_version) if not suricata_version: logger.error("Failed to parse provided Suricata version: %s" % ( args.suricata_version)) return 1 logger.info("Forcing Suricata version to %s." % (suricata_version.full)) elif suricata_path: - suricata_version = suricata.update.engine.get_version(suricata_path) + suricata_version = engine.get_version(suricata_path) if suricata_version: logger.info("Found Suricata version %s at %s." % ( str(suricata_version.full), suricata_path)) @@ -1318,12 +1315,11 @@ def _main(): else: logger.info( "Using default Suricata version of %s", DEFAULT_SURICATA_VERSION) - suricata_version = suricata.update.engine.parse_version( - DEFAULT_SURICATA_VERSION) + suricata_version = engine.parse_version(DEFAULT_SURICATA_VERSION) # Provide the Suricata version to the net module to add to the # User-Agent. - suricata.update.net.set_user_agent_suricata_version(suricata_version.full) + net.set_user_agent_suricata_version(suricata_version.full) if args.subcommand: if args.subcommand == "check-versions" and hasattr(args, "func"): @@ -1380,7 +1376,7 @@ def _main(): suricata_path and os.path.exists(suricata_path): logger.info("Loading %s",config.get("suricata-conf")) try: - suriconf = suricata.update.engine.Configuration.load( + suriconf = engine.Configuration.load( config.get("suricata-conf"), suricata_path=suricata_path) except subprocess.CalledProcessError: return 1 @@ -1420,8 +1416,7 @@ def _main(): if not filename.endswith(".rules"): continue logger.debug("Parsing %s." % (filename)) - rules += suricata.update.rule.parse_fileobj( - io.BytesIO(files[filename]), filename) + rules += rule_mod.parse_fileobj(io.BytesIO(files[filename]), filename) rulemap = build_rule_map(rules) logger.info("Loaded %d rules." % (len(rules)))