]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add symbol categories for MetaDefender and VirusTotal 5656/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 3 Oct 2025 14:43:27 +0000 (15:43 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 3 Oct 2025 14:43:27 +0000 (15:43 +0100)
Implemented a category-based symbol system for hash lookup antivirus
scanners (MetaDefender and VirusTotal) to replace dynamic scoring:

- Added 4 symbol categories: CLEAN (-0.5), LOW (2.0), MEDIUM (5.0), HIGH (8.0)
- Replaced full_score_engines with threshold-based categorization (low_category, medium_category)
- Fixed symbol registration in antivirus.lua to use rule instead of config
- Updated cache format to preserve symbol category across requests
- Added backward compatibility for old cache format
- Added symbols registration and metric score assignment
- Updated configuration documentation with examples

The new system provides:
- Clear threat categorization instead of linear interpolation
- Proper symbol weights applied automatically
- Consistent behavior between MetaDefender and VirusTotal
- Cache that preserves symbol categories

Configuration example:
metadefender {
  apikey = "KEY";
  type = "metadefender";
  minimum_engines = 3;
  low_category = 5;
  medium_category = 10;
}

conf/local.d/antivirus.conf.example
conf/modules.d/antivirus.conf
lualib/lua_scanners/common.lua
lualib/lua_scanners/metadefender.lua
lualib/lua_scanners/virustotal.lua
src/plugins/lua/antivirus.lua

index 50e0bdbfd86de2511e0a22e89ee86aefab68f37e..d0e1ad4d8f31e2771298ab0ca92523eee4f85c79 100644 (file)
@@ -4,44 +4,73 @@
 metadefender {
   # Required: Your MetaDefender API key from https://metadefender.opswat.com/
   apikey = "YOUR_API_KEY_HERE";
-  
-  # Symbol name (default: METADEFENDER_VIRUS)
-  symbol = "METADEFENDER_VIRUS";
-  
+
+  # Main symbol name (for compatibility, usually not used directly)
+  symbol = "METADEFENDER";
+
   # Scanner type - must be "metadefender"
   type = "metadefender";
-  
+
   # Scan MIME parts separately instead of full message (recommended: true)
   scan_mime_parts = true;
-  
+
   # Don't scan text or image MIME parts (saves API quota)
   scan_text_mime = false;
   scan_image_mime = false;
-  
+
   # Maximum file size to scan (20MB default)
   max_size = 20000000;
-  
+
   # Log when files are clean (default: false to reduce noise)
   log_clean = false;
-  
+
   # Minimum AV engines that must detect malware before flagging (default: 3)
   # Lower value = more sensitive, may have more false positives
   minimum_engines = 3;
-  
-  # Number of engines at which maximum score is assigned (default: 7)
-  # Scores scale linearly between minimum_engines and full_score_engines
-  full_score_engines = 7;
-  
+
+  # Threshold for low category (default: 5)
+  # Detections from minimum_engines to low_category-1 = LOW
+  low_category = 5;
+
+  # Threshold for medium category (default: 10)
+  # Detections from low_category to medium_category-1 = MEDIUM
+  # Detections >= medium_category = HIGH
+  medium_category = 10;
+
   # HTTP request timeout in seconds
   timeout = 5.0;
-  
+
   # Redis cache expiration (2 hours = 7200 seconds)
   # Longer cache reduces API calls but may miss new detections
   cache_expire = 7200;
-  
+
+  # Symbol categories with scores (can be customized)
+  symbols = {
+    clean = {
+      symbol = "METADEFENDER_CLEAN";
+      score = -0.5;
+      description = "MetaDefender decided attachment to be clean";
+    };
+    low = {
+      symbol = "METADEFENDER_LOW";
+      score = 2.0;
+      description = "MetaDefender found low number of threats (3-4 engines)";
+    };
+    medium = {
+      symbol = "METADEFENDER_MEDIUM";
+      score = 5.0;
+      description = "MetaDefender found medium number of threats (5-9 engines)";
+    };
+    high = {
+      symbol = "METADEFENDER_HIGH";
+      score = 8.0;
+      description = "MetaDefender found high number of threats (10+ engines)";
+    };
+  }
+
   # Optional: Force an action when malware is detected
   # action = "reject";
-  
+
   # Optional: Custom message template
   # message = '${SCANNER}: virus found: "${VIRUS}"';
 }
index 2912f475f5ec4b2e578b9c6f8fb408ea80d85b16..ebe38b8cd7504c88a25fed962da2cf42839fd901 100644 (file)
@@ -59,8 +59,8 @@ antivirus {
     #
     # If `max_size` is set, messages > n bytes in size are not scanned
     #max_size = 20000000;
-    # symbol to add
-    #symbol = "METADEFENDER_VIRUS";
+    # Main symbol (for compatibility, usually not used directly)
+    #symbol = "METADEFENDER";
     # type of scanner
     #type = "metadefender";
     # Your MetaDefender API key (required)
@@ -71,12 +71,88 @@ antivirus {
     #log_clean = false;
     # Minimum number of engines detecting malware for a hit (default 3)
     #minimum_engines = 3;
-    # Number of engines at which we assign full score (default 7)
-    #full_score_engines = 7;
+    # Threshold for low category (default 5)
+    #low_category = 5;
+    # Threshold for medium category (default 10)
+    #medium_category = 10;
     # Request timeout
     #timeout = 5.0;
     # Redis cache expiration time in seconds (default 7200 = 2 hours)
     #cache_expire = 7200;
+    # Symbol categories with scores (can be overridden)
+    #symbols = {
+    #  clean = {
+    #    symbol = "METADEFENDER_CLEAN";
+    #    score = -0.5;
+    #    description = "MetaDefender decided attachment to be clean";
+    #  };
+    #  low = {
+    #    symbol = "METADEFENDER_LOW";
+    #    score = 2.0;
+    #    description = "MetaDefender found low number of threats";
+    #  };
+    #  medium = {
+    #    symbol = "METADEFENDER_MEDIUM";
+    #    score = 5.0;
+    #    description = "MetaDefender found medium number of threats";
+    #  };
+    #  high = {
+    #    symbol = "METADEFENDER_HIGH";
+    #    score = 8.0;
+    #    description = "MetaDefender found high number of threats";
+    #  };
+    #}
+  #}
+
+  #virustotal {
+    # VirusTotal API (hash lookup)
+    # Get your API key at https://www.virustotal.com/
+    #
+    # If `max_size` is set, messages > n bytes in size are not scanned
+    #max_size = 20000000;
+    # Main symbol (for compatibility, usually not used directly)
+    #symbol = "VIRUSTOTAL";
+    # type of scanner
+    #type = "virustotal";
+    # Your VirusTotal API key (required)
+    #apikey = "YOUR_API_KEY_HERE";
+    # Scan mime_parts separately (default true)
+    #scan_mime_parts = true;
+    # You can enable logging for clean messages
+    #log_clean = false;
+    # Minimum number of engines detecting malware for a hit (default 3)
+    #minimum_engines = 3;
+    # Threshold for low category (default 5)
+    #low_category = 5;
+    # Threshold for medium category (default 10)
+    #medium_category = 10;
+    # Request timeout
+    #timeout = 5.0;
+    # Redis cache expiration time in seconds (default 7200 = 2 hours)
+    #cache_expire = 7200;
+    # Symbol categories with scores (can be overridden)
+    #symbols = {
+    #  clean = {
+    #    symbol = "VIRUSTOTAL_CLEAN";
+    #    score = -0.5;
+    #    description = "VirusTotal decided attachment to be clean";
+    #  };
+    #  low = {
+    #    symbol = "VIRUSTOTAL_LOW";
+    #    score = 2.0;
+    #    description = "VirusTotal found low number of threats";
+    #  };
+    #  medium = {
+    #    symbol = "VIRUSTOTAL_MEDIUM";
+    #    score = 5.0;
+    #    description = "VirusTotal found medium number of threats";
+    #  };
+    #  high = {
+    #    symbol = "VIRUSTOTAL_HIGH";
+    #    score = 8.0;
+    #    description = "VirusTotal found high number of threats";
+    #  };
+    #}
   #}
 
   .include(try=true,priority=5) "${DBDIR}/dynamic/antivirus.conf"
index fa751f4f4782282d7bd75500a55c87db3dc06675..f5e760eec2d7f80cc5467cc30f0b804fb3360297 100644 (file)
@@ -13,7 +13,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 --[[[
 -- @module lua_scanners_common
@@ -30,7 +30,6 @@ local fun = require "fun"
 local exports = {}
 
 local function log_clean(task, rule, msg)
-
   msg = msg or 'message or mime_part is clean'
 
   if rule.log_clean then
@@ -38,7 +37,6 @@ local function log_clean(task, rule, msg)
   else
     lua_util.debugm(rule.name, task, '%s: %s', rule.log_prefix, msg)
   end
-
 end
 
 local function match_patterns(default_sym, found, patterns, dyn_weight)
@@ -111,16 +109,15 @@ local function yield_result(task, rule, vname, dyn_weight, is_fail, maybe_part)
     else
       all_whitelisted = false
       rspamd_logger.infox(task, '%s: result - %s: "%s - score: %s"',
-          rule.log_prefix, threat_info, tm, symscore)
+        rule.log_prefix, threat_info, tm, symscore)
 
       if maybe_part and rule.show_attachments and maybe_part:get_filename() then
         local fname = maybe_part:get_filename()
         task:insert_result(symname, symscore, string.format("%s|%s",
-            tm, fname))
+          tm, fname))
       else
         task:insert_result(symname, symscore, tm)
       end
-
     end
   end
 
@@ -130,10 +127,10 @@ local function yield_result(task, rule, vname, dyn_weight, is_fail, maybe_part)
       flags = 'least'
     end
     task:set_pre_result(rule.action,
-        lua_util.template(rule.message or 'Rejected', {
-          SCANNER = rule.name,
-          VIRUS = threat_table,
-        }), rule.name, nil, nil, flags)
+      lua_util.template(rule.message or 'Rejected', {
+        SCANNER = rule.name,
+        VIRUS = threat_table,
+      }), rule.name, nil, nil, flags)
   end
 end
 
@@ -144,7 +141,7 @@ local function message_not_too_large(task, content, rule)
   end
   if #content > max_size then
     rspamd_logger.infox(task, "skip %s check as it is too large: %s (%s is allowed)",
-        rule.log_prefix, #content, max_size)
+      rule.log_prefix, #content, max_size)
     return false
   end
   return true
@@ -157,7 +154,7 @@ local function message_not_too_small(task, content, rule)
   end
   if #content < min_size then
     rspamd_logger.infox(task, "skip %s check as it is too small: %s (%s is allowed)",
-        rule.log_prefix, #content, min_size)
+      rule.log_prefix, #content, min_size)
     return false
   end
   return true
@@ -178,7 +175,7 @@ local function message_min_words(task, rule)
 
     if not text_part_above_limit then
       rspamd_logger.infox(task, '%s: #words in all text parts is below text_part_min_words limit: %s',
-          rule.log_prefix, rule.text_part_min_words)
+        rule.log_prefix, rule.text_part_min_words)
     end
 
     return text_part_above_limit
@@ -217,7 +214,6 @@ local function dynamic_scan(task, rule)
 end
 
 local function need_check(task, content, rule, digest, fn, maybe_part)
-
   local uncached = true
   local key = digest
 
@@ -231,19 +227,30 @@ local function need_check(task, content, rule, digest, fn, maybe_part)
       if threat_string[1] ~= 'OK' then
         if threat_string[1] == 'MACRO' then
           yield_result(task, rule, 'File contains macros',
-              0.0, 'macro', maybe_part)
+            0.0, 'macro', maybe_part)
         elseif threat_string[1] == 'ENCRYPTED' then
           yield_result(task, rule, 'File is encrypted',
-              0.0, 'encrypted', maybe_part)
+            0.0, 'encrypted', maybe_part)
         else
-          lua_util.debugm(rule.name, task, '%s: got cached threat result for %s: %s - score: %s',
+          -- Check if cached data contains symbol name (for category-based scanners)
+          -- Format: "SYMBOL_NAME\vdetails" or just "details"
+          if #threat_string >= 2 and rule.symbols then
+            -- New format with symbol name
+            local symbol_name = threat_string[1]
+            local details = threat_string[2]
+            lua_util.debugm(rule.name, task, '%s: got cached threat result for %s: %s - %s',
+              rule.log_prefix, key, symbol_name, details)
+            task:insert_result(symbol_name, 1.0, details)
+          else
+            -- Old format without symbol name
+            lua_util.debugm(rule.name, task, '%s: got cached threat result for %s: %s - score: %s',
               rule.log_prefix, key, threat_string[1], score)
-          yield_result(task, rule, threat_string, score, false, maybe_part)
+            yield_result(task, rule, threat_string, score, false, maybe_part)
+          end
         end
-
       else
         lua_util.debugm(rule.name, task, '%s: got cached negative result for %s: %s',
-            rule.log_prefix, key, threat_string[1])
+          rule.log_prefix, key, threat_string[1])
       end
       uncached = false
     else
@@ -262,31 +269,26 @@ local function need_check(task, content, rule, digest, fn, maybe_part)
         f_message_not_too_small and
         f_message_min_words and
         f_dynamic_scan then
-
       fn()
-
     end
-
   end
 
   if rule.redis_params and not rule.no_cache then
-
     key = rule.prefix .. key
 
     if lua_redis.redis_make_request(task,
-        rule.redis_params, -- connect params
-        key, -- hash key
-        false, -- is write
-        redis_av_cb, --callback
-        'GET', -- command
-        { key } -- arguments)
-    ) then
+          rule.redis_params, -- connect params
+          key,             -- hash key
+          false,           -- is write
+          redis_av_cb,     --callback
+          'GET',           -- command
+          { key }          -- arguments)
+        ) then
       return true
     end
   end
 
   return false
