]> git.ipfire.org Git - thirdparty/wireguard-tools.git/commitdiff
ipc: cache windows lookups to avoid O(n^2) with nested lookups
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 24 Jun 2021 11:35:48 +0000 (13:35 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 20 Jul 2021 11:24:18 +0000 (13:24 +0200)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
src/ipc-uapi-windows.h
src/ipc-windows.h
src/wincompat/include/hashtable.h [new file with mode: 0644]

index 1aa08c4c8fac1ff67a778e346a4f919e8d287eb3..4d362d007338650372bd339631bfe1907b03b3d3 100644 (file)
@@ -10,6 +10,7 @@
 #include <stdio.h>
 #include <stdbool.h>
 #include <fcntl.h>
+#include <hashtable.h>
 
 static FILE *userspace_interface_file(const char *iface)
 {
@@ -113,6 +114,9 @@ err:
        return NULL;
 }
 
+static bool have_cached_interfaces;
+static struct hashtable cached_interfaces;
+
 static bool userspace_has_wireguard_interface(const char *iface)
 {
        char fname[MAX_PATH];
@@ -120,10 +124,13 @@ static bool userspace_has_wireguard_interface(const char *iface)
        HANDLE find_handle;
        bool ret = false;
 
+       if (have_cached_interfaces)
+               return hashtable_find_entry(&cached_interfaces, iface) != NULL;
+
        snprintf(fname, sizeof(fname), "ProtectedPrefix\\Administrators\\WireGuard\\%s", iface);
        find_handle = FindFirstFile("\\\\.\\pipe\\*", &find_data);
        if (find_handle == INVALID_HANDLE_VALUE)
-               return -GetLastError();
+               return -EIO;
        do {
                if (!strcmp(fname, find_data.cFileName)) {
                        ret = true;
@@ -139,18 +146,25 @@ static int userspace_get_wireguard_interfaces(struct string_list *list)
        static const char prefix[] = "ProtectedPrefix\\Administrators\\WireGuard\\";
        WIN32_FIND_DATA find_data;
        HANDLE find_handle;
+       char *iface;
        int ret = 0;
 
        find_handle = FindFirstFile("\\\\.\\pipe\\*", &find_data);
        if (find_handle == INVALID_HANDLE_VALUE)
-               return -GetLastError();
+               return -EIO;
        do {
                if (strncmp(prefix, find_data.cFileName, strlen(prefix)))
                        continue;
-               ret = string_list_add(list, find_data.cFileName + strlen(prefix));
+               iface = find_data.cFileName + strlen(prefix);
+               ret = string_list_add(list, iface);
                if (ret < 0)
                        goto out;
+               if (!hashtable_find_or_insert_entry(&cached_interfaces, iface)) {
+                       ret = -errno;
+                       goto out;
+               }
        } while (FindNextFile(find_handle, &find_data));
+       have_cached_interfaces = true;
 
 out:
        FindClose(find_handle);
index 14270c982ff2b0e1eb8a4d216b031294e9495c86..2382847d2d34be094d07f0e3fe325f0480c9170f 100644 (file)
 #include <ddk/ndisguid.h>
 #include <nci.h>
 #include <wireguard.h>
+#include <hashtable.h>
 
 #define IPC_SUPPORTS_KERNEL_INTERFACE
 
+static bool have_cached_kernel_interfaces;
+static struct hashtable cached_kernel_interfaces;
+
 static int kernel_get_wireguard_interfaces(struct string_list *list)
 {
        HDEVINFO dev_info = SetupDiGetClassDevsExW(&GUID_DEVCLASS_NET, NULL, NULL, DIGCF_PRESENT, NULL, NULL, NULL);
+       bool will_have_cached_kernel_interfaces = true;
 
        if (dev_info == INVALID_HANDLE_VALUE) {
                errno = EACCES;
@@ -33,6 +38,7 @@ static int kernel_get_wireguard_interfaces(struct string_list *list)
                HKEY key;
                GUID instance_id;
                char *interface_name;
+               struct hashtable_entry *entry;
 
                if (!SetupDiEnumDeviceInfo(dev_info, i, &dev_info_data)) {
                        if (GetLastError() == ERROR_NO_MORE_ITEMS)
@@ -105,7 +111,25 @@ static int kernel_get_wireguard_interfaces(struct string_list *list)
                }
 
                string_list_add(list, interface_name);
+
+               entry = hashtable_find_or_insert_entry(&cached_kernel_interfaces, interface_name);
                free(interface_name);
+               if (!entry)
+                       goto cleanup_entry;
+
+               if (SetupDiGetDeviceInstanceIdW(dev_info, &dev_info_data, NULL, 0, &buf_len) || GetLastError() != ERROR_INSUFFICIENT_BUFFER)
+                       goto cleanup_entry;
+               entry->value = calloc(sizeof(WCHAR), buf_len);
+               if (!entry->value)
+                       goto cleanup_entry;
+               if (!SetupDiGetDeviceInstanceIdW(dev_info, &dev_info_data, entry->value, buf_len, &buf_len)) {
+                       free(entry->value);
+                       entry->value = NULL;
+                       goto cleanup_entry;
+               }
+
+cleanup_entry:
+               will_have_cached_kernel_interfaces |= entry != NULL && entry->value != NULL;
 cleanup_buf:
                free(buf);
 cleanup_key:
@@ -113,15 +137,48 @@ cleanup_key:
 skip:;
        }
        SetupDiDestroyDeviceInfoList(dev_info);
+       have_cached_kernel_interfaces = will_have_cached_kernel_interfaces;
        return 0;
 }
 
 static HANDLE kernel_interface_handle(const char *iface)
 {
-       HDEVINFO dev_info = SetupDiGetClassDevsExW(&GUID_DEVCLASS_NET, NULL, NULL, DIGCF_PRESENT, NULL, NULL, NULL);
+       HDEVINFO dev_info;
        WCHAR *interfaces = NULL;
        HANDLE handle;
 
+       if (have_cached_kernel_interfaces) {
+               struct hashtable_entry *entry = hashtable_find_entry(&cached_kernel_interfaces, iface);
+               if (entry) {
+                       DWORD buf_len;
+                       if (CM_Get_Device_Interface_List_SizeW(
+                               &buf_len, (GUID *)&GUID_DEVINTERFACE_NET, (DEVINSTID_W)entry->value,
+                               CM_GET_DEVICE_INTERFACE_LIST_PRESENT) != CR_SUCCESS)
+                               goto err_hash;
+                       interfaces = calloc(buf_len, sizeof(*interfaces));
+                       if (!interfaces)
+                               goto err_hash;
+                       if (CM_Get_Device_Interface_ListW(
+                               (GUID *)&GUID_DEVINTERFACE_NET, (DEVINSTID_W)entry->value, interfaces, buf_len,
+                               CM_GET_DEVICE_INTERFACE_LIST_PRESENT) != CR_SUCCESS || !interfaces[0]) {
+                               free(interfaces);
+                               interfaces = NULL;
+                               goto err_hash;
+                       }
+                       handle = CreateFileW(interfaces, GENERIC_READ | GENERIC_WRITE,
+                                            FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, NULL,
+                                            OPEN_EXISTING, 0, NULL);
+                       free(interfaces);
+                       if (handle == INVALID_HANDLE_VALUE)
+                               goto err_hash;
+                       return handle;
+err_hash:
+                       errno = EACCES;
+                       return NULL;
+               }
+       }
+
+       dev_info = SetupDiGetClassDevsExW(&GUID_DEVCLASS_NET, NULL, NULL, DIGCF_PRESENT, NULL, NULL, NULL);
        if (dev_info == INVALID_HANDLE_VALUE)
                return NULL;
 
diff --git a/src/wincompat/include/hashtable.h b/src/wincompat/include/hashtable.h
new file mode 100644 (file)
index 0000000..bd83bbb
--- /dev/null
@@ -0,0 +1,61 @@
+/* SPDX-License-Identifier: GPL-2.0
+ *
+ * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+#ifndef _HASHTABLE_H
+#define _HASHTABLE_H
+
+#include <string.h>
+
+enum { HASHTABLE_ENTRY_BUCKETS_POW2 = 1 << 10 };
+
+struct hashtable_entry {
+       char *key;
+       void *value;
+       struct hashtable_entry *next;
+};
+
+struct hashtable {
+       struct hashtable_entry *entry_buckets[HASHTABLE_ENTRY_BUCKETS_POW2];
+};
+
+static unsigned int hashtable_bucket(const char *str)
+{
+       unsigned long hash = 5381;
+       char c;
+       while ((c = *str++))
+               hash = ((hash << 5) + hash) ^ c;
+       return hash & (HASHTABLE_ENTRY_BUCKETS_POW2 - 1);
+}
+
+static struct hashtable_entry *hashtable_find_entry(struct hashtable *hashtable, const char *key)
+{
+       struct hashtable_entry *entry;
+       for (entry = hashtable->entry_buckets[hashtable_bucket(key)]; entry; entry = entry->next) {
+               if (!strcmp(entry->key, key))
+                       return entry;
+       }
+       return NULL;
+}
+
+static struct hashtable_entry *hashtable_find_or_insert_entry(struct hashtable *hashtable, const char *key)
+{
+       struct hashtable_entry **entry;
+       for (entry = &hashtable->entry_buckets[hashtable_bucket(key)]; *entry; entry = &(*entry)->next) {
+               if (!strcmp((*entry)->key, key))
+                       return *entry;
+       }
+       *entry = calloc(1, sizeof(**entry));
+       if (!*entry)
+               return NULL;
+       (*entry)->key = strdup(key);
+       if (!(*entry)->key) {
+               free(*entry);
+               *entry = NULL;
+               return NULL;
+       }
+       return *entry;
+}
+
+#endif