]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Refactor OpenSSL RSA components getting to a helper function
authorTimo Teräs <timo.teras@iki.fi>
Mon, 26 Dec 2022 17:07:18 +0000 (19:07 +0200)
committerOndřej Surý <ondrej@isc.org>
Mon, 9 Jan 2023 14:55:07 +0000 (15:55 +0100)
lib/dns/opensslrsa_link.c

index acf963719ff02ca84efcc7e7a78bf5fb88054bb1..221e9e1723ac964e75212f7181de4b75930b434c 100644 (file)
                goto err; \
        }
 
+typedef struct rsa_components {
+       bool bnfree;
+       const BIGNUM *e, *n, *d, *p, *q, *dmp1, *dmq1, *iqmp;
+} rsa_components_t;
+
+static isc_result_t
+opensslrsa_components_get(const dst_key_t *key, rsa_components_t *c,
+                         bool private) {
+       REQUIRE(c->e == NULL && c->n == NULL && c->d == NULL && c->p == NULL &&
+               c->q == NULL && c->dmp1 == NULL && c->dmq1 == NULL &&
+               c->iqmp == NULL);
+
+       EVP_PKEY *pub = key->keydata.pkeypair.pub;
+       EVP_PKEY *priv = key->keydata.pkeypair.priv;
+
+       if (private && priv == NULL) {
+               return (DST_R_INVALIDPRIVATEKEY);
+       }
+#if OPENSSL_VERSION_NUMBER >= 0x30000000L
+       if (EVP_PKEY_get_bn_param(pub, OSSL_PKEY_PARAM_RSA_E,
+                                 (BIGNUM **)&c->e) == 1)
+       {
+               isc_result_t ret = ISC_R_UNSET;
+
+               c->bnfree = true;
+               if (EVP_PKEY_get_bn_param(pub, OSSL_PKEY_PARAM_RSA_N,
+                                         (BIGNUM **)&c->n) != 1)
+               {
+                       DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+               }
+               if (!private) {
+                       return (ISC_R_SUCCESS);
+               }
+               if (EVP_PKEY_get_bn_param(priv, OSSL_PKEY_PARAM_RSA_D,
+                                         (BIGNUM **)&c->d) != 1)
+               {
+                       DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+               }
+               if (EVP_PKEY_get_bn_param(priv, OSSL_PKEY_PARAM_RSA_FACTOR1,
+                                         (BIGNUM **)&c->p) != 1)
+               {
+                       DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+               }
+               if (EVP_PKEY_get_bn_param(priv, OSSL_PKEY_PARAM_RSA_FACTOR2,
+                                         (BIGNUM **)&c->q) != 1)
+               {
+                       DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+               }
+               if (EVP_PKEY_get_bn_param(priv, OSSL_PKEY_PARAM_RSA_EXPONENT1,
+                                         (BIGNUM **)&c->dmp1) != 1)
+               {
+                       DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+               }
+               if (EVP_PKEY_get_bn_param(priv, OSSL_PKEY_PARAM_RSA_EXPONENT2,
+                                         (BIGNUM **)&c->dmq1) != 1)
+               {
+                       DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+               }
+               if (EVP_PKEY_get_bn_param(priv,
+                                         OSSL_PKEY_PARAM_RSA_COEFFICIENT1,
+                                         (BIGNUM **)&c->iqmp) != 1)
+               {
+                       DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+               }
+               return (ISC_R_SUCCESS);
+       err:
+               return (ret);
+       }
+#endif
+#if OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000
+       const RSA *rsa = EVP_PKEY_get0_RSA(pub);
+       if (rsa == NULL) {
+               return (dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+       }
+       RSA_get0_key(rsa, &c->n, &c->e, &c->d);
+       if (c->e == NULL || c->n == NULL) {
+               return (dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+       }
+       if (!private) {
+               return (ISC_R_SUCCESS);
+       }
+       rsa = EVP_PKEY_get0_RSA(priv);
+       if (rsa == NULL) {
+               return (dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+       }
+       RSA_get0_factors(rsa, &c->p, &c->q);
+       RSA_get0_crt_params(rsa, &c->dmp1, &c->dmq1, &c->iqmp);
+       return (ISC_R_SUCCESS);
+#else
+       return (DST_R_OPENSSLFAILURE);
+#endif
+}
+
+static void
+opensslrsa_components_free(rsa_components_t *c) {
+       if (!c->bnfree) {
+               return;
+       }
+       if (c->e != NULL) {
+               BN_free((BIGNUM *)c->e);
+       }
+       if (c->n != NULL) {
+               BN_free((BIGNUM *)c->n);
+       }
+       if (c->d != NULL) {
+               BN_clear_free((BIGNUM *)c->d);
+       }
+       if (c->p != NULL) {
+               BN_clear_free((BIGNUM *)c->p);
+       }
+       if (c->q != NULL) {
+               BN_clear_free((BIGNUM *)c->q);
+       }
+       if (c->dmp1 != NULL) {
+               BN_clear_free((BIGNUM *)c->dmp1);
+       }
+       if (c->dmq1 != NULL) {
+               BN_clear_free((BIGNUM *)c->dmq1);
+       }
+       if (c->iqmp != NULL) {
+               BN_clear_free((BIGNUM *)c->iqmp);
+       }
+}
+
 static bool
 opensslrsa_valid_key_alg(unsigned int key_alg) {
        switch (key_alg) {
@@ -456,35 +580,19 @@ opensslrsa_todns(const dst_key_t *key, isc_buffer_t *data) {
        unsigned int e_bytes;
        unsigned int mod_bytes;
        isc_result_t ret;
-       EVP_PKEY *pkey;
-#if OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000
-       RSA *rsa;
-       const BIGNUM *e = NULL, *n = NULL;
-#else
-       BIGNUM *e = NULL, *n = NULL;
-#endif /* OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000 */
+       rsa_components_t c = { 0 };
 
        REQUIRE(key->keydata.pkeypair.pub != NULL);
 
-       pkey = key->keydata.pkeypair.pub;
        isc_buffer_availableregion(data, &r);
 
-#if OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000
-       rsa = EVP_PKEY_get1_RSA(pkey);
-       if (rsa == NULL) {
-               DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
-       }
-       RSA_get0_key(rsa, &n, &e, NULL);
-#else
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_E, &e);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_N, &n);
-#endif /* OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000 */
-       if (e == NULL || n == NULL) {
-               DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+       ret = opensslrsa_components_get(key, &c, false);
+       if (ret != ISC_R_SUCCESS) {
+               goto err;
        }
 
-       mod_bytes = BN_num_bytes(n);
-       e_bytes = BN_num_bytes(e);
+       mod_bytes = BN_num_bytes(c.n);
+       e_bytes = BN_num_bytes(c.e);
 
        if (e_bytes < 256) { /*%< key exponent is <= 2040 bits */
                if (r.length < 1) {
@@ -505,27 +613,16 @@ opensslrsa_todns(const dst_key_t *key, isc_buffer_t *data) {
                DST_RET(ISC_R_NOSPACE);
        }
 
-       BN_bn2bin(e, r.base);
+       BN_bn2bin(c.e, r.base);
        isc_region_consume(&r, e_bytes);
-       BN_bn2bin(n, r.base);
+       BN_bn2bin(c.n, r.base);
        isc_region_consume(&r, mod_bytes);
 
        isc_buffer_add(data, e_bytes + mod_bytes);
 
        ret = ISC_R_SUCCESS;
 err:
-#if OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000
-       if (rsa != NULL) {
-               RSA_free(rsa);
-       }
-#else
-       if (e != NULL) {
-               BN_free(e);
-       }
-       if (n != NULL) {
-               BN_free(n);
-       }
-#endif /* OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000 */
+       opensslrsa_components_free(&c);
        return (ret);
 }
 
@@ -684,123 +781,87 @@ opensslrsa_tofile(const dst_key_t *key, const char *directory) {
        dst_private_t priv = { 0 };
        unsigned char *bufs[8] = { NULL };
        unsigned short i = 0;
-       EVP_PKEY *pkey;
-#if OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000
-       RSA *rsa = NULL;
-       const BIGNUM *n = NULL, *e = NULL, *d = NULL;
-       const BIGNUM *p = NULL, *q = NULL;
-       const BIGNUM *dmp1 = NULL, *dmq1 = NULL, *iqmp = NULL;
-#else
-       BIGNUM *n = NULL, *e = NULL, *d = NULL;
-       BIGNUM *p = NULL, *q = NULL;
-       BIGNUM *dmp1 = NULL, *dmq1 = NULL, *iqmp = NULL;
-#endif /* OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000 */
-
-       if (key->keydata.pkeypair.priv != NULL) {
-               pkey = key->keydata.pkeypair.priv;
-       } else if (key->keydata.pkeypair.pub != NULL) {
-               pkey = key->keydata.pkeypair.pub;
-       } else {
-               DST_RET(DST_R_NULLKEY);
-       }
+       rsa_components_t c = { 0 };
 
        if (key->external) {
                return (dst__privstruct_writefile(key, &priv, directory));
        }
 
-#if OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000
-       rsa = EVP_PKEY_get1_RSA(pkey);
-       if (rsa == NULL) {
-               DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
-       }
-       RSA_get0_key(rsa, &n, &e, &d);
-       RSA_get0_factors(rsa, &p, &q);
-       RSA_get0_crt_params(rsa, &dmp1, &dmq1, &iqmp);
-#else
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_N, &n);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_E, &e);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_D, &d);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_FACTOR1, &p);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_FACTOR2, &q);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_EXPONENT1, &dmp1);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_EXPONENT2, &dmq1);
-       EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_COEFFICIENT1, &iqmp);
-#endif /* OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000 */
-
-       if (n == NULL || e == NULL) {
-               DST_RET(dst__openssl_toresult(DST_R_OPENSSLFAILURE));
+       ret = opensslrsa_components_get(key, &c, true);
+       if (ret != ISC_R_SUCCESS) {
+               goto err;
        }
 
        priv.elements[i].tag = TAG_RSA_MODULUS;
-       priv.elements[i].length = BN_num_bytes(n);
+       priv.elements[i].length = BN_num_bytes(c.n);
        bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-       BN_bn2bin(n, bufs[i]);
+       BN_bn2bin(c.n, bufs[i]);
        priv.elements[i].data = bufs[i];
        i++;
 
        priv.elements[i].tag = TAG_RSA_PUBLICEXPONENT;
-       priv.elements[i].length = BN_num_bytes(e);
+       priv.elements[i].length = BN_num_bytes(c.e);
        bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-       BN_bn2bin(e, bufs[i]);
+       BN_bn2bin(c.e, bufs[i]);
        priv.elements[i].data = bufs[i];
        i++;
 
-       if (d != NULL) {
+       if (c.d != NULL) {
                priv.elements[i].tag = TAG_RSA_PRIVATEEXPONENT;
-               priv.elements[i].length = BN_num_bytes(d);
+               priv.elements[i].length = BN_num_bytes(c.d);
                INSIST(i < ARRAY_SIZE(bufs));
                bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-               BN_bn2bin(d, bufs[i]);
+               BN_bn2bin(c.d, bufs[i]);
                priv.elements[i].data = bufs[i];
                i++;
        }
 
-       if (p != NULL) {
+       if (c.p != NULL) {
                priv.elements[i].tag = TAG_RSA_PRIME1;
-               priv.elements[i].length = BN_num_bytes(p);
+               priv.elements[i].length = BN_num_bytes(c.p);
                INSIST(i < ARRAY_SIZE(bufs));
                bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-               BN_bn2bin(p, bufs[i]);
+               BN_bn2bin(c.p, bufs[i]);
                priv.elements[i].data = bufs[i];
                i++;
        }
 
-       if (q != NULL) {
+       if (c.q != NULL) {
                priv.elements[i].tag = TAG_RSA_PRIME2;
-               priv.elements[i].length = BN_num_bytes(q);
+               priv.elements[i].length = BN_num_bytes(c.q);
                INSIST(i < ARRAY_SIZE(bufs));
                bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-               BN_bn2bin(q, bufs[i]);
+               BN_bn2bin(c.q, bufs[i]);
                priv.elements[i].data = bufs[i];
                i++;
        }
 
-       if (dmp1 != NULL) {
+       if (c.dmp1 != NULL) {
                priv.elements[i].tag = TAG_RSA_EXPONENT1;
-               priv.elements[i].length = BN_num_bytes(dmp1);
+               priv.elements[i].length = BN_num_bytes(c.dmp1);
                INSIST(i < ARRAY_SIZE(bufs));
                bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-               BN_bn2bin(dmp1, bufs[i]);
+               BN_bn2bin(c.dmp1, bufs[i]);
                priv.elements[i].data = bufs[i];
                i++;
        }
 
-       if (dmq1 != NULL) {
+       if (c.dmq1 != NULL) {
                priv.elements[i].tag = TAG_RSA_EXPONENT2;
-               priv.elements[i].length = BN_num_bytes(dmq1);
+               priv.elements[i].length = BN_num_bytes(c.dmq1);
                INSIST(i < ARRAY_SIZE(bufs));
                bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-               BN_bn2bin(dmq1, bufs[i]);
+               BN_bn2bin(c.dmq1, bufs[i]);
                priv.elements[i].data = bufs[i];
                i++;
        }
 
-       if (iqmp != NULL) {
+       if (c.iqmp != NULL) {
                priv.elements[i].tag = TAG_RSA_COEFFICIENT;
-               priv.elements[i].length = BN_num_bytes(iqmp);
+               priv.elements[i].length = BN_num_bytes(c.iqmp);
                INSIST(i < ARRAY_SIZE(bufs));
                bufs[i] = isc_mem_get(key->mctx, priv.elements[i].length);
-               BN_bn2bin(iqmp, bufs[i]);
+               BN_bn2bin(c.iqmp, bufs[i]);
                priv.elements[i].data = bufs[i];
                i++;
        }
@@ -831,34 +892,7 @@ err:
                                    priv.elements[i].length);
                }
        }
-#if OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000
-       RSA_free(rsa);
-#else
-       if (n != NULL) {
-               BN_free(n);
-       }
-       if (e != NULL) {
-               BN_free(e);
-       }
-       if (d != NULL) {
-               BN_clear_free(d);
-       }
-       if (p != NULL) {
-               BN_clear_free(p);
-       }
-       if (q != NULL) {
-               BN_clear_free(q);
-       }
-       if (dmp1 != NULL) {
-               BN_clear_free(dmp1);
-       }
-       if (dmq1 != NULL) {
-               BN_clear_free(dmq1);
-       }
-       if (iqmp != NULL) {
-               BN_clear_free(iqmp);
-       }
-#endif /* OPENSSL_VERSION_NUMBER < 0x30000000L || OPENSSL_API_LEVEL < 30000 */
+       opensslrsa_components_free(&c);
 
        return (ret);
 }