#include "util-lua-common.h"
#include "util-lua-tls.h"
-static int GetCertNotBefore(lua_State *luastate, const Flow *f, int direction)
+static const char tls_state_mt[] = "suricata:tls";
+
+struct LuaTls {
+ const SSLState *state; // state
+};
+
+static int LuaTlsFlowStateGet(lua_State *luastate)
{
+ if (!LuaStateNeedProto(luastate, ALPROTO_TLS)) {
+ return LuaCallbackError(luastate, "error: protocol not tls");
+ }
+ Flow *f = LuaStateGetFlow(luastate);
+ if (f == NULL) {
+ LUA_ERROR("failed to get flow");
+ }
+
+ struct LuaTls *s = (struct LuaTls *)lua_newuserdata(luastate, sizeof(*s));
+ if (s == NULL) {
+ LUA_ERROR("failed to allocate userdata");
+ }
+
void *state = FlowGetAppState(f);
if (state == NULL)
return LuaCallbackError(luastate, "error: no app layer state");
+ s->state = (const SSLState *)state;
+ luaL_getmetatable(luastate, tls_state_mt);
+ lua_setmetatable(luastate, -2);
+ return 1;
+}
- SSLState *ssl_state = (SSLState *)state;
- SSLStateConnp *connp = NULL;
+static int GetCertNotBefore(lua_State *luastate, bool client, const SSLState *ssl_state)
+{
+ const SSLStateConnp *connp;
- if (direction) {
+ if (client) {
connp = &ssl_state->client_connp;
} else {
connp = &ssl_state->server_connp;
if (connp->cert0_not_before == 0)
return LuaCallbackError(luastate, "error: no certificate NotBefore");
- int r = LuaPushInteger(luastate, connp->cert0_not_before);
-
- return r;
+ return LuaPushInteger(luastate, connp->cert0_not_before);
}
-static int TlsGetCertNotBefore(lua_State *luastate)
+static int LuaTlsGetServerCertNotBefore(lua_State *luastate)
{
- int r;
-
- if (!(LuaStateNeedProto(luastate, ALPROTO_TLS)))
- return LuaCallbackError(luastate, "error: protocol not tls");
-
- int direction = LuaStateGetDirection(luastate);
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get flow");
+ }
- Flow *f = LuaStateGetFlow(luastate);
- if (f == NULL)
- return LuaCallbackError(luastate, "internal error: no flow");
+ return GetCertNotBefore(luastate, false, s->state);
+}
- r = GetCertNotBefore(luastate, f, direction);
+static int LuaTlsGetClientCertNotBefore(lua_State *luastate)
+{
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get flow");
+ }
- return r;
+ return GetCertNotBefore(luastate, true, s->state);
}
-static int GetCertNotAfter(lua_State *luastate, const Flow *f, int direction)
+static int GetCertNotAfter(lua_State *luastate, bool client, const SSLState *ssl_state)
{
- void *state = FlowGetAppState(f);
- if (state == NULL)
- return LuaCallbackError(luastate, "error: no app layer state");
+ const SSLStateConnp *connp;
- SSLState *ssl_state = (SSLState *)state;
- SSLStateConnp *connp = NULL;
-
- if (direction) {
+ if (client) {
connp = &ssl_state->client_connp;
} else {
connp = &ssl_state->server_connp;
if (connp->cert0_not_after == 0)
return LuaCallbackError(luastate, "error: no certificate NotAfter");
- int r = LuaPushInteger(luastate, connp->cert0_not_after);
-
- return r;
+ return LuaPushInteger(luastate, connp->cert0_not_after);
}
-static int TlsGetCertNotAfter(lua_State *luastate)
+static int LuaTlsGetServerCertNotAfter(lua_State *luastate)
{
- int r;
-
- if (!(LuaStateNeedProto(luastate, ALPROTO_TLS)))
- return LuaCallbackError(luastate, "error: protocol not tls");
-
- int direction = LuaStateGetDirection(luastate);
-
- Flow *f = LuaStateGetFlow(luastate);
- if (f == NULL)
- return LuaCallbackError(luastate, "internal error: no flow");
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get state");
+ }
- r = GetCertNotAfter(luastate, f, direction);
+ return GetCertNotAfter(luastate, false, s->state);
+}
+static int LuaTlsGetClientCertNotAfter(lua_State *luastate)
+{
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get state");
+ }
- return r;
+ return GetCertNotAfter(luastate, true, s->state);
}
-static int GetCertInfo(lua_State *luastate, const Flow *f, int direction)
+static int GetCertInfo(lua_State *luastate, bool client, const SSLState *ssl_state)
{
- void *state = FlowGetAppState(f);
- if (state == NULL)
- return LuaCallbackError(luastate, "error: no app layer state");
+ const SSLStateConnp *connp;
- SSLState *ssl_state = (SSLState *)state;
- SSLStateConnp *connp = NULL;
-
- if (direction) {
+ if (client) {
connp = &ssl_state->client_connp;
} else {
connp = &ssl_state->server_connp;
return r;
}
-static int TlsGetCertInfo(lua_State *luastate)
-{
- int r;
-
- if (!(LuaStateNeedProto(luastate, ALPROTO_TLS)))
- return LuaCallbackError(luastate, "error: protocol not tls");
-
- int direction = LuaStateGetDirection(luastate);
-
- Flow *f = LuaStateGetFlow(luastate);
- if (f == NULL)
- return LuaCallbackError(luastate, "internal error: no flow");
-
- r = GetCertInfo(luastate, f, direction);
-
- return r;
-}
-
-static int GetAgreedVersion(lua_State *luastate, const Flow *f)
+static int LuaTlsGetServerCertInfo(lua_State *luastate)
{
- void *state = FlowGetAppState(f);
- if (state == NULL)
- return LuaCallbackError(luastate, "error: no app layer state");
-
- SSLState *ssl_state = (SSLState *)state;
-
- char ssl_version[SSL_VERSION_MAX_STRLEN];
- SSLVersionToString(ssl_state->server_connp.version, ssl_version);
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get state");
+ }
- return LuaPushStringBuffer(luastate, (uint8_t *)ssl_version,
- strlen(ssl_version));
+ return GetCertInfo(luastate, false, s->state);
}
-static int TlsGetVersion(lua_State *luastate)
+static int LuaTlsGetClientCertInfo(lua_State *luastate)
{
- int r;
-
- if (!(LuaStateNeedProto(luastate, ALPROTO_TLS)))
- return LuaCallbackError(luastate, "error: protocol not tls");
-
- Flow *f = LuaStateGetFlow(luastate);
- if (f == NULL)
- return LuaCallbackError(luastate, "internal error: no flow");
-
- r = GetAgreedVersion(luastate, f);
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get state");
+ }
- return r;
+ return GetCertInfo(luastate, true, s->state);
}
-static int GetSNI(lua_State *luastate, const Flow *f)
+static int GetSNI(lua_State *luastate, const SSLState *ssl_state)
{
- void *state = FlowGetAppState(f);
- if (state == NULL)
- return LuaCallbackError(luastate, "error: no app layer state");
-
- SSLState *ssl_state = (SSLState *)state;
-
if (ssl_state->client_connp.sni == NULL)
return LuaCallbackError(luastate, "error: no server name indication");
strlen(ssl_state->client_connp.sni));
}
-static int TlsGetSNI(lua_State *luastate)
+static int LuaTlsGetSNI(lua_State *luastate)
{
- int r;
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get state");
+ }
if (!(LuaStateNeedProto(luastate, ALPROTO_TLS)))
return LuaCallbackError(luastate, "error: protocol not tls");
- Flow *f = LuaStateGetFlow(luastate);
- if (f == NULL)
- return LuaCallbackError(luastate, "internal error: no flow");
-
- r = GetSNI(luastate, f);
-
- return r;
+ return GetSNI(luastate, s->state);
}
-static int GetCertSerial(lua_State *luastate, const Flow *f)
+static int GetCertChain(lua_State *luastate, bool client)
{
- void *state = FlowGetAppState(f);
- if (state == NULL)
- return LuaCallbackError(luastate, "error: no app layer state");
-
- SSLState *ssl_state = (SSLState *)state;
-
- if (ssl_state->server_connp.cert0_serial == NULL)
- return LuaCallbackError(luastate, "error: no certificate serial");
-
- return LuaPushStringBuffer(luastate,
- (uint8_t *)ssl_state->server_connp.cert0_serial,
- strlen(ssl_state->server_connp.cert0_serial));
-}
-
-static int TlsGetCertSerial(lua_State *luastate)
-{
- int r;
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get state");
+ }
if (!(LuaStateNeedProto(luastate, ALPROTO_TLS)))
return LuaCallbackError(luastate, "error: protocol not tls");
- Flow *f = LuaStateGetFlow(luastate);
- if (f == NULL)
- return LuaCallbackError(luastate, "internal error: no flow");
-
- r = GetCertSerial(luastate, f);
-
- return r;
-}
-
-static int GetCertChain(lua_State *luastate, const Flow *f, int direction)
-{
- void *state = FlowGetAppState(f);
- if (state == NULL)
- return LuaCallbackError(luastate, "error: no app layer state");
+ const SSLStateConnp *connp;
- SSLState *ssl_state = (SSLState *)state;
- SSLStateConnp *connp = NULL;
-
- if (direction) {
- connp = &ssl_state->client_connp;
+ if (client) {
+ connp = &s->state->client_connp;
} else {
- connp = &ssl_state->server_connp;
+ connp = &s->state->server_connp;
}
uint32_t u = 0;
lua_newtable(luastate);
SSLCertsChain *cert = NULL;
+
TAILQ_FOREACH(cert, &connp->certs, next)
{
lua_pushinteger(luastate, u++);
return 1;
}
-static int TlsGetCertChain(lua_State *luastate)
+static int LuaTlsGetServerCertChain(lua_State *luastate)
{
- int r;
+ return GetCertChain(luastate, false);
+}
- if (!(LuaStateNeedProto(luastate, ALPROTO_TLS)))
- return LuaCallbackError(luastate, "error: protocol not tls");
+static int LuaTlsGetClientCertChain(lua_State *luastate)
+{
+ return GetCertChain(luastate, true);
+}
- int direction = LuaStateGetDirection(luastate);
+static int GetCertSerial(lua_State *luastate, bool client)
+{
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get flow");
+ }
- Flow *f = LuaStateGetFlow(luastate);
- if (f == NULL)
- return LuaCallbackError(luastate, "internal error: no flow");
+ const SSLStateConnp *connp;
- r = GetCertChain(luastate, f, direction);
+ if (client) {
+ connp = &s->state->client_connp;
+ } else {
+ connp = &s->state->server_connp;
+ }
+ if (connp->cert0_serial == NULL)
+ return LuaCallbackError(luastate, "error: no certificate serial");
- return r;
+ return LuaPushStringBuffer(
+ luastate, (uint8_t *)connp->cert0_serial, strlen(connp->cert0_serial));
+}
+
+static int LuaTlsGetServerCertSerial(lua_State *luastate)
+{
+ return GetCertSerial(luastate, false);
}
-/** \brief register tls lua extensions in a luastate */
-int LuaRegisterTlsFunctions(lua_State *luastate)
+static int LuaTlsGetClientCertSerial(lua_State *luastate)
{
- /* registration of the callbacks */
- lua_pushcfunction(luastate, TlsGetCertNotBefore);
- lua_setglobal(luastate, "TlsGetCertNotBefore");
+ return GetCertSerial(luastate, true);
+}
- lua_pushcfunction(luastate, TlsGetCertNotAfter);
- lua_setglobal(luastate, "TlsGetCertNotAfter");
+static int GetAgreedVersion(lua_State *luastate, bool client)
+{
+ struct LuaTls *s = (struct LuaTls *)luaL_checkudata(luastate, 1, tls_state_mt);
+ if (s->state == NULL) {
+ LUA_ERROR("failed to get state");
+ }
- lua_pushcfunction(luastate, TlsGetVersion);
- lua_setglobal(luastate, "TlsGetVersion");
+ uint16_t version;
+ if (client) {
+ version = s->state->client_connp.version;
+ } else {
+ version = s->state->server_connp.version;
+ }
- lua_pushcfunction(luastate, TlsGetCertInfo);
- lua_setglobal(luastate, "TlsGetCertInfo");
+ char ssl_version[SSL_VERSION_MAX_STRLEN];
+ SSLVersionToString(version, ssl_version);
- lua_pushcfunction(luastate, TlsGetSNI);
- lua_setglobal(luastate, "TlsGetSNI");
+ lua_pushstring(luastate, (const char *)&ssl_version);
+ return 1;
+}
+
+static int LuaTlsGetServerVersion(lua_State *luastate)
+{
+ return GetAgreedVersion(luastate, false);
+}
- lua_pushcfunction(luastate, TlsGetCertSerial);
- lua_setglobal(luastate, "TlsGetCertSerial");
+static int LuaTlsGetClientVersion(lua_State *luastate)
+{
+ return GetAgreedVersion(luastate, true);
+}
- lua_pushcfunction(luastate, TlsGetCertChain);
- lua_setglobal(luastate, "TlsGetCertChain");
+static const struct luaL_Reg tlslib_meta[] = {
+ // clang-format off
+ { "get_server_cert_not_before", LuaTlsGetServerCertNotBefore },
+ { "get_client_cert_not_before", LuaTlsGetClientCertNotBefore },
+ { "get_server_cert_not_after", LuaTlsGetServerCertNotAfter },
+ { "get_client_cert_not_after", LuaTlsGetClientCertNotAfter },
+ { "get_server_version", LuaTlsGetServerVersion },
+ { "get_client_version", LuaTlsGetClientVersion },
+ { "get_server_serial", LuaTlsGetServerCertSerial },
+ { "get_client_serial", LuaTlsGetClientCertSerial },
+ { "get_server_cert_info", LuaTlsGetServerCertInfo },
+ { "get_client_cert_info", LuaTlsGetClientCertInfo },
+ { "get_client_sni", LuaTlsGetSNI },
+ { "get_client_cert_chain", LuaTlsGetClientCertChain },
+ { "get_server_cert_chain", LuaTlsGetServerCertChain },
+ { NULL, NULL, }
+ // clang-format off
+};
+
+static const struct luaL_Reg tlslib[] = {
+ // clang-format off
+ { "get_tx", LuaTlsFlowStateGet },
+ { NULL, NULL, },
+ // clang-format on
+};
+
+int SCLuaLoadTlsLib(lua_State *L)
+{
+ luaL_newmetatable(L, tls_state_mt);
+ lua_pushvalue(L, -1);
+ lua_setfield(L, -2, "__index");
+ luaL_setfuncs(L, tlslib_meta, 0);
- return 0;
+ luaL_newlib(L, tlslib);
+ return 1;
}