-
 end
 
 local function save_cache(task, digest, rule, to_save, dyn_weight, maybe_part)
@@ -299,10 +301,10 @@ local function save_cache(task, digest, rule, to_save, dyn_weight, maybe_part)
     -- Do nothing
     if err then
       rspamd_logger.errx(task, 'failed to save %s cache for %s -> "%s": %s',
-          rule.detection_category, to_save, key, err)
+        rule.detection_category, to_save, key, err)
     else
       lua_util.debugm(rule.name, task, '%s: saved cached result for %s: %s - score %s - ttl %s',
-          rule.log_prefix, key, to_save, dyn_weight, rule.cache_expire)
+        rule.log_prefix, key, to_save, dyn_weight, rule.cache_expire)
     end
   end
 
@@ -321,12 +323,12 @@ local function save_cache(task, digest, rule, to_save, dyn_weight, maybe_part)
     key = rule.prefix .. key
 
     lua_redis.redis_make_request(task,
-        rule.redis_params, -- connect params
-        key, -- hash key
-        true, -- is write
-        redis_set_cb, --callback
-        'SETEX', -- command
-        { key, rule.cache_expire or 0, value }
+      rule.redis_params,   -- connect params
+      key,                 -- hash key
+      true,                -- is write
+      redis_set_cb,        --callback
+      'SETEX',             -- command
+      { key, rule.cache_expire or 0, value }
     )
   end
 
