From: Eric Leblond Date: Wed, 24 Aug 2022 11:57:56 +0000 (+0200) Subject: tld: add new transform X-Git-Tag: suricata-8.0.0-beta1~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=585c624482e7b1f0b5a944b2d3c093d3cfab2b2f;p=thirdparty%2Fsuricata.git tld: add new transform Extract the TLD from a buffer. Ticket: #5639 --- diff --git a/rust/src/domain.rs b/rust/src/domain.rs index f55f07075b..bfdccdc77e 100644 --- a/rust/src/domain.rs +++ b/rust/src/domain.rs @@ -31,3 +31,18 @@ pub unsafe extern "C" fn rs_get_domain(input: *const u8, len: u32, output: *mut None => false } } + +#[no_mangle] +pub unsafe extern "C" fn rs_get_tld(input: *const u8, len: u32, output: *mut u8, olen: *mut u64) -> bool { + let slice: &[u8] = std::slice::from_raw_parts(input as *mut u8, len as usize); + let result = psl::domain(slice); + match result { + Some(x) => { + let tld = x.suffix().as_bytes(); + ptr::copy(tld.as_ptr(), output, tld.len()); + *olen = tld.len() as u64; + true + }, + None => false + } +} diff --git a/src/detect-engine-register.h b/src/detect-engine-register.h index 97e2485728..08246a4f19 100644 --- a/src/detect-engine-register.h +++ b/src/detect-engine-register.h @@ -319,6 +319,7 @@ enum DetectKeywordId { DETECT_TRANSFORM_HEADER_LOWERCASE, DETECT_TRANSFORM_FROM_BASE64, DETECT_TRANSFORM_DOMAIN, + DETECT_TRANSFORM_TLD, DETECT_IKE_EXCH_TYPE, DETECT_IKE_SPI_INITIATOR, diff --git a/src/detect-transform-domain.c b/src/detect-transform-domain.c index a66773b875..55bb1e4fdb 100644 --- a/src/detect-transform-domain.c +++ b/src/detect-transform-domain.c @@ -39,10 +39,13 @@ #include "rust.h" static int DetectTransformDomainSetup(DetectEngineCtx *, Signature *, const char *); +static int DetectTransformTLDSetup(DetectEngineCtx *, Signature *, const char *); #ifdef UNITTESTS static void DetectTransformDomainRegisterTests(void); +static void DetectTransformTLDRegisterTests(void); #endif static void TransformDomain(DetectEngineThreadCtx *ctx, InspectionBuffer *buffer, void *options); +static void TransformTLD(DetectEngineThreadCtx *ctx, InspectionBuffer *buffer, void *options); void DetectTransformDomainRegister(void) { @@ -55,6 +58,16 @@ void DetectTransformDomainRegister(void) sigmatch_table[DETECT_TRANSFORM_DOMAIN].RegisterTests = DetectTransformDomainRegisterTests; #endif sigmatch_table[DETECT_TRANSFORM_DOMAIN].flags |= SIGMATCH_NOOPT; + + sigmatch_table[DETECT_TRANSFORM_TLD].name = "tld"; + sigmatch_table[DETECT_TRANSFORM_TLD].desc = "modify buffer to extract the tld"; + sigmatch_table[DETECT_TRANSFORM_TLD].url = "/rules/transforms.html#tld"; + sigmatch_table[DETECT_TRANSFORM_TLD].Transform = TransformTLD; + sigmatch_table[DETECT_TRANSFORM_TLD].Setup = DetectTransformTLDSetup; +#ifdef UNITTESTS + sigmatch_table[DETECT_TRANSFORM_TLD].RegisterTests = DetectTransformTLDRegisterTests; +#endif + sigmatch_table[DETECT_TRANSFORM_TLD].flags |= SIGMATCH_NOOPT; } /** @@ -92,6 +105,41 @@ static void TransformDomain(DetectEngineThreadCtx *ctx, InspectionBuffer *buffer } } +/** + * \internal + * \brief Extract the dotprefix, if any, the last pattern match, either content or uricontent + * \param det_ctx detection engine ctx + * \param s signature + * \param nullstr should be null + * \retval 0 ok + * \retval -1 failure + */ +static int DetectTransformTLDSetup(DetectEngineCtx *de_ctx, Signature *s, const char *nullstr) +{ + SCEnter(); + int r = DetectSignatureAddTransform(s, DETECT_TRANSFORM_TLD, NULL); + SCReturnInt(r); +} + +/** + * \brief Return the domain, if any, in the last pattern match. + * + */ +static void TransformTLD(DetectEngineThreadCtx *ctx, InspectionBuffer *buffer, void *options) +{ + const size_t input_len = buffer->inspect_len; + uint64_t output_len = 0; + + if (input_len) { + uint8_t output[input_len]; + + bool res = rs_get_tld(buffer->inspect, input_len, output, &output_len); + if (res == true) { + InspectionBufferCopy(buffer, output, output_len); + } + } +} + #ifdef UNITTESTS static int DetectTransformDomainTest01(void) { @@ -158,4 +206,69 @@ static void DetectTransformDomainRegisterTests(void) UtRegisterTest("DetectTransformDomainTest02", DetectTransformDomainTest02); UtRegisterTest("DetectTransformDomainTest03", DetectTransformDomainTest03); } + +static int DetectTransformTLDTest01(void) +{ + const uint8_t *input = (const uint8_t *)"www.example.com"; + uint32_t input_len = strlen((char *)input); + + const char *result = "com"; + uint32_t result_len = strlen((char *)result); + + InspectionBuffer buffer; + InspectionBufferInit(&buffer, input_len); + InspectionBufferSetup(NULL, -1, &buffer, input, input_len); + PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len); + TransformTLD(NULL, &buffer, NULL); + PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len); + FAIL_IF_NOT(buffer.inspect_len == result_len); + FAIL_IF_NOT(strncmp(result, (const char *)buffer.inspect, result_len) == 0); + InspectionBufferFree(&buffer); + PASS; +} + +static int DetectTransformTLDTest02(void) +{ + const uint8_t *input = (const uint8_t *)"hello.example.co.uk"; + uint32_t input_len = strlen((char *)input); + + const char *result = "co.uk"; + uint32_t result_len = strlen((char *)result); + + InspectionBuffer buffer; + InspectionBufferInit(&buffer, input_len); + InspectionBufferSetup(NULL, -1, &buffer, input, input_len); + PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len); + TransformTLD(NULL, &buffer, NULL); + PrintRawDataFp(stdout, buffer.inspect, buffer.inspect_len); + FAIL_IF_NOT(buffer.inspect_len == result_len); + FAIL_IF_NOT(strncmp(result, (const char *)buffer.inspect, result_len) == 0); + InspectionBufferFree(&buffer); + PASS; +} + +static int DetectTransformTLDTest03(void) +{ + const char rule[] = "alert dns any any -> any any (dns.query; tld; content:\"com\"; sid:1;)"; + ThreadVars th_v; + DetectEngineThreadCtx *det_ctx = NULL; + memset(&th_v, 0, sizeof(th_v)); + + DetectEngineCtx *de_ctx = DetectEngineCtxInit(); + FAIL_IF_NULL(de_ctx); + Signature *s = DetectEngineAppendSig(de_ctx, rule); + FAIL_IF_NULL(s); + SigGroupBuild(de_ctx); + DetectEngineThreadCtxInit(&th_v, (void *)de_ctx, (void *)&det_ctx); + DetectEngineThreadCtxDeinit(&th_v, (void *)det_ctx); + DetectEngineCtxFree(de_ctx); + PASS; +} + +static void DetectTransformTLDRegisterTests(void) +{ + UtRegisterTest("DetectTransformTLDTest01", DetectTransformTLDTest01); + UtRegisterTest("DetectTransformTLDTest02", DetectTransformTLDTest02); + UtRegisterTest("DetectTransformTLDTest03", DetectTransformTLDTest03); +} #endif