#include "main/snort_config.h"
#include "protocols/gre.h"
+#include "checksum.h"
+
#ifdef UNIT_TEST
#include "catch/snort_catch.h"
#endif
void log(TextLog* const, const uint8_t* pkt, const uint16_t len) override;
bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
EncState&, Buffer&, Flow*) override;
+ void update(const ip::IpApi& api, const EncodeFlags flags, uint8_t* raw_pkt,
+ uint16_t lyr_len, uint32_t& updated_len) override;
};
static const uint32_t GRE_HEADER_LEN = 4;
* see RFCs 1701, 2784 and 2637
*/
-bool GreCodec::encode(const uint8_t* const raw_in, const uint16_t raw_len,
- EncState& enc, Buffer& buf, Flow*)
-
+void GreCodec::update(const ip::IpApi& api, const EncodeFlags /*flags*/, uint8_t* raw_pkt,
+ uint16_t lyr_len, uint32_t& updated_len)
{
- if (raw_len > GRE_HEADER_LEN)
+ UNUSED(api);
+ gre::GREHdr* const greh = reinterpret_cast<gre::GREHdr*>(raw_pkt);
+
+ updated_len += lyr_len;
+
+ if (GRE_CHKSUM(greh))
{
- ErrorMessage("Invalid GRE header length: %u",raw_len);
- return false;
+ assert(lyr_len >= 6);
+ // Checksum field is zero for computing checksum
+ *(uint16_t*)(raw_pkt + 4) = 0;
+ *(uint16_t*)(raw_pkt + 4) = checksum::cksum_add((uint16_t*)raw_pkt, updated_len);
}
+}
+bool GreCodec::encode(const uint8_t* const raw_in, const uint16_t raw_len,
+ EncState& enc, Buffer& buf, Flow*)
+{
if (!buf.allocate(raw_len))
return false;
enc.next_proto = IpProtocol::GRE;
enc.next_ethertype = greh_out->proto();
- GRE_CHKSUM(greh_out);
+ if (GRE_SEQ(greh_out))
+ {
+ uint16_t len = 4; // Flags, version and protocol
+
+ if (GRE_CHKSUM(greh_out))
+ len += 4;
+
+ if (GRE_KEY(greh_out))
+ len += 4;
+
+ *(uint32_t*)(buf.data() + len) += ntohl(1);
+ }
+
+ if (GRE_CHKSUM(greh_out))
+ {
+ assert(raw_len >= 6);
+ // Checksum field is zero for computing checksum
+ *(uint16_t*)(buf.data() + 4) = 0;
+ *(uint16_t*)(buf.data() + 4) = checksum::cksum_add((uint16_t*)buf.data(),
+ buf.size());
+ }
return true;
}