@@ -396,7 +398,6 @@ local function gen_extension(fname)
 end
 
 local function check_parts_match(task, rule)
-
   local filter_func = function(p)
     local mtype, msubtype = p:get_type()
     local detected_ext = p:get_detected_ext()
@@ -434,7 +435,7 @@ local function check_parts_match(task, rule)
           return true
         elseif magic.ct and match_filter(task, rule, magic.ct, rule.mime_parts_filter_regex, 'regex') then
           lua_util.debugm(rule.name, task, '%s: regex detected libmagic content-type: %s',
-              rule.log_prefix, magic.ct)
+            rule.log_prefix, magic.ct)
           return true
         end
       end
@@ -489,7 +490,6 @@ local function check_parts_match(task, rule)
 end
 
 local function check_metric_results(task, rule)
-
   if rule.action ~= 'reject' then
     local metric_result = task:get_metric_score()
     local metric_action = task:get_metric_action()
index 6e40d9f335b9e446814b208301abe9bf168eee63..e8345d6d68132108987028006fec032ee65d9bde 100644 (file)
@@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 --[[[
 -- @module metadefender
@@ -29,7 +29,6 @@ local common = require "lua_scanners/common"
 local N = 'metadefender'
 
 local function metadefender_config(opts)
-
   local default_conf = {
     name = N,
     url = 'https://api.metadefender.com/v4/hash',
@@ -44,10 +43,36 @@ local function metadefender_config(opts)
     scan_mime_parts = true,
     scan_text_mime = false,
     scan_image_mime = false,
-    apikey = nil, -- Required to set by user
+    apikey = nil,         -- Required to set by user
     -- Specific for metadefender
-    minimum_engines = 3, -- Minimum required to get scored
-    full_score_engines = 7, -- After this number we set max score
+    minimum_engines = 3,  -- Minimum required to get scored
+    -- Threshold-based categorization
+    low_category = 5,     -- Low threat: minimum_engines to low_category-1
+    medium_category = 10, -- Medium threat: low_category to medium_category-1
+    -- High threat: medium_category and above
+    -- Symbol categories
+    symbols = {
+      clean = {
+        symbol = 'METADEFENDER_CLEAN',
+        score = -0.5,
+        description = 'MetaDefender decided attachment to be clean'
+      },
+      low = {
+        symbol = 'METADEFENDER_LOW',
+        score = 2.0,
+        description = 'MetaDefender found low number of threats'
+      },
+      medium = {
+        symbol = 'METADEFENDER_MEDIUM',
+        score = 5.0,
+        description = 'MetaDefender found medium number of threats'
+      },
+      high = {
+        symbol = 'METADEFENDER_HIGH',
+        score = 8.0,
+        description = 'MetaDefender found high number of threats'
+      },
+    },
   }
 
   default_conf = lua_util.override_defaults(default_conf, opts)
@@ -102,17 +127,16 @@ local function metadefender_check(task, content, digest, rule, maybe_part)
         task:insert_result(rule.symbol_fail, 1.0, 'HTTP error: ' .. http_err)
       else
         local cached
-        local dyn_score
         -- Parse the response
         if code ~= 200 then
           if code == 404 then
             cached = 'OK'
             if rule['log_clean'] then
               rspamd_logger.infox(task, '%s: hash %s clean (not found)',
-                  rule.log_prefix, hash)
+                rule.log_prefix, hash)
             else
               lua_util.debugm(rule.name, task, '%s: hash %s clean (not found)',
-                  rule.log_prefix, hash)
+                rule.log_prefix, hash)
             end
           elseif code == 429 then
             -- Request rate limit exceeded
@@ -130,7 +154,7 @@ local function metadefender_check(task, content, digest, rule, maybe_part)
           local res, json_err = parser:parse_string(body)
 
           lua_util.debugm(rule.name, task, '%s: got reply data: "%s"',
-              rule.log_prefix, body)
+            rule.log_prefix, body)
 
           if res then
             local obj = parser:get_object()
