]> git.ipfire.org Git - thirdparty/squid.git/blobdiff - helpers/negotiate_auth/wrapper/negotiate_wrapper.cc
SourceFormat Enforcement
[thirdparty/squid.git] / helpers / negotiate_auth / wrapper / negotiate_wrapper.cc
index 6570d1ace6a7f451865c80d58872cecd63f0cb6f..0d90701ceae3b5e03cdb4e1365f03d2923f725cd 100644 (file)
@@ -1,3 +1,11 @@
+/*
+ * Copyright (C) 1996-2015 The Squid Software Foundation and contributors
+ *
+ * Squid software is distributed under GPLv2+ license and includes
+ * contributions from numerous individuals and organizations.
+ * Please see the COPYING and CONTRIBUTORS files for details.
+ */
+
 /*
  * -----------------------------------------------------------------------------
  *
  *
  * -----------------------------------------------------------------------------
  */
-/*
- * Hosted at http://sourceforge.net/projects/squidkerbauth
- */
 
-#include "config.h"
-#include "nw_base64.h"
+#include "squid.h"
+#include "base64.h"
 
-#if HAVE_STRING_H
-#include <string.h>
-#endif
-#if HAVE_STDIO_H
-#include <stdio.h>
-#endif
-#if HAVE_STDLIB_H
-#include <stdlib.h>
-#endif
+#include <cerrno>
+#include <cstring>
+#include <cstdlib>
+#include <ctime>
 #if HAVE_NETDB_H
 #include <netdb.h>
 #endif
 #if HAVE_UNISTD_H
 #include <unistd.h>
 #endif
-#if HAVE_TIME_H
-#include <time.h>
-#endif
-#if HAVE_SYS_TIME_H
-#include <sys/time.h>
-#endif
-#if HAVE_ERRNO_H
-#include <errno.h>
-#endif
 
 #if !defined(HAVE_DECL_XMALLOC) || !HAVE_DECL_XMALLOC
 #define xmalloc malloc
@@ -108,14 +99,12 @@ main(int argc, char *const argv[])
     char tbuff[MAX_AUTHTOKEN_LEN];
     char buff[MAX_AUTHTOKEN_LEN+2];
     char *c;
-    static int err = 0;
     int debug = 0;
     int length;
     int nstart = 0, kstart = 0;
     int nend = 0, kend = 0;
-    char *token;
+    uint8_t *token;
     char **nargs, **kargs;
-    int i,j;
     int fpid;
     FILE *FDKIN,*FDKOUT;
     FILE *FDNIN,*FDNOUT;
@@ -132,13 +121,13 @@ main(int argc, char *const argv[])
         return 0;
     }
 
-    j = 1;
+    int j = 1;
     if (!strncasecmp(argv[1],"-d",2)) {
         debug = 1;
         j = 2;
     }
 
