]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Test] Test for TCP library
authorMikhail Galanin <mgalanin@mimecast.com>
Thu, 30 Aug 2018 15:51:55 +0000 (16:51 +0100)
committerMikhail Galanin <mgalanin@mimecast.com>
Thu, 30 Aug 2018 15:51:55 +0000 (16:51 +0100)
test/functional/cases/220_http.robot
test/functional/cases/230_tcp.robot [new file with mode: 0644]
test/functional/lua/http.lua
test/functional/lua/tcp.lua [new file with mode: 0644]
test/functional/util/dummy_http.py

index a8f47faa81f72125aab05fda62fbb42ebb653cf1..b1ac67bd91e85d757e002ef8ac779fce864c1161 100644 (file)
@@ -11,8 +11,7 @@ Variables       ${TESTDIR}/lib/vars.py
 ${URL_TLD}      ${TESTDIR}/../lua/unit/test_tld.dat
 ${CONFIG}       ${TESTDIR}/configs/lua_test.conf
 ${MESSAGE}      ${TESTDIR}/messages/spam_message.eml
-${REDIS_SCOPE}  Suite
-${RSPAMD_SCOPE}  Suite
+${RSPAMD_SCOPE}  Test
 
 *** Test Cases ***
 Simple HTTP request
diff --git a/test/functional/cases/230_tcp.robot b/test/functional/cases/230_tcp.robot
new file mode 100644 (file)
index 0000000..4d8b2fb
--- /dev/null
@@ -0,0 +1,63 @@
+*** Settings ***
+Test Setup      Http Setup
+Test Teardown   Http Teardown
+Library         Process
+Library         ${TESTDIR}/lib/rspamd.py
+Resource        ${TESTDIR}/lib/rspamd.robot
+Variables       ${TESTDIR}/lib/vars.py
+
+*** Variables ***
+# ${CONFIG}       ${TESTDIR}/configs/http.conf
+${URL_TLD}      ${TESTDIR}/../lua/unit/test_tld.dat
+${CONFIG}       ${TESTDIR}/configs/lua_test.conf
+${MESSAGE}      ${TESTDIR}/messages/spam_message.eml
+${RSPAMD_SCOPE}  Test
+
+*** Test Cases ***
+Simple TCP request
+  ${result} =  Scan Message With Rspamc  ${MESSAGE}
+  Check Rspamc  ${result}  HTTP_ASYNC_RESPONSE
+  Check Rspamc  ${result}  HTTP_ASYNC_RESPONSE_2
+
+
+Sync API TCP request
+  ${result} =  Scan Message With Rspamc  ${MESSAGE}
+  Check Rspamc  ${result}  HTTP_SYNC_RESPONSE
+  Check Rspamc  ${result}  HTTP_SYNC_RESPONSE_2
+  Check Rspamc  ${result}  hello world
+  Check Rspamc  ${result}  hello post
+
+Sync API TCP get request
+  Check url  /request  get  HTTP_SYNC_EOF_get (0.00)[hello world]
+  Check url  /content-length  get  HTTP_SYNC_CONTENT_get (0.00)[hello world]
+
+Sync API TCP post request
+  Check url  /request  post  HTTP_SYNC_EOF_post (0.00)[hello post]
+  Check url  /content-length  post  HTTP_SYNC_CONTENT_post (0.00)[hello post]
+
+*** Keywords ***
+Lua Setup
+  [Arguments]  ${LUA_SCRIPT}
+  Set Global Variable  ${LUA_SCRIPT}
+  Generic Setup
+
+Http Setup
+  Run Dummy Http
+  Lua Setup  ${TESTDIR}/lua/tcp.lua
+
+Http Teardown
+  ${http_pid} =  Get File  /tmp/dummy_http.pid
+  Shutdown Process With Children  ${http_pid}
+  Normal Teardown
+
+Run Dummy Http
+  [Arguments]
+  ${result} =  Start Process  ${TESTDIR}/util/dummy_http.py
+  Wait Until Created  /tmp/dummy_http.pid
+
+
+Check url
+  [Arguments]  ${url}  ${method}  @{expect_results}
+  ${result} =  Scan Message With Rspamc  --header=url:${url}  --header=method:${method}  ${MESSAGE}
+  : FOR  ${expect}  IN  @{expect_results}
+  \  Check Rspamc  ${result}  ${expect}
index 44a6c6fd3bf53a663fb63081ae0b9b257485175a..0c1eff8ba072629ec6b83a0dfceaca85f74433e6 100644 (file)
@@ -43,12 +43,15 @@ local function http_symbol(task)
     timeout = 1,
   })
 