@@ -152,48 +176,64 @@ local function metadefender_check(task, content, digest, rule, maybe_part)
             local total = scan_results.total_avs or 0
 
             if detected == 0 then
-              cached = 'OK'
               if rule['log_clean'] then
                 rspamd_logger.infox(task, '%s: hash %s clean',
-                    rule.log_prefix, hash)
+                  rule.log_prefix, hash)
               else
                 lua_util.debugm(rule.name, task, '%s: hash %s clean',
-                    rule.log_prefix, hash)
+                  rule.log_prefix, hash)
+              end
+              -- Insert CLEAN symbol
+              if rule.symbols and rule.symbols.clean then
+                local clean_sym = rule.symbols.clean.symbol or 'METADEFENDER_CLEAN'
+                local sopt = string.format("%s:0/%s", hash, total)
+                task:insert_result(clean_sym, 1.0, sopt)
+                -- Save with symbol name for proper cache retrieval
+                cached = string.format("%s\v%s", clean_sym, sopt)
+              else
+                cached = 'OK'
               end
             else
               if detected < rule.minimum_engines then
                 lua_util.debugm(rule.name, task, '%s: hash %s has not enough hits: %s where %s is min',
-                    rule.log_prefix, hash, detected, rule.minimum_engines)
+                  rule.log_prefix, hash, detected, rule.minimum_engines)
                 cached = 'OK'
               else
