]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Add a database check function unless we have anything from Hyperscan
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 2 Mar 2023 11:07:28 +0000 (11:07 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 2 Mar 2023 11:07:28 +0000 (11:07 +0000)
Related: https://github.com/intel/hyperscan/issues/389

src/libserver/hyperscan_tools.cxx

index 1953e9b32b22a231a3f397ea3ce7c6a14028a256..5748138412670f7464e7e3bee9b2de600f81235c 100644 (file)
 
 #define HYPERSCAN_LOG_TAG "hsxxxx"
 
+// Hyperscan does not provide any API to check validity of it's databases
+// However, it is required for us to perform migrations properly without
+// failing at `hs_alloc_scratch` phase or even `hs_scan` which is **way too late**
+// Hence, we have to check hyperscan internal guts to prevent that situation...
+
+#ifndef HS_VERSION_32BIT
+#define HS_VERSION_32BIT ((HS_MAJOR << 24) | (HS_MINOR << 16) | (HS_PATCH << 8) | 0)
+#endif
+#ifndef HS_DB_VERSION
+#define HS_DB_VERSION HS_VERSION_32BIT
+#endif
+
+#ifndef HS_DB_MAGIC
+#define HS_DB_MAGIC   (0xdbdbdbdbU)
+#endif
+
 #define msg_info_hyperscan(...)   rspamd_default_log_function (G_LOG_LEVEL_INFO, \
         "hyperscan", HYPERSCAN_LOG_TAG, \
         RSPAMD_LOG_FUNC, \
@@ -141,8 +157,8 @@ public:
                }
 
                auto is_known = known_cached_files.insert(file.get_name());
-               msg_debug_hyperscan("added %s known hyperscan file: %*s",
-                       is_known.second ? "new" : "already",
+               msg_debug_hyperscan("added %s hyperscan file: %*s",
+                       is_known.second ? "new" : "already known",
                        (int)file.get_name().size(),
                        file.get_name().data());
        }
@@ -178,8 +194,8 @@ public:
                }
 
                auto is_known = known_cached_files.insert(mut_fname);
-               msg_debug_hyperscan("added %s known hyperscan file: %s",
-                       is_known.second ? "new" : "already",
+               msg_debug_hyperscan("added %s hyperscan file: %s",
+                       is_known.second ? "new" : "already known",
                        mut_fname.c_str());
        }
 
@@ -314,36 +330,54 @@ struct hs_shared_database {
        }
 };
 
+struct real_hs_db {
+       std::uint32_t magic;
+       std::uint32_t version;
+       std::uint32_t length;
+       std::uint64_t platform;
+       std::uint32_t crc32;
+};
 static auto
-hs_shared_from_unserialized(raii_mmaped_file &&map) -> tl::expected<hs_shared_database, error>
+hs_is_valid_database(void *raw, std::size_t len, std::string_view fname) -> tl::expected<bool, std::string>
 {
-       auto ptr = map.get_map();
-       auto db = (hs_database_t *)ptr;
+       if (len < sizeof(real_hs_db)) {
+               return tl::make_unexpected(fmt::format("cannot load hyperscan database from {}: too short", fname));
+       }
 
-       char *info = nullptr;
-       // Check HS database sanity (see #4409 for details)
-       auto ret = hs_database_info(db, &info);
+       static real_hs_db test;
 
-       if (ret != HS_SUCCESS) {
-               if (info) {
-                       g_free (info);
-               }
-               return tl::make_unexpected(
-                       error{fmt::format("cannot use database {}: error code: {}", map.get_file().get_name(), ret),
-                                                                                ret, error_category::IMPORTANT});
+       memcpy(&test, raw, sizeof(test));
+
+       if (test.magic != HS_DB_MAGIC) {
+               return tl::make_unexpected(fmt::format("cannot load hyperscan database from {}: invalid magic: {} ({} expected)",
+                       fname, test.magic, HS_DB_MAGIC));
+       }
+
+       if (test.version != HS_DB_VERSION) {
+               return tl::make_unexpected(fmt::format("cannot load hyperscan database from {}: invalid version: {} ({} expected)",
+                       fname, test.version, HS_DB_VERSION));
        }
 
-       msg_debug_hyperscan("database: %s, info: %s", map.get_file().get_name(), info);
+       return true;
+}
+
+static auto
+hs_shared_from_unserialized(hs_known_files_cache &hs_cache, raii_mmaped_file &&map) -> tl::expected<hs_shared_database, error>
+{
+       auto ptr = map.get_map();
+       auto db = (hs_database_t *)ptr;
 
-       if (info) {
-               g_free (info);
+       auto is_valid = hs_is_valid_database(map.get_map(), map.get_size(), map.get_file().get_name());
+       if (!is_valid) {
+               return tl::make_unexpected(error{is_valid.error(), -1, error_category::IMPORTANT});
        }
 
+       hs_cache.add_cached_file(map.get_file());
        return tl::expected<hs_shared_database, error>{tl::in_place, std::move(map), db};
 }
 
 static auto
-hs_shared_from_serialized(raii_mmaped_file &&map, std::int64_t offset) -> tl::expected<hs_shared_database, error>
+hs_shared_from_serialized(hs_known_files_cache &hs_cache, raii_mmaped_file &&map, std::int64_t offset) -> tl::expected<hs_shared_database, error>
 {
        hs_database_t *target = nullptr;
 
@@ -352,6 +386,7 @@ hs_shared_from_serialized(raii_mmaped_file &&map, std::int64_t offset) -> tl::ex
                return tl::make_unexpected(error {"cannot deserialize database", ret});
        }
 
+       hs_cache.add_cached_file(map.get_file());
        return tl::expected<hs_shared_database, error>{tl::in_place, target, map.get_file().get_name().data()};
 }
 
@@ -463,22 +498,20 @@ auto load_cached_hs_file(const char *fname, std::int64_t offset = 0) -> tl::expe
                                         * being created by another process.
                                         * We cannot use it!
                                         */
-                                       return hs_shared_from_serialized(std::forward<T>(cached_serialized), offset);
+                                       return hs_shared_from_serialized(hs_cache, std::forward<T>(cached_serialized), offset);
                                }
                                else {
-                                       hs_cache.add_cached_file(unserialized_checked);
                                        return raii_mmaped_file::mmap_shared(std::move(unserialized_checked), PROT_READ)
                                                .and_then([&]<class U>(U &&mmapped_unserialized) -> auto {
-                                                       return hs_shared_from_unserialized(std::forward<U>(mmapped_unserialized));
+                                                       return hs_shared_from_unserialized(hs_cache, std::forward<U>(mmapped_unserialized));
                                                });
                                }
                        }
                        else {
-                               return hs_shared_from_serialized(std::forward<T>(cached_serialized), offset);
+                               return hs_shared_from_serialized(hs_cache, std::forward<T>(cached_serialized), offset);
                        }
 #else // defined(HS_MAJOR) && defined(HS_MINOR) && HS_MAJOR >= 5 && HS_MINOR >= 4
-                       hs_cache.add_cached_file(cached_serialized.get_file());
-                       return hs_shared_from_serialized(std::forward<T>(cached_serialized), offset);
+                       return hs_shared_from_serialized(hs_cache, std::forward<T>(cached_serialized), offset);
 #endif // defined(HS_MAJOR) && defined(HS_MINOR) && HS_MAJOR >= 5 && HS_MINOR >= 4
                });
 }