+  rspamd_logger.errx(task, 'rspamd_http.request[before]')
+
   local err, response = rspamd_http.request({
     url = 'http://127.0.0.1:18080' .. url,
     task = task,
     method = method,
     timeout = 1,
   })
+  rspamd_logger.errx(task, 'rspamd_http.request[done] err: %1 response:%2', err, response)
 
   if not err then
     task:insert_result('HTTP_CORO_' .. response.code, 1.0, response.content)
diff --git a/test/functional/lua/tcp.lua b/test/functional/lua/tcp.lua
new file mode 100644 (file)
index 0000000..5b0f474
--- /dev/null
@@ -0,0 +1,170 @@
+--[[[
+-- Just a test for TCP API
+--]]
+
+local rspamd_tcp = require "rspamd_tcp"
+local logger = require "rspamd_logger"
+local tcp_sync = require "lua_tcp_sync"
+
+-- [[ old fashioned callback api ]]
+local function http_simple_tcp_async_symbol(task)
+  logger.errx(task, 'http_tcp_symbol: begin')
+  local function http_get_cb(err, data, conn)
+    logger.errx(task, 'http_get_cb: got reply: %s, error: %s, conn: %s', data, err, conn)
+    task:insert_result('HTTP_ASYNC_RESPONSE_2', 1.0, data)
+  end
+  local function http_read_post_cb(err, conn)
+    logger.errx(task, 'http_read_post_cb: write done: error: %s, conn: %s', err, conn)
+    conn:add_read(http_get_cb)
+  end
+  local function http_read_cb(err, data, conn)
+    logger.errx(task, 'http_read_cb: got reply: %s, error: %s, conn: %s', data, err, conn)
+    conn:add_write(http_read_post_cb, "POST /request2 HTTP/1.1\r\n\r\n")
+    task:insert_result('HTTP_ASYNC_RESPONSE', 1.0, data)
+  end
+  rspamd_tcp:request({
+    task = task,
+    callback = http_read_cb,
+    host = '127.0.0.1',
+    data = {'GET /request HTTP/1.1\r\nConnection: keep-alive\r\n\r\n'},
+    read = true,
+    port = 18080,
+  })
+end
+
+local function http_simple_tcp_symbol(task)
+  logger.errx(task, 'connect_sync, before')
+
+  local err
+  local is_ok, connection = tcp_sync.connect {
+    task = task,
+    host = '127.0.0.1',
+    timeout = 20,
+    port = 18080,
+  }
+
+  logger.errx(task, 'connect_sync %1, %2', is_ok, tostring(connection))
+
+  is_ok, err = connection:write('GET /request_sync HTTP/1.1\r\nConnection: keep-alive\r\n\r\n')
+
+  logger.errx(task, 'write %1, %2', is_ok, err)
+  if not is_ok then
+    task:insert_result('HTTP_SYNC_WRITE_ERROR', 1.0, err)
+    logger.errx(task, 'write error: %1', err)
+  end
+
+  local data
+  is_ok, data = connection:read_once();
+
+  logger.errx(task, 'read_once: is_ok: %1, data: %2', is_ok, data)
+  if not is_ok then
+    task:insert_result('HTTP_SYNC_ERROR', 1.0, data)
+  else
+    task:insert_result('HTTP_SYNC_RESPONSE', 1.0, data)
+  end
+
+  is_ok, err = connection:write("POST /request2 HTTP/1.1\r\n\r\n")
+  logger.errx(task, 'write[2] %1, %2', is_ok, err)
+
+  is_ok, data = connection:read_once();
+  logger.errx(task, 'read_once[2]: is_ok %1, data: %2', is_ok, data)
+  if not is_ok then
+    task:insert_result('HTTP_SYNC_ERROR_2', 1.0, data)
+  else
+    task:insert_result('HTTP_SYNC_RESPONSE_2', 1.0, data)
+  end
+
+  connection:close()
+end
+
+local function http_tcp_symbol(task)
+  local url = tostring(task:get_request_header('url'))
+  local method = tostring(task:get_request_header('method'))
+
+  if url == 'nil' then
+    return
+  end
+
+  local err
+  local is_ok, connection = tcp_sync.connect {
+    task = task,
+    host = '127.0.0.1',
+    timeout = 20,
+    port = 18080,
+  }
+
+  logger.errx(task, 'connect_sync %1, %2', is_ok, tostring(connection))
+  if not is_ok then
+    logger.errx(task, 'connect error: %1', connection)
+    return
+  end
+
+  is_ok, err = connection:write(string.format('%s %s HTTP/1.1\r\nConnection: close\r\n\r\n', method:upper(), url))
+
+  logger.errx(task, 'write %1, %2', is_ok, err)
+  if not is_ok then
+    logger.errx(task, 'write error: %1', err)
+    return
+  end
+
+  local content_length, content
+
+  while true do
+    local header_line
+    is_ok, header_line = connection:read_until("\r\n")
+    if not is_ok then
+      logger.errx(task, 'failed to get header: %1', header_line)
+      return
+    end
+
+    if header_line == "" then
+      logger.errx(task, 'headers done')
+      break
+    end
+
+    local value
+    local header = header_line:gsub("([%w-]+): (.*)", 
+        function (h, v) value = v; return h:lower() end)
+
+    logger.errx(task, 'parsed header: %1 -> "%2"', header, value)
+
+    if header == "content-length" then
+      content_length = tonumber(value)
+    end
+
+  end
+
+  if content_length then
+    is_ok, content = connection:read_bytes(content_length)
+    if is_ok then
+      task:insert_result('HTTP_SYNC_CONTENT_' .. method, 1.0, content)
+    end
+  else
+    is_ok, content = connection:read_until_eof()
+    if is_ok then
+      task:insert_result('HTTP_SYNC_EOF_' .. method, 1.0, content)
+    end
+  end
+  logger.errx(task, '(is_ok: %1) content [%2 bytes] %3', is_ok, content_length, content)
+end
+
+rspamd_config:register_symbol({
+  name = 'SIMPLE_TCP_ASYNC_TEST',
+  score = 1.0,
+  callback = http_simple_tcp_async_symbol,
+  no_squeeze = true
+})
+rspamd_config:register_symbol({
+  name = 'SIMPLE_TCP_TEST',
+  score = 1.0,
+  callback = http_simple_tcp_symbol,
+  no_squeeze = true
+})
+
+rspamd_config:register_symbol({
+  name = 'HTTP_TCP_TEST',
+  score = 1.0,
+  callback = http_tcp_symbol,
+  no_squeeze = true
+})
+-- ]]
index 4f8e67ffd1750c6a68d7f20d9b74b43d29107d61..4814613ea85b62b43aeb0df1e22727611be66302 100755 (executable)
@@ -1,10 +1,14 @@
 #!/usr/bin/env python
 
 import BaseHTTPServer