-                if detected >= rule.full_score_engines then
-                  dyn_score = 1.0
+                -- Determine category based on detection count
+                local category
+                local category_sym
+                local sopt = string.format("%s:%s/%s", hash, detected, total)
+
+                if detected >= rule.medium_category then
+                  category = 'high'
+                  category_sym = rule.symbols.high.symbol or 'METADEFENDER_HIGH'
+                elseif detected >= rule.low_category then
+                  category = 'medium'
+                  category_sym = rule.symbols.medium.symbol or 'METADEFENDER_MEDIUM'
                 else
-                  local norm_detected = detected - rule.minimum_engines
-                  dyn_score = norm_detected / (rule.full_score_engines - rule.minimum_engines)
+                  category = 'low'
+                  category_sym = rule.symbols.low.symbol or 'METADEFENDER_LOW'
                 end
 
-                if dyn_score < 0 or dyn_score > 1 then
-                  dyn_score = 1.0
-                end
+                rspamd_logger.infox(task, '%s: result - %s: "%s" - category: %s',
+                  rule.log_prefix, rule.detection_category .. 'found', sopt, category)
 
-                local sopt = string.format("%s:%s/%s",
-                    hash, detected, total)
-                common.yield_result(task, rule, sopt, dyn_score, nil, maybe_part)
-                cached = sopt
+                task:insert_result(category_sym, 1.0, sopt)
+                -- Save with symbol name for proper cache retrieval
+                cached = string.format("%s\v%s", category_sym, sopt)
               end
             end
           else
             -- not res
             rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s',
-                json_err, body, headers)
+              json_err, body, headers)
             task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: ' .. json_err)
             return
           end
         end
 
         if cached then
-          common.save_cache(task, digest, rule, cached, dyn_score, maybe_part)
+          common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
         end
       end
     end
@@ -203,13 +243,11 @@ local function metadefender_check(task, content, digest, rule, maybe_part)
   end
 
   if common.condition_check_and_continue(task, content, rule, digest,
-      metadefender_check_uncached) then
+        metadefender_check_uncached) then
     return
   else
-
     metadefender_check_uncached()
   end
