]> git.ipfire.org Git - thirdparty/squid.git/blobdiff - helpers/negotiate_auth/kerberos/negotiate_kerberos_auth.cc
SourceFormat Enforcement
[thirdparty/squid.git] / helpers / negotiate_auth / kerberos / negotiate_kerberos_auth.cc
index 93aa1d13394ce6d40f964d15ae2c9c47279a3fb1..270d9ca9e49e9149eb7c88419efc7113c64fa5aa 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 1996-2014 The Squid Software Foundation and contributors
+ * Copyright (C) 1996-2015 The Squid Software Foundation and contributors
  *
  * Squid software is distributed under GPLv2+ license and includes
  * contributions from numerous individuals and organizations.
@@ -36,8 +36,6 @@
  */
 
 #include "squid.h"
-#include "compat/getaddrinfo.h"
-#include "compat/getnameinfo.h"
 #include "rfc1738.h"
 
 #if HAVE_GSSAPI
@@ -72,7 +70,7 @@ int
 check_k5_err(krb5_context context, const char *function, krb5_error_code code)
 {
 
-    if (code) {
+    if (code && code != KRB5_KT_END) {
         const char *errmsg;
         errmsg = krb5_get_error_message(context, code);
         debug((char *) "%s| %s: ERROR: %s failed: %s\n", LogTime(), PROGRAM, function, errmsg);
@@ -392,12 +390,12 @@ main(int argc, char *const argv[])
             struct stat fstat;
             char *ktp;
 #endif
-           if (optarg)
+            if (optarg)
                 keytab_name = xstrdup(optarg);
-           else {
+            else {
                 fprintf(stderr, "ERROR: keytab file not given\n");
-               exit(1);
-           }
+                exit(1);
+            }
             /*
              * Some sanity checks
              */
@@ -428,12 +426,12 @@ main(int argc, char *const argv[])
 #if HAVE_SYS_STAT_H
             struct stat dstat;
 #endif
-           if (optarg)
+            if (optarg)
                 rcache_dir = xstrdup(optarg);
-           else {
+            else {
                 fprintf(stderr, "ERROR: replay cache directory not given\n");
-               exit(1);
-           }
+                exit(1);
+            }
             /*
              * Some sanity checks
              */
@@ -457,20 +455,20 @@ main(int argc, char *const argv[])
 #endif
             break;
         case 't':
-           if (optarg)
+            if (optarg)
                 rcache_type = xstrdup(optarg);
-           else {
+            else {
                 fprintf(stderr, "ERROR: replay cache type not given\n");
-               exit(1);
-           }
+                exit(1);
+            }
             break;
         case 's':
-           if (optarg)
+            if (optarg)
                 service_principal = xstrdup(optarg);
-           else {
+            else {
                 fprintf(stderr, "ERROR: service principal not given\n");
-               exit(1);
-           }
+                exit(1);
+            }
             break;
         default:
             fprintf(stderr, "Usage: \n");
@@ -570,6 +568,11 @@ main(int argc, char *const argv[])
                       LogTime(), PROGRAM, memory_keytab_name);
             }
         }
+        ret = krb5_free_kt_list(context,ktlist);
+        if (check_k5_err(context, "krb5_free_kt_list", ret)) {
+            debug((char *) "%s| %s: ERROR: Freeing list failed\n",
+                  LogTime(), PROGRAM);
+        }
     }
     krb5_free_context(context);
 #endif
@@ -584,7 +587,7 @@ main(int argc, char *const argv[])
                       strerror(ferror(stdin)));
 
                 fprintf(stdout, "BH input error\n");
-                exit(1);       /* BIIG buffer */
+                exit(1);    /* BIIG buffer */
             }
             fprintf(stdout, "BH input error\n");
             exit(0);
@@ -653,12 +656,23 @@ main(int argc, char *const argv[])
             fprintf(stdout, "BH Invalid negotiate request\n");
             continue;
         }
-        input_token.length = (size_t)base64_decode_len(buf+3);
-        debug((char *) "%s| %s: DEBUG: Decode '%s' (decoded length: %d).\n",
-              LogTime(), PROGRAM, buf + 3, (int) input_token.length);
+        const uint8_t *b64Token = reinterpret_cast<const uint8_t*>(buf+3);
+        const size_t srcLen = strlen(buf+3);
+        input_token.length = BASE64_DECODE_LENGTH(srcLen);
+        debug((char *) "%s| %s: DEBUG: Decode '%s' (decoded length estimate: %d).\n",
+              LogTime(), PROGRAM, b64Token, (int) input_token.length);
         input_token.value = xmalloc(input_token.length);
 
-        input_token.length = (size_t)base64_decode((char *) input_token.value, (unsigned int)input_token.length, buf+3);
+        struct base64_decode_ctx ctx;
+        base64_decode_init(&ctx);
+        size_t dstLen = 0;
+        if (!base64_decode_update(&ctx, &dstLen, static_cast<uint8_t*>(input_token.value), srcLen, b64Token) ||
+                !base64_decode_final(&ctx)) {
+            debug((char *) "%s| %s: ERROR: Invalid base64 token [%s]\n", LogTime(), PROGRAM, b64Token);
+            fprintf(stdout, "BH Invalid negotiate request token\n");
+            continue;
+        }
+        input_token.length = dstLen;
 
         if ((input_token.length >= sizeof ntlmProtocol + 1) &&
                 (!memcmp(input_token.value, ntlmProtocol, sizeof ntlmProtocol))) {
@@ -705,14 +719,17 @@ main(int argc, char *const argv[])
         if (output_token.length) {
             spnegoToken = (const unsigned char *) output_token.value;
             spnegoTokenLength = output_token.length;
-            token = (char *) xmalloc((size_t)base64_encode_len((int)spnegoTokenLength));
+            token = (char *) xmalloc((size_t)base64_encode_len(spnegoTokenLength));
             if (token == NULL) {
                 debug((char *) "%s| %s: ERROR: Not enough memory\n", LogTime(), PROGRAM);
                 fprintf(stdout, "BH Not enough memory\n");
                 goto cleanup;
             }
-            base64_encode_str(token, base64_encode_len((int)spnegoTokenLength),
-                              (const char *) spnegoToken, (int)spnegoTokenLength);
+            struct base64_encode_ctx tokCtx;
+            base64_encode_init(&tokCtx);
+            size_t blen = base64_encode_update(&tokCtx, reinterpret_cast<uint8_t*>(token), spnegoTokenLength, reinterpret_cast<const uint8_t*>(spnegoToken));
+            blen += base64_encode_final(&tokCtx, reinterpret_cast<uint8_t*>(token)+blen);
+            token[blen] = '\0';
 
             if (check_gss_err(major_status, minor_status, "gss_accept_sec_context()", log, 1))
                 goto cleanup;
@@ -866,3 +883,4 @@ main(int argc, char *const argv[])
     }
 }
 #endif /* HAVE_GSSAPI */
+