+import SocketServer
+import SimpleHTTPServer
+
 import time
 import os
 import sys
 import signal
+import socket
 
 PORT = 18080
 HOST_NAME = '127.0.0.1'
@@ -14,12 +18,19 @@ PID = "/tmp/dummy_http.pid"
 
 class MyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
 
+    def setup(self):
+        BaseHTTPServer.BaseHTTPRequestHandler.setup(self)
+        self.protocol_version = "HTTP/1.1" # allow connection: keep-alive
+
     def do_HEAD(self):
         self.send_response(200)
         self.send_header("Content-type", "text/html")
         self.end_headers()
+        self.log_message("to be closed: " + self.close_connection)
 
     def do_GET(self):
+        response = "hello world"
+        
         """Respond to a GET request."""
         if self.path == "/empty":
             self.finish()
@@ -33,11 +44,23 @@ class MyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
         else:
             self.send_response(200)
 
+        if self.path == "/content-length":
+            self.send_header("Content-Length", str(len(response)))
+
         self.send_header("Content-type", "text/plain")
         self.end_headers()
-        self.wfile.write("hello world")
+        self.wfile.write(response)
+        self.log_message("to be closed: %d, headers: %s, conn:'%s'" % (self.close_connection, str(self.headers), self.headers.get('Connection', "").lower()))
+
+        conntype = self.headers.get('Connection', "").lower()
+        if conntype != 'keep-alive':
+            self.close_connection = True
+        
+        self.log_message("ka:'%s', pv:%s[%s]" % (str(conntype == 'keep-alive'), str(self.protocol_version >= "HTTP/1.1"), self.protocol_version))
+
 
     def do_POST(self):
+        response = "hello post"
         """Respond to a GET request."""
         if self.path == "/empty":
             self.finish()
@@ -50,30 +73,35 @@ class MyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
             self.send_response(403)
         else:
             self.send_response(200)
+            
+        if self.path == "/content-length":
+            self.send_header("Content-Length", str(len(response)))
 
         self.send_header("Content-type", "text/plain")
         self.end_headers()
-        self.wfile.write("hello post")
-
+        self.wfile.write(response)
 
-class MyHttp(BaseHTTPServer.HTTPServer):
-    def __init__(self, server_address, RequestHandlerClass, bind_and_activate=False):
-        BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
-        self.keep_running = True
 
+class ThreadingSimpleServer(SocketServer.ThreadingMixIn,
+                   BaseHTTPServer.HTTPServer):
+    def __init__(self):
+        BaseHTTPServer.HTTPServer.__init__(self, (HOST_NAME, PORT), MyHandler)
+        self.allow_reuse_address = True
+        self.timeout = 1
+        
     def run(self):
-        self.server_bind()
-        self.server_activate()
-
         with open(PID, 'w+') as f:
             f.write(str(os.getpid()))
             f.close()
-
-        while self.keep_running:
-            try:
-                self.handle_request()
-            except Exception:
-                pass
+        try:
+            while 1:
+                sys.stdout.flush()
+                server.handle_request()
+        except KeyboardInterrupt:
+            print "Interrupt"
+        except socket.error:
+            print "Socket closed"
+            pass
 
     def stop(self):
         self.keep_running = False
@@ -81,20 +109,13 @@ class MyHttp(BaseHTTPServer.HTTPServer):
 
 
 if __name__ == '__main__':
-    server_class = BaseHTTPServer.HTTPServer
-    httpd = MyHttp((HOST_NAME, PORT), MyHandler)
-    httpd.allow_reuse_address = True
-    httpd.timeout = 1
+    server = ThreadingSimpleServer()
 
     def alarm_handler(signum, frame):
-        httpd.stop()
+        server.stop()
 
     signal.signal(signal.SIGALRM, alarm_handler)
     signal.signal(signal.SIGTERM, alarm_handler)
-    signal.alarm(10)
+    signal.alarm(1000)
 
-    try:
-        httpd.run()
-    except KeyboardInterrupt:
-        pass
-    httpd.server_close()
+    server.run()