]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dns_random.cc
Merge pull request #8223 from PowerDNS/omoerbeek-patch-1
[thirdparty/pdns.git] / pdns / dns_random.cc
index 4357af56258f803b7478fa8cba1163d14965014c..48b910c8f4f85d7c2a561f664de3c52fc37dd922 100644 (file)
@@ -210,7 +210,7 @@ void dns_random_init(const string& data __attribute__((unused)), bool force) {
 }
 
 /* Parts of this code come from arc4random_uniform */
-unsigned int dns_random(unsigned int upper_bound) {
+uint32_t dns_random(uint32_t upper_bound) {
   if (chosen_rng == RNG_UNINITIALIZED)
     dns_random_setup();
 
@@ -241,17 +241,19 @@ unsigned int dns_random(unsigned int upper_bound) {
     throw std::runtime_error("Unreachable at " __FILE__ ":" + boost::lexical_cast<std::string>(__LINE__)); // cannot be reached
   case RNG_SODIUM:
 #if defined(HAVE_RANDOMBYTES_STIR) && !defined(USE_URANDOM_ONLY)
-    return static_cast<unsigned int>(randombytes_uniform(static_cast<uint32_t>(upper_bound)));
+    return randombytes_uniform(upper_bound);
 #else
     throw std::runtime_error("Unreachable at " __FILE__ ":" + boost::lexical_cast<std::string>(__LINE__)); // cannot be reached
 #endif /* RND_SODIUM */
   case RNG_OPENSSL: {
 #if defined(HAVE_RAND_BYTES) && !defined(USE_URANDOM_ONLY)
-      unsigned int num=0;
-      while(num < min) {
+      uint32_t num = 0;
+      do {
         if (RAND_bytes(reinterpret_cast<unsigned char*>(&num), sizeof(num)) < 1)
           throw std::runtime_error("Openssl RNG was not seeded");
       }
+      while(num < min);
+
       return num % upper_bound;
 #else
       throw std::runtime_error("Unreachable at " __FILE__ ":" + boost::lexical_cast<std::string>(__LINE__)); // cannot be reached
@@ -259,11 +261,13 @@ unsigned int dns_random(unsigned int upper_bound) {
      }
   case RNG_GETRANDOM: {
 #if defined(HAVE_GETRANDOM) && !defined(USE_URANDOM_ONLY)
-      unsigned int num=0;
-      while(num < min) {
+      uint32_t num = 0;
+      do {
         if (getrandom(&num, sizeof(num), 0) != sizeof(num))
           throw std::runtime_error("getrandom() failed: " + std::string(strerror(errno)));
       }
+      while(num < min);
+
       return num % upper_bound;
 #else
       throw std::runtime_error("Unreachable at " __FILE__ ":" + boost::lexical_cast<std::string>(__LINE__)); // cannot be reached
@@ -271,25 +275,44 @@ unsigned int dns_random(unsigned int upper_bound) {
       }
   case RNG_ARC4RANDOM:
 #if defined(HAVE_ARC4RANDOM) && !defined(USE_URANDOM_ONLY)
-    return static_cast<unsigned int>(arc4random_uniform(static_cast<uint32_t>(upper_bound)));
+    return arc4random_uniform(upper_bound);
 #else
     throw std::runtime_error("Unreachable at " __FILE__ ":" + boost::lexical_cast<std::string>(__LINE__)); // cannot be reached
 #endif
   case RNG_URANDOM: {
-      unsigned int num = 0;
-      while(num < min) {
-        if (read(urandom_fd, &num, sizeof(num)) < 0) {
+      uint32_t num = 0;
+      size_t attempts = 5;
+      do {
+        ssize_t got = read(urandom_fd, &num, sizeof(num));
+        if (got < 0) {
+          if (errno == EINTR) {
+            continue;
+          }
+
           (void)close(urandom_fd);
           throw std::runtime_error("Cannot read random device");
         }
+        else if (static_cast<size_t>(got) != sizeof(num)) {
+          /* short read, let's retry */
+          if (attempts == 0) {
+            throw std::runtime_error("Too many short reads on random device");
+          }
+          attempts--;
+          continue;
+        }
       }
+      while(num < min);
+
       return num % upper_bound;
     }
 #if defined(HAVE_KISS_RNG)
   case RNG_KISS: {
-      unsigned int num = 0;
-      while(num < min)
+      uint32_t num = 0;
+      do {
         num = kiss_rand();
+      }
+      while(num < min);
+
       return num % upper_bound;
     }
 #endif