]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
shared/ip-procotol-list: generalize and rework parse_ip_protocol()
authorZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>
Sat, 16 Sep 2023 10:43:16 +0000 (12:43 +0200)
committerZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>
Fri, 22 Sep 2023 06:17:42 +0000 (08:17 +0200)
Optionally, accept protocols that don't have a known name.
Avoid any allocations in the common case.
Return more granular error codes: -ERANGE for negative values,
-EOPNOTSUPP if the protocol is a valid number, but we don't know
the protocol, and -EINVAL only if it's not a numerical string.

src/shared/ip-protocol-list.c
src/shared/ip-protocol-list.h
src/test/test-ip-protocol-list.c

index 21d88f2660776f761b2f296c233c4277b071cc10..14155b679ab5b90030aeeffe201d53f76feaff60 100644 (file)
@@ -37,33 +37,40 @@ int ip_protocol_from_name(const char *name) {
         return sc->id;
 }
 
-int parse_ip_protocol(const char *s) {
-        _cleanup_free_ char *str = NULL;
-        int i, r;
+int parse_ip_protocol_full(const char *s, bool relaxed) {
+        int r, p;
 
         assert(s);
 
         if (isempty(s))
                 return IPPROTO_IP;
 
-        /* Do not use strdupa() here, as the input string may come from *
-         * command line or config files. */
-        str = strdup(s);
-        if (!str)
+        /* People commonly use lowercase protocol names, which we can look up very quickly, so let's try that
+         * first. */
+        r = ip_protocol_from_name(s);
+        if (r >= 0)
+                return r;
+
+        /* Do not use strdupa() here, as the input string may come from command line or config files. */
+        _cleanup_free_ char *t = strdup(s);
+        if (!t)
                 return -ENOMEM;
 
-        i = ip_protocol_from_name(ascii_strlower(str));
-        if (i >= 0)
-                return i;
+        r = ip_protocol_from_name(ascii_strlower(t));
+        if (r >= 0)
+                return r;
 
-        r = safe_atoi(str, &i);
+        r = safe_atoi(t, &p);
         if (r < 0)
                 return r;
+        if (p < 0)
+                return -ERANGE;
 
-        if (!ip_protocol_to_name(i))
-                return -EINVAL;
+        /* If @relaxed, we don't check that we have a name for the protocol. */
+        if (!relaxed && !ip_protocol_to_name(p))
+                return -EPROTONOSUPPORT;
 
-        return i;
+        return p;
 }
 
 const char *ip_protocol_to_tcp_udp(int id) {
index b40ec083016fd3d63af07855f51a65113d17f955..a0875ef234c4244ef776f52532189322f9f5e788 100644 (file)
@@ -1,9 +1,14 @@
 /* SPDX-License-Identifier: LGPL-2.1-or-later */
 #pragma once
 
+#include <stdbool.h>
+
 const char *ip_protocol_to_name(int id);
 int ip_protocol_from_name(const char *name);
-int parse_ip_protocol(const char *s);
+int parse_ip_protocol_full(const char *s, bool relaxed);
+static inline int parse_ip_protocol(const char *s) {
+        return parse_ip_protocol_full(s, false);
+}
 
 const char *ip_protocol_to_tcp_udp(int id);
 int ip_protocol_from_tcp_udp(const char *ip_protocol);
index 018441d497a38c7c69405155001e91f7014b661f..dfff015f53e44310f83e268825732b4c5f8865df 100644 (file)
@@ -17,13 +17,13 @@ static void test_int(int i) {
         assert_se(ip_protocol_from_name(ip_protocol_to_name(parse_ip_protocol(str))) == i);
 }
 
-static void test_int_fail(int i) {
+static void test_int_fail(int i, int error) {
         char str[DECIMAL_STR_MAX(int)];
 
         assert_se(!ip_protocol_to_name(i));
 
         xsprintf(str, "%i", i);
-        assert_se(parse_ip_protocol(str) == -EINVAL);
+        assert_se(parse_ip_protocol(str) == error);
 }
 
 static void test_str(const char *s) {
@@ -31,39 +31,41 @@ static void test_str(const char *s) {
         assert_se(streq(ip_protocol_to_name(parse_ip_protocol(s)), s));
 }
 
-static void test_str_fail(const char *s) {
+static void test_str_fail(const char *s, int error) {
         assert_se(ip_protocol_from_name(s) == -EINVAL);
-        assert_se(parse_ip_protocol(s) == -EINVAL);
-}
-
-static void test_parse_ip_protocol_one(const char *s, int expected) {
-        assert_se(parse_ip_protocol(s) == expected);
+        assert_se(parse_ip_protocol(s) == error);
 }
 
 TEST(integer) {
         test_int(IPPROTO_TCP);
         test_int(IPPROTO_DCCP);
-        test_int_fail(-1);
-        test_int_fail(1024 * 1024);
+        test_int_fail(-1, -ERANGE);
+        test_int_fail(1024 * 1024, -EPROTONOSUPPORT);
 }
 
 TEST(string) {
         test_str("sctp");
         test_str("udp");
-        test_str_fail("hoge");
-        test_str_fail("-1");
-        test_str_fail("1000000000");
+        test_str_fail("hoge", -EINVAL);
+        test_str_fail("-1", -ERANGE);
+        test_str_fail("1000000000", -EPROTONOSUPPORT);
 }
 
 TEST(parse_ip_protocol) {
-        test_parse_ip_protocol_one("sctp", IPPROTO_SCTP);
-        test_parse_ip_protocol_one("ScTp", IPPROTO_SCTP);
-        test_parse_ip_protocol_one("ip", IPPROTO_IP);
-        test_parse_ip_protocol_one("", IPPROTO_IP);
-        test_parse_ip_protocol_one("1", 1);
-        test_parse_ip_protocol_one("0", 0);
-        test_parse_ip_protocol_one("-10", -EINVAL);
-        test_parse_ip_protocol_one("100000000", -EINVAL);
+        assert_se(parse_ip_protocol("sctp") == IPPROTO_SCTP);
+        assert_se(parse_ip_protocol("ScTp") == IPPROTO_SCTP);
+        assert_se(parse_ip_protocol("ip") == IPPROTO_IP);
+        assert_se(parse_ip_protocol("") == IPPROTO_IP);
+        assert_se(parse_ip_protocol("1") == 1);
+        assert_se(parse_ip_protocol("0") == 0);
+        assert_se(parse_ip_protocol("-10") == -ERANGE);
+        assert_se(parse_ip_protocol("100000000") == -EPROTONOSUPPORT);
+}
+
+TEST(parse_ip_protocol_full) {
+        assert_se(parse_ip_protocol_full("-1", true) == -ERANGE);
+        assert_se(parse_ip_protocol_full("0", true) == 0);
+        assert_se(parse_ip_protocol_full("11", true) == 11);
 }
 
 DEFINE_TEST_MAIN(LOG_INFO);