-
 end
 
 return {
index d937c41288663924e78b703d69abdcbd0015c357..5893e70fb09f332c0085afb12ed0c2735914d2d5 100644 (file)
@@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 --[[[
 -- @module virustotal
@@ -29,7 +29,6 @@ local common = require "lua_scanners/common"
 local N = 'virustotal'
 
 local function virustotal_config(opts)
-
   local default_conf = {
     name = N,
     url = 'https://www.virustotal.com/vtapi/v2/file',
@@ -44,10 +43,36 @@ local function virustotal_config(opts)
     scan_mime_parts = true,
     scan_text_mime = false,
     scan_image_mime = false,
-    apikey = nil, -- Required to set by user
+    apikey = nil,         -- Required to set by user
     -- Specific for virustotal
-    minimum_engines = 3, -- Minimum required to get scored
-    full_score_engines = 7, -- After this number we set max score
+    minimum_engines = 3,  -- Minimum required to get scored
+    -- Threshold-based categorization
+    low_category = 5,     -- Low threat: minimum_engines to low_category-1
+    medium_category = 10, -- Medium threat: low_category to medium_category-1
+    -- High threat: medium_category and above
+    -- Symbol categories
+    symbols = {
+      clean = {
+        symbol = 'VIRUSTOTAL_CLEAN',
+        score = -0.5,
+        description = 'VirusTotal decided attachment to be clean'
+      },
+      low = {
+        symbol = 'VIRUSTOTAL_LOW',
+        score = 2.0,
+        description = 'VirusTotal found low number of threats'
+      },
+      medium = {
+        symbol = 'VIRUSTOTAL_MEDIUM',
+        score = 5.0,
+        description = 'VirusTotal found medium number of threats'
+      },
+      high = {
+        symbol = 'VIRUSTOTAL_HIGH',
+        score = 8.0,
+        description = 'VirusTotal found high number of threats'
+      },
+    },
   }
 
   default_conf = lua_util.override_defaults(default_conf, opts)
@@ -78,7 +103,7 @@ local function virustotal_check(task, content, digest, rule, maybe_part)
   local function virustotal_check_uncached()
     local function make_url(hash)
       return string.format('%s/report?apikey=%s&resource=%s',
-          rule.url, rule.apikey, hash)
+        rule.url, rule.apikey, hash)
     end
 
     local hash = rspamd_cryptobox_hash.create_specific('md5')
@@ -98,17 +123,16 @@ local function virustotal_check(task, content, digest, rule, maybe_part)
         rspamd_logger.errx(task, 'HTTP error: %s, body: %s, headers: %s', http_err, body, headers)
       else
         local cached
-        local dyn_score
         -- Parse the response
         if code ~= 200 then
           if code == 404 then
             cached = 'OK'
             if rule['log_clean'] then
               rspamd_logger.infox(task, '%s: hash %s clean (not found)',
-                  rule.log_prefix, hash)
+                rule.log_prefix, hash)
             else
               lua_util.debugm(rule.name, task, '%s: hash %s clean (not found)',
-                  rule.log_prefix, hash)
+                rule.log_prefix, hash)
             end
           elseif code == 204 then
             -- Request rate limit exceeded
@@ -126,67 +150,101 @@ local function virustotal_check(task, content, digest, rule, maybe_part)
           local res, json_err = parser:parse_string(body)
 
           lua_util.debugm(rule.name, task, '%s: got reply data: "%s"',
-              rule.log_prefix, body)
+            rule.log_prefix, body)
 
           if res then
             local obj = parser:get_object()
             if not obj.positives or type(obj.positives) ~= 'number' then
               if obj.response_code then
                 if obj.response_code == 0 then
-                  cached = 'OK'
                   if rule['log_clean'] then
                     rspamd_logger.infox(task, '%s: hash %s clean (not found)',
-                        rule.log_prefix, hash)
+                      rule.log_prefix, hash)
                   else
                     lua_util.debugm(rule.name, task, '%s: hash %s clean (not found)',
-                        rule.log_prefix, hash)
+                      rule.log_prefix, hash)
+                  end
+                  -- Insert CLEAN symbol
+                  if rule.symbols and rule.symbols.clean then
+                    local clean_sym = rule.symbols.clean.symbol or 'VIRUSTOTAL_CLEAN'
+                    local sopt = string.format("%s:0", hash)
+                    task:insert_result(clean_sym, 1.0, sopt)
+                    -- Save with symbol name for proper cache retrieval
+                    cached = string.format("%s\v%s", clean_sym, sopt)
+                  else
+                    cached = 'OK'
                   end
                 else
                   rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s',
-                      'bad response code: ' .. tostring(obj.response_code), body, headers)
+                    'bad response code: ' .. tostring(obj.response_code), body, headers)
                   task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: no `positives` element')
                   return
                 end
               else
                 rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s',
-                    'no response_code', body, headers)
+                  'no response_code', body, headers)
                 task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: no `positives` element')
                 return
               end
             else
-              if obj.positives < rule.minimum_engines then
+              if obj.positives == 0 then
+                if rule['log_clean'] then
+                  rspamd_logger.infox(task, '%s: hash %s clean',
+                    rule.log_prefix, hash)
+                else
+                  lua_util.debugm(rule.name, task, '%s: hash %s clean',
+                    rule.log_prefix, hash)
+                end
+                -- Insert CLEAN symbol
+                if rule.symbols and rule.symbols.clean then
+                  local clean_sym = rule.symbols.clean.symbol or 'VIRUSTOTAL_CLEAN'
+                  local sopt = string.format("%s:0/%s", hash, obj.total or 0)
+                  task:insert_result(clean_sym, 1.0, sopt)
+                  -- Save with symbol name for proper cache retrieval
+                  cached = string.format("%s\v%s", clean_sym, sopt)
+                else
+                  cached = 'OK'
+                end
+              elseif obj.positives < rule.minimum_engines then
                 lua_util.debugm(rule.name, task, '%s: hash %s has not enough hits: %s where %s is min',
-                    rule.log_prefix, obj.positives, rule.minimum_engines)
-                -- TODO: add proper hashing!
+                  rule.log_prefix, hash, obj.positives, rule.minimum_engines)
                 cached = 'OK'
               else
