]> git.ipfire.org Git - thirdparty/suricata-update.git/commitdiff
Use header in the request for sources
authorShivani Bhardwaj <shivanib134@gmail.com>
Wed, 5 Dec 2018 19:04:26 +0000 (00:34 +0530)
committerJason Ish <ish@unx.ca>
Fri, 14 Dec 2018 17:57:43 +0000 (11:57 -0600)
Allows the use of header while requesting for a source.

Alongwith 75f6c327, closes redmine ticket #2577

suricata/update/commands/addsource.py
suricata/update/main.py
suricata/update/net.py
suricata/update/sources.py

index 6dbf400f8c45430aa7bddca976e2bc69ecd4318e..42d4c638cc47a8c07a93bc30cb0006f16f9e58fa 100644 (file)
@@ -32,12 +32,12 @@ def register(parser):
     parser.add_argument("name", metavar="<name>", nargs="?",
                         help="Name of source")
     parser.add_argument("url", metavar="<url>", nargs="?", help="Source URL")
-    parser.add_argument("--header", metavar="<header>", help="HTTP Header")
+    parser.add_argument("--http-header", metavar="<http-header>",
+                        help="Additional HTTP header to add to requests")
     parser.set_defaults(func=add_source)
 
 def add_source():
     args = config.args()
-    header = None
 
     if args.name:
         name = args.name
@@ -59,8 +59,7 @@ def add_source():
             if url:
                 break
 
-    if args.header:
-        header = args.header
+    header = args.http_header if args.http_header else None
 
     source_config = sources.SourceConfiguration(name, header=header, url=url)
     sources.save_source_config(source_config)
index fe6e3332ae2c22ae1ceda9fd78b9717aa7dbb507..e048c679bf24d581eb5c55b64221219d680a6880 100644 (file)
@@ -356,6 +356,8 @@ class Fetch:
             "%s-%s" % (url_hash, self.url_basename(url)))
 
     def fetch(self, url):
+        net_arg = url
+        url = url[0] if isinstance(url, tuple) else url
         tmp_filename = self.get_tmp_filename(url)
         if not config.args().force and os.path.exists(tmp_filename):
             if not config.args().now and \
@@ -373,7 +375,7 @@ class Fetch:
         try:
             tmp_fileobj = tempfile.NamedTemporaryFile()
             suricata.update.net.get(
-                url,
+                net_arg,
                 tmp_fileobj,
                 progress_hook=self.progress_hook)
             shutil.copyfile(tmp_fileobj.name, tmp_filename)
@@ -399,6 +401,7 @@ class Fetch:
                 fetched = self.fetch(url)
                 files.update(fetched)
             except URLError as err:
+                url = url[0] if isinstance(url, tuple) else url
                 logger.error("Failed to fetch %s: %s", url, err)
         else:
             for url in self.args.url:
@@ -963,14 +966,15 @@ def load_sources(suricata_version):
             params.update(internal_params)
             if "url" in source:
                 # No need to go off to the index.
-                url = source["url"] % params
+                url = (source["url"] % params, source.get("http-header"))
+                logger.debug("Resolved source %s to URL %s.", name, url[0])
             else:
                 if not index:
                     raise exceptions.ApplicationError(
                         "Source index is required for source %s; "
                         "run suricata-update update-sources" % (source["source"]))
                 url = index.resolve_url(name, params)
-            logger.debug("Resolved source %s to URL %s.", name, url)
+                logger.debug("Resolved source %s to URL %s.", name, url)
             urls.append(url)
 
     if config.get("sources"):
index 394bb10e997f5b3dcb6c6a4715cdfc02dfbb14fe..1392bc55514bac752ef093b433ac06df487ed1b5 100644 (file)
@@ -20,6 +20,7 @@
 import platform
 import logging
 import ssl
+import re
 
 try:
     # Python 3.3...
@@ -76,6 +77,16 @@ def build_user_agent():
     return "Suricata-Update/%s (%s)" % (
         version, "; ".join(params))
 
+
+def is_header_clean(header):
+    if len(header) != 2:
+        return False
+    name, val = header[0].strip(), header[1].strip()
+    if re.match( r"^[\w-]+$", name) and re.match(r"^[\w-]+$", val):
+        return True
+    return False
+
+
 def get(url, fileobj, progress_hook=None):
     """ Perform a GET request against a URL writing the contents into
     the provideded file like object.
@@ -107,29 +118,43 @@ def get(url, fileobj, progress_hook=None):
 
     if user_agent:
         logger.debug("Setting HTTP User-Agent to %s", user_agent)
-        opener.addheaders = [("User-Agent", user_agent),]
+        http_headers = [("User-Agent", user_agent)]
     else:
-        opener.addheaders = [(header, value) for header,
-                             value in opener.addheaders if header.lower() != "user-agent"]
-    remote = opener.open(url)
-    info = remote.info()
+        http_headers = [(header, value) for header,
+                        value in opener.addheaders if header.lower() != "user-agent"]
+    if isinstance(url, tuple):
+        header = url[1].split(":") if url[1] is not None else None
+        if header and is_header_clean(header=header):
+            name, val = header[0].strip(), header[1].strip()
+            logger.debug("Setting HTTP header %s to %s", name, val)
+            http_headers.append((name, val))
+        elif header:
+            logger.error("Header not set as it does not meet the criteria")
+        url = url[0]
+    opener.addheaders = http_headers
+
     try:
-        content_length = int(info["content-length"])
-    except:
-        content_length = 0
-    bytes_read = 0
-    while True:
-        buf = remote.read(GET_BLOCK_SIZE)
-        if not buf:
-            # EOF
-            break
-        bytes_read += len(buf)
-        fileobj.write(buf)
-        if progress_hook:
-            progress_hook(content_length, bytes_read)
-    remote.close()
-    fileobj.flush()
-    return bytes_read, info
+        remote = opener.open(url)
+    except ValueError as ve:
+        logger.error(ve)
+    else:
+        info = remote.info()
+        content_length = info.get("content-length")
+        content_length = int(content_length) if content_length else 0
+        bytes_read = 0
+        while True:
+            buf = remote.read(GET_BLOCK_SIZE)
+            if not buf:
+                # EOF
+                break
+            bytes_read += len(buf)
+            fileobj.write(buf)
+            if progress_hook:
+                progress_hook(content_length, bytes_read)
+        remote.close()
+        fileobj.flush()
+        return bytes_read, info
+
 
 if __name__ == "__main__":
 
index d00193e89bad70218cb68b7bf7dd1b08ac34d122..87e90cf528a7419fe4192784144c4d61d1a42edc 100644 (file)
@@ -92,7 +92,7 @@ class SourceConfiguration:
         if self.params:
             d["params"] = self.params
         if self.header:
-            d["header"] = self.header
+            d["http-header"] = self.header
         return d
 
 class Index: