]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Fix pcre post-filtering
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 9 Dec 2015 17:46:26 +0000 (17:46 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 9 Dec 2015 17:46:26 +0000 (17:46 +0000)
src/libserver/re_cache.c
src/lua/lua_regexp.c
src/plugins/lua/spamassassin.lua

index 4f4a1881c505aab997a630eeb9d4c392d6b46101..f38b254efb7726b3113c91303a00159160a50055 100644 (file)
@@ -61,7 +61,7 @@
 
 #ifdef WITH_HYPERSCAN
 #define RSPAMD_HS_MAGIC_LEN (sizeof (rspamd_hs_magic))
-static const guchar rspamd_hs_magic[] = {'r', 's', 'h', 's', 'r', 'e', '1', '0'};
+static const guchar rspamd_hs_magic[] = {'r', 's', 'h', 's', 'r', 'e', '1', '1'};
 #endif
 
 struct rspamd_re_class {
@@ -462,27 +462,23 @@ rspamd_re_cache_hyperscan_cb (unsigned int id,
        rt = cbdata->rt;
 
        pcre_elt = g_ptr_array_index (rt->cache->re, id);
-
-       if (flags & HS_FLAG_PREFILTER) {
-               if (!isset (rt->checked, id)) {
-                       /* We need to match the corresponding pcre first */
-                       ret = rspamd_re_cache_process_pcre (rt,
-                                       pcre_elt->re,
-                                       cbdata->in + from,
-                                       to - from,
-                                       FALSE);
-
-                       setbit (rt->checked, id);
-                       rt->results[id] = ret;
-               }
+       maxhits = rspamd_regexp_get_maxhits (pcre_elt->re);
+       ret = 1;
+
+       if (pcre_elt->match_type == RSPAMD_RE_CACHE_HYPERSCAN_PRE) {
+               /* We need to match the corresponding pcre first */
+               ret = rspamd_re_cache_process_pcre (rt,
+                               pcre_elt->re,
+                               cbdata->in + from,
+                               to - from,
+                               FALSE);
+               msg_info ("pcre: %s", rspamd_regexp_get_pattern (pcre_elt->re));
        }
-       else {
-               maxhits = rspamd_regexp_get_maxhits (pcre_elt->re);
-               setbit (rt->checked, id);
 
-               if (maxhits == 0 || rt->results[id] < maxhits) {
-                       rt->results[id]++;
-               }
+       setbit (rt->checked, id);
+
+       if (maxhits == 0 || rt->results[id] < maxhits) {
+               rt->results[id] += ret;
        }
 
        return 0;
@@ -962,7 +958,7 @@ rspamd_re_cache_compile_hyperscan (struct rspamd_re_cache *cache,
        const gchar **hs_pats = NULL;
        gchar *hs_serialized;
        gsize serialized_len, total = 0;
-       struct iovec iov[6];
+       struct iovec iov[7];
 
        g_hash_table_iter_init (&it, cache->re_classes);
 
@@ -1074,7 +1070,6 @@ rspamd_re_cache_compile_hyperscan (struct rspamd_re_cache *cache,
                                return -1;
                        }
 
-                       g_free (hs_flags);
                        g_free (hs_pats);
 
                        if (hs_serialize_database (test_db, &hs_serialized,
@@ -1087,6 +1082,7 @@ rspamd_re_cache_compile_hyperscan (struct rspamd_re_cache *cache,
 
                                close (fd);
                                g_free (hs_ids);
+                               g_free (hs_flags);
                                hs_free_database (test_db);
 
                                return -1;
@@ -1099,6 +1095,7 @@ rspamd_re_cache_compile_hyperscan (struct rspamd_re_cache *cache,
                         * Platform - sizeof (platform)
                         * n - number of regexps
                         * n * <regexp ids>
+                        * n * <regexp flags>
                         * crc - 8 bytes checksum
                         * <hyperscan blob>
                         */
@@ -1111,10 +1108,12 @@ rspamd_re_cache_compile_hyperscan (struct rspamd_re_cache *cache,
                        iov[2].iov_len = sizeof (n);
                        iov[3].iov_base = hs_ids;
                        iov[3].iov_len = sizeof (*hs_ids) * n;
-                       iov[4].iov_base = &crc;
-                       iov[4].iov_len = sizeof (crc);
-                       iov[5].iov_base = hs_serialized;
-                       iov[5].iov_len = serialized_len;
+                       iov[4].iov_base = hs_flags;
+                       iov[4].iov_len = sizeof (*hs_flags) * n;
+                       iov[5].iov_base = &crc;
+                       iov[5].iov_len = sizeof (crc);
+                       iov[6].iov_base = hs_serialized;
+                       iov[6].iov_len = serialized_len;
 
                        if (writev (fd, iov, G_N_ELEMENTS (iov)) == -1) {
                                g_set_error (err,
@@ -1124,6 +1123,7 @@ rspamd_re_cache_compile_hyperscan (struct rspamd_re_cache *cache,
                                                path, strerror (errno));
                                close (fd);
                                g_free (hs_ids);
+                               g_free (hs_flags);
                                g_free (hs_serialized);
 
                                return -1;
@@ -1133,6 +1133,7 @@ rspamd_re_cache_compile_hyperscan (struct rspamd_re_cache *cache,
 
                        g_free (hs_serialized);
                        g_free (hs_ids);
+                       g_free (hs_flags);
                }
 
                close (fd);
@@ -1249,7 +1250,7 @@ rspamd_re_cache_load_hyperscan (struct rspamd_re_cache *cache,
        return FALSE;
 #else
        gchar path[PATH_MAX];
-       gint fd, i, n, *hs_ids = NULL, total = 0;
+       gint fd, i, n, *hs_ids = NULL, *hs_flags = NULL, total = 0;
        GHashTableIter it;
        gpointer k, v;
        guint8 *map, *p, *end;
@@ -1287,7 +1288,7 @@ rspamd_re_cache_load_hyperscan (struct rspamd_re_cache *cache,
                        p = map + RSPAMD_HS_MAGIC_LEN + sizeof (cache->plt);
                        n = *(gint *)p;
 
-                       if (n <= 0 || n * sizeof (gint) + /* IDs */
+                       if (n <= 0 || 2 * n * sizeof (gint) + /* IDs + flags */
                                                        sizeof (guint64) + /* crc */
                                                        RSPAMD_HS_MAGIC_LEN + /* header */
                                                        sizeof (cache->plt) > (gsize)st.st_size) {
@@ -1302,6 +1303,9 @@ rspamd_re_cache_load_hyperscan (struct rspamd_re_cache *cache,
                        p += sizeof (n);
                        hs_ids = g_malloc (n * sizeof (*hs_ids));
                        memcpy (hs_ids, p, n * sizeof (*hs_ids));
+                       p += n * sizeof (*hs_ids);
+                       hs_flags = g_malloc (n * sizeof (*hs_flags));
+                       memcpy (hs_flags, p, n * sizeof (*hs_flags));
 
                        /* Skip crc */
                        p += n * sizeof (*hs_ids) + sizeof (guint64);
@@ -1311,6 +1315,7 @@ rspamd_re_cache_load_hyperscan (struct rspamd_re_cache *cache,
                                msg_err_re_cache ("bad hs database in %s", path);
                                munmap (map, st.st_size);
                                g_free (hs_ids);
+                               g_free (hs_flags);
 
                                return FALSE;
                        }
@@ -1327,10 +1332,17 @@ rspamd_re_cache_load_hyperscan (struct rspamd_re_cache *cache,
                        for (i = 0; i < n; i ++) {
                                g_assert ((gint)cache->re->len > hs_ids[i] && hs_ids[i] >= 0);
                                elt = g_ptr_array_index (cache->re, hs_ids[i]);
-                               elt->match_type = RSPAMD_RE_CACHE_HYPERSCAN;
+
+                               if (hs_flags[i] & HS_FLAG_PREFILTER) {
+                                       elt->match_type = RSPAMD_RE_CACHE_HYPERSCAN_PRE;
+                               }
+                               else {
+                                       elt->match_type = RSPAMD_RE_CACHE_HYPERSCAN;
+                               }
                        }
 
                        re_class->hs_ids = hs_ids;
+                       g_free (hs_flags);
                        re_class->nhs = n;
                }
                else {
index d19ed83e6f84bf2a552c9798945ced1d22e1361b..385e974cc834d394e591eb345139def6b16095c2 100644 (file)
@@ -43,6 +43,7 @@ LUA_FUNCTION_DEF (regexp, get_cached);
 LUA_FUNCTION_DEF (regexp, get_pattern);
 LUA_FUNCTION_DEF (regexp, set_limit);
 LUA_FUNCTION_DEF (regexp, set_max_hits);
+LUA_FUNCTION_DEF (regexp, get_max_hits);
 LUA_FUNCTION_DEF (regexp, search);
 LUA_FUNCTION_DEF (regexp, match);
 LUA_FUNCTION_DEF (regexp, matchn);
@@ -54,6 +55,7 @@ static const struct luaL_reg regexplib_m[] = {
        LUA_INTERFACE_DEF (regexp, get_pattern),
        LUA_INTERFACE_DEF (regexp, set_limit),
        LUA_INTERFACE_DEF (regexp, set_max_hits),
+       LUA_INTERFACE_DEF (regexp, get_max_hits),
        LUA_INTERFACE_DEF (regexp, match),
        LUA_INTERFACE_DEF (regexp, matchn),
        LUA_INTERFACE_DEF (regexp, search),
@@ -271,7 +273,8 @@ lua_regexp_set_limit (lua_State *L)
 /***
  * @method re:set_max_hits(lim)
  * Set maximum number of hits returned by a regexp
- * @param {number} lim limit in bytes
+ * @param {number} lim limit in hits count
+ * @return {number} old number of max hits
  */
 static int
 lua_regexp_set_max_hits (lua_State *L)
@@ -291,6 +294,26 @@ lua_regexp_set_max_hits (lua_State *L)
        return 1;
 }
 
+/***
+ * @method re:get_max_hits(lim)
+ * Get maximum number of hits returned by a regexp
+ * @return {number} number of max hits
+ */
+static int
+lua_regexp_get_max_hits (lua_State *L)
+{
+       struct rspamd_lua_regexp *re = lua_check_regexp (L);
+
+       if (re && re->re && !IS_DESTROYED (re)) {
+               lua_pushnumber (L, rspamd_regexp_get_maxhits (re->re));
+       }
+       else {
+               lua_pushnumber (L, 1);
+       }
+
+       return 1;
+}
+
 /***
  * @method re:search(line[, raw[, capture]])
  * Search line in regular expression object. If line matches then this
index 7fbf00942ea8496037d93e95e498d0a548cf7f38..e2708bbeb3ae719bf9d5b25fff47dc56ecfd1937 100644 (file)
@@ -940,7 +940,7 @@ _.each(function(r)
         rspamd_logger.errx(rspamd_config, 'cannot apply replacement for rule %1', r)
         rule['re'] = nil
       else
-        local old_max_hits = rule['re']:set_limit(0)
+        local old_max_hits = rule['re']:get_max_hits()
         rspamd_logger.debugx(rspamd_config, 'replace %1 -> %2', r, nexpr)
         rspamd_config:replace_regexp({
           old_re = rule['re'],