-    for (i=j; i<argc; i++) {
+    for (int i=j; i<argc; ++i) {
         if (!strncasecmp(argv[i],"--ntlm",6))
             nstart = i;
         if (!strncasecmp(argv[i],"--kerberos",10))
@@ -168,7 +157,7 @@ main(int argc, char *const argv[])
     nargs[nend-nstart]=NULL;
     if (debug) {
         fprintf(stderr, "%s| %s: NTLM command: ", LogTime(), PROGRAM);
-        for (i=0; i<nend-nstart; i++)
+        for (int i=0; i<nend-nstart; ++i)
             fprintf(stderr, "%s ", nargs[i]);
         fprintf(stderr, "\n");
     }
@@ -180,7 +169,7 @@ main(int argc, char *const argv[])
     kargs[kend-kstart]=NULL;
     if (debug) {
         fprintf(stderr, "%s| %s: Kerberos command: ", LogTime(), PROGRAM);
-        for (i=0; i<kend-kstart; i++)
+        for (int i=0; i<kend-kstart; ++i)
             fprintf(stderr, "%s ", kargs[i]);
         fprintf(stderr, "\n");
     }
@@ -198,7 +187,6 @@ main(int argc, char *const argv[])
         return 1;
     }
 
-
     if  (( fpid = vfork()) < 0 ) {
         fprintf(stderr, "%s| %s: Failed first fork\n", LogTime(), PROGRAM);
         return 1;
@@ -279,7 +267,6 @@ main(int argc, char *const argv[])
     setbuf(FDNIN, NULL);
     setbuf(FDNOUT, NULL);
 
-
     while (1) {
         if (fgets(buf, sizeof(buf) - 1, stdin) == NULL) {
             if (ferror(stdin)) {
@@ -299,20 +286,16 @@ main(int argc, char *const argv[])
         if (c) {
             *c = '\0';
             length = c - buf;
+            if (debug)
+                fprintf(stderr, "%s| %s: Got '%s' from squid (length: %d).\n",
+                        LogTime(), PROGRAM, buf, length);
         } else {
-            err = 1;
-        }
-        if (err) {
             if (debug)
                 fprintf(stderr, "%s| %s: Oversized message\n", LogTime(),
                         PROGRAM);
             fprintf(stdout, "BH Oversized message\n");
-            err = 0;
             continue;
         }
-        if (debug)
-            fprintf(stderr, "%s| %s: Got '%s' from squid (length: %d).\n",
-                    LogTime(), PROGRAM, buf, length);
 
         if (buf[0] == '\0') {
             if (debug)
@@ -346,17 +329,28 @@ main(int argc, char *const argv[])
             fprintf(stdout, "BH Invalid negotiate request\n");
             continue;
         }
-        length = nw_base64_decode_len(buf + 3);
+        length = BASE64_DECODE_LENGTH(strlen(buf+3));
         if (debug)
             fprintf(stderr, "%s| %s: Decode '%s' (decoded length: %d).\n",
                     LogTime(), PROGRAM, buf + 3, (int) length);
 
-        if ((token = (char *)xmalloc(length)) == NULL) {
+        if ((token = static_cast<uint8_t *>(xmalloc(length))) == NULL) {
             fprintf(stderr, "%s| %s: Error allocating memory for token\n", LogTime(), PROGRAM);
             return 1;
         }
 
-        nw_base64_decode(token, buf + 3, length);
+        struct base64_decode_ctx ctx;
+        base64_decode_init(&ctx);
+        size_t dstLen = 0;
+        if (!base64_decode_update(&ctx, &dstLen, token, strlen(buf+3), reinterpret_cast<const uint8_t*>(buf+3)) ||
+                !base64_decode_final(&ctx)) {
+            if (debug)
+                fprintf(stderr, "%s| %s: Invalid base64 token [%s]\n", LogTime(), PROGRAM, buf+3);
+            fprintf(stdout, "BH Invalid negotiate request token\n");
+            continue;
+        }
+        length = dstLen;
+        token[dstLen] = '\0';
 
         if ((static_cast<size_t>(length) >= sizeof(ntlmProtocol) + 1) &&
                 (!memcmp(token, ntlmProtocol, sizeof ntlmProtocol))) {
@@ -386,13 +380,13 @@ main(int argc, char *const argv[])
             if (strlen(tbuff) >= 3 && (!strncmp(tbuff,"AF ",3) || !strncmp(tbuff,"NA ",3))) {
                 strncpy(buff,tbuff,3);
                 buff[3]='=';
-                for (unsigned int i=2; i<=strlen(tbuff); i++)
+                for (unsigned int i=2; i<=strlen(tbuff); ++i)
                     buff[i+2] = tbuff[i];
             } else {
                 strcpy(buff,tbuff);
             }
         } else {
-            free(token);
+            xfree(token);
             if (debug)
                 fprintf(stderr, "%s| %s: received Kerberos token\n",
                         LogTime(), PROGRAM);
@@ -418,3 +412,4 @@ main(int argc, char *const argv[])
 
     return 1;
 }
+