-                if obj.positives > rule.full_score_engines then
-                  dyn_score = 1.0
+                -- Determine category based on detection count
+                local category
+                local category_sym
+                local sopt = string.format("%s:%s/%s", hash, obj.positives, obj.total)
+
+                if obj.positives >= rule.medium_category then
+                  category = 'high'
+                  category_sym = rule.symbols.high.symbol or 'VIRUSTOTAL_HIGH'
+                elseif obj.positives >= rule.low_category then
+                  category = 'medium'
+                  category_sym = rule.symbols.medium.symbol or 'VIRUSTOTAL_MEDIUM'
                 else
-                  local norm_pos = obj.positives - rule.minimum_engines
-                  dyn_score = norm_pos / (rule.full_score_engines - rule.minimum_engines)
+                  category = 'low'
+                  category_sym = rule.symbols.low.symbol or 'VIRUSTOTAL_LOW'
                 end
 
-                if dyn_score < 0 or dyn_score > 1 then
-                  dyn_score = 1.0
-                end
-                local sopt = string.format("%s:%s/%s",
-                    hash, obj.positives, obj.total)
-                common.yield_result(task, rule, sopt, dyn_score, nil, maybe_part)
-                cached = sopt
+                rspamd_logger.infox(task, '%s: result - %s: "%s" - category: %s',
+                  rule.log_prefix, rule.detection_category .. 'found', sopt, category)
+
+                task:insert_result(category_sym, 1.0, sopt)
+                -- Save with symbol name for proper cache retrieval
+                cached = string.format("%s\v%s", category_sym, sopt)
               end
             end
           else
             -- not res
             rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s',
-                json_err, body, headers)
+              json_err, body, headers)
             task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: ' .. json_err)
             return
           end
         end
 
         if cached then
-          common.save_cache(task, digest, rule, cached, dyn_score, maybe_part)
+          common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
         end
       end
     end
@@ -196,13 +254,11 @@ local function virustotal_check(task, content, digest, rule, maybe_part)
   end
 
   if common.condition_check_and_continue(task, content, rule, digest,
-      virustotal_check_uncached) then
+        virustotal_check_uncached) then
     return
   else
-
     virustotal_check_uncached()
   end
-
 end
 
 return {
index 5337f6666b50ff9ebd41758e3025a17737b101f5..1d4b8349359c591fa9f921ad155ace0a3e1243f3 100644 (file)
@@ -27,8 +27,8 @@ local N = "antivirus"
 
 if confighelp then
   rspamd_config:add_example(nil, 'antivirus',
-      "Check messages for viruses",
-      [[
+    "Check messages for viruses",
+    [[
   antivirus {
     # multiple scanners could be checked, for each we create a configuration block with an arbitrary name
     clamav {
@@ -75,7 +75,7 @@ end
 
 -- Encode as base32 in the source to avoid crappy stuff
 local eicar_pattern = rspamd_util.decode_base32(
-    [[akp6woykfbonrepmwbzyfpbmibpone3mj3pgwbffzj9e1nfjdkorisckwkohrnfe1nt41y3jwk1cirjki4w4nkieuni4ndfjcktnn1yjmb1wn]]
+  [[akp6woykfbonrepmwbzyfpbmibpone3mj3pgwbffzj9e1nfjdkorisckwkohrnfe1nt41y3jwk1cirjki4w4nkieuni4ndfjcktnn1yjmb1wn]]
 )
 
 local function add_antivirus_rule(sym, opts)
@@ -91,7 +91,7 @@ local function add_antivirus_rule(sym, opts)
 
   if not cfg then
     rspamd_logger.errx(rspamd_config, 'unknown antivirus type: %s',
-        opts.type)
+      opts.type)
     return nil
   end
 
@@ -109,7 +109,7 @@ local function add_antivirus_rule(sym, opts)
   if opts.attachments_only ~= nil then
     opts.scan_mime_parts = opts.attachments_only
     rspamd_logger.warnx(rspamd_config, '%s [%s]: Using attachments_only is deprecated. ' ..
-        'Please use scan_mime_parts = %s instead', opts.symbol, opts.type, opts.attachments_only)
+      'Please use scan_mime_parts = %s instead', opts.symbol, opts.type, opts.attachments_only)
   end
   -- WORKAROUND for deprecated attachments_only
 
@@ -123,9 +123,12 @@ local function add_antivirus_rule(sym, opts)
   rule.symbol_encrypted = opts.symbol_encrypted
   rule.redis_params = redis_params
 
+  -- Store rule for symbol registration later
+  rule.symbol_main = opts.symbol
+
   if not rule then
     rspamd_logger.errx(rspamd_config, 'cannot configure %s for %s',
-        opts.type, opts.symbol)
+      opts.type, opts.symbol)
     return nil
   end
 
@@ -133,10 +136,10 @@ local function add_antivirus_rule(sym, opts)
   rule.patterns_fail = common.create_regex_table(opts.patterns_fail or {})
 
   lua_redis.register_prefix(rule.prefix .. '_*', N,
-      string.format('Antivirus cache for rule "%s"',
-          rule.type), {
-        type = 'string',
-      })
+    string.format('Antivirus cache for rule "%s"',
+      rule.type), {
+      type = 'string',
+    })
 
   -- if any mime_part filter defined, do not scan all attachments
   if opts.mime_parts_filter_regex ~= nil
@@ -157,9 +160,9 @@ local function add_antivirus_rule(sym, opts)
     rule.whitelist = rspamd_config:add_hash_map(opts.whitelist)
   end
 
-  return function(task)
+  -- Return both callback and rule for symbol registration
+  local cb = function(task)
     if rule.scan_mime_parts then
-
       fun.each(function(p)
         local content = p:get_content()
         local clen = #content
@@ -173,18 +176,19 @@ local function add_antivirus_rule(sym, opts)
 
             if clen == #opts.eicar_fake_pattern and content == opts.eicar_fake_pattern then
               rspamd_logger.infox(task, 'found eicar fake replacement part in the part (filename="%s")',
-                  p:get_filename())
+                p:get_filename())
               content = eicar_pattern
             end
           end
           cfg.check(task, content, p:get_digest(), rule, p)
         end
       end, common.check_parts_match(task, rule))
-
     else
       cfg.check(task, task:get_content(), task:get_digest(), rule)
     end
   end
+
+  return cb, rule
 end
 
 -- Registration
@@ -200,15 +204,15 @@ if opts and type(opts) == 'table' then
       if not m.name then
         m.name = k
       end
-      local cb = add_antivirus_rule(k, m)
+      local cb, rule = add_antivirus_rule(k, m)
 
       if not cb then
         rspamd_logger.errx(rspamd_config, 'cannot add rule: "' .. k .. '"')
         lua_util.config_utils.push_config_error(N, 'cannot add AV rule: "' .. k .. '"')
       else
-        rspamd_logger.infox(rspamd_config, 'added antivirus engine %s -> %s', k, m.symbol)
+        rspamd_logger.infox(rspamd_config, 'added antivirus engine %s -> %s', k, rule.symbol or m.symbol)
         local t = {
-          name = m.symbol,
+          name = rule.symbol or m.symbol,
           callback = cb,
           score = 0.0,
           group = N
@@ -233,27 +237,27 @@ if opts and type(opts) == 'table' then
 
         rspamd_config:register_symbol({
           type = 'virtual',
-          name = m['symbol_fail'],
+          name = rule.symbol_fail or m['symbol_fail'],
           parent = id,
           score = 0.0,
           group = N
         })
         rspamd_config:register_symbol({
           type = 'virtual',
-          name = m['symbol_encrypted'],
+          name = rule.symbol_encrypted or m['symbol_encrypted'],
           parent = id,
           score = 0.0,
           group = N
         })
         rspamd_config:register_symbol({
           type = 'virtual',
-          name = m['symbol_macro'],
+          name = rule.symbol_macro or m['symbol_macro'],
           parent = id,
           score = 0.0,
           group = N
         })
         has_valid = true
-        if type(m['patterns']) == 'table' then
+        if type(rule.patterns) == 'table' and type(m['patterns']) == 'table' then
           if m['patterns'][1] then
             for _, p in ipairs(m['patterns']) do
               if type(p) == 'table' then
@@ -321,6 +325,48 @@ if opts and type(opts) == 'table' then
             end
           end
         end
+        if rule.symbols then
+          rspamd_logger.infox(rspamd_config, 'registering category symbols for %s', rule.name)
+          local function reg_symbols(tbl)
+            for _, sym in pairs(tbl) do
+              if type(sym) == 'string' then
+                rspamd_logger.infox(rspamd_config, 'registering symbol: %s (string)', sym)
+                rspamd_config:register_symbol({
+                  type = 'virtual',
+                  name = sym,
+                  parent = id,
+                  group = N
+                })
+              elseif type(sym) == 'table' then
+                if sym.symbol then
+                  rspamd_logger.infox(rspamd_config, 'registering symbol: %s with score %s',
+                    sym.symbol, sym.score or 'default')
+                  rspamd_config:register_symbol({
+                    type = 'virtual',
+                    name = sym.symbol,
+                    parent = id,
+                    group = N
+                  })
+
+                  if sym.score then
+                    rspamd_config:set_metric_symbol({
+                      name = sym.symbol,
+                      score = sym.score,
+                      description = sym.description,
+                      group = sym.group or N,
+                    })
+                  end
+                else
+                  reg_symbols(sym)
+                end
+              end
+            end
+          end
+
+          reg_symbols(rule.symbols)
+        else
+          rspamd_logger.infox(rspamd_config, 'no category symbols defined for %s', rule.name)
+        end
         if m['score'] then
           -- Register metric symbol
           local description = 'antivirus symbol'