]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn: implement RIO for fast Windows UDP sockets
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 22 Feb 2021 17:47:41 +0000 (18:47 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 25 Feb 2021 14:08:08 +0000 (15:08 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
conn/bind_std.go
conn/bind_windows.go [new file with mode: 0644]
conn/boundif_windows.go [deleted file]
conn/default.go
conn/winrio/rio_windows.go [new file with mode: 0644]
device/queueconstants_default.go
device/queueconstants_windows.go [new file with mode: 0644]
go.mod
go.sum

index 193c4fed3313395b8771c0c6f3355dae2f0d3b26..28d14643dee2cb47d8b451a1673194b064770dd8 100644 (file)
@@ -128,6 +128,8 @@ func (bind *StdNetBind) Close() error {
                err2 = bind.ipv6.Close()
                bind.ipv6 = nil
        }
+       bind.blackhole4 = false
+       bind.blackhole6 = false
        if err1 != nil {
                return err1
        }
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
new file mode 100644 (file)
index 0000000..1e2712e
--- /dev/null
@@ -0,0 +1,581 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+       "encoding/binary"
+       "io"
+       "net"
+       "strconv"
+       "sync"
+       "sync/atomic"
+       "unsafe"
+
+       "golang.org/x/sys/windows"
+
+       "golang.zx2c4.com/wireguard/conn/winrio"
+)
+
+const (
+       packetsPerRing = 1024
+       bytesPerPacket = 2048 - 32
+       receiveSpins   = 15
+)
+
+type ringPacket struct {
+       addr WinRingEndpoint
+       data [bytesPerPacket]byte
+}
+
+type ringBuffer struct {
+       packets    uintptr
+       head, tail uint32
+       id         winrio.BufferId
+       iocp       windows.Handle
+       isFull     bool
+       cq         winrio.Cq
+       mu         sync.Mutex
+       overlapped windows.Overlapped
+}
+
+func (rb *ringBuffer) Push() *ringPacket {
+       for rb.isFull {
+               panic("ring is full")
+       }
+       ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
+       rb.tail += 1
+       if rb.tail == rb.head {
+               rb.isFull = true
+       }
+       return ret
+}
+
+func (rb *ringBuffer) Return(count uint32) {
+       if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
+               return
+       }
+       rb.head += count
+       rb.isFull = false
+}
+
+type afWinRingBind struct {
+       sock      windows.Handle
+       rx, tx    ringBuffer
+       rq        winrio.Rq
+       mu        sync.Mutex
+       blackhole bool
+}
+
+// WinRingBind uses Windows registered I/O for fast ring buffered networking.
+type WinRingBind struct {
+       v4, v6 afWinRingBind
+       mu     sync.RWMutex
+       isOpen uint32
+}
+
+func NewDefaultBind() Bind { return NewWinRingBind() }
+
+func NewWinRingBind() Bind {
+       if !winrio.Initialize() {
+               return NewStdNetBind()
+       }
+       return new(WinRingBind)
+}
+
+type WinRingEndpoint struct {
+       family uint16
+       data   [30]byte
+}
+
+var _ Bind = (*WinRingBind)(nil)
+var _ Endpoint = (*WinRingEndpoint)(nil)
+
+func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
+       host, port, err := net.SplitHostPort(s)
+       if err != nil {
+               return nil, err
+       }
+       host16, err := windows.UTF16PtrFromString(host)
+       if err != nil {
+               return nil, err
+       }
+       port16, err := windows.UTF16PtrFromString(port)
+       if err != nil {
+               return nil, err
+       }
+       hints := windows.AddrinfoW{
+               Flags:    windows.AI_NUMERICHOST,
+               Family:   windows.AF_UNSPEC,
+               Socktype: windows.SOCK_DGRAM,
+               Protocol: windows.IPPROTO_UDP,
+       }
+       var addrinfo *windows.AddrinfoW
+       err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
+       if err != nil {
+               return nil, err
+       }
+       defer windows.FreeAddrInfoW(addrinfo)
+       if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
+               return nil, windows.ERROR_INVALID_ADDRESS
+       }
+       var src []byte
+       var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
+       unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen))
+       copy(dst[:], src)
+       return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
+}
+
+func (*WinRingEndpoint) ClearSrc() {}
+
+func (e *WinRingEndpoint) DstIP() net.IP {
+       switch e.family {
+       case windows.AF_INET:
+               return append([]byte{}, e.data[2:6]...)
+       case windows.AF_INET6:
+               return append([]byte{}, e.data[6:22]...)
+       }
+       return nil
+}
+
+func (e *WinRingEndpoint) SrcIP() net.IP {
+       return nil // not supported
+}
+
+func (e *WinRingEndpoint) DstToBytes() []byte {
+       switch e.family {
+       case windows.AF_INET:
+               b := make([]byte, 0, 6)
+               b = append(b, e.data[2:6]...)
+               b = append(b, e.data[1], e.data[0])
+               return b
+       case windows.AF_INET6:
+               b := make([]byte, 0, 18)
+               b = append(b, e.data[6:22]...)
+               b = append(b, e.data[1], e.data[0])
+               return b
+       }
+       return nil
+}
+
+func (e *WinRingEndpoint) DstToString() string {
+       switch e.family {
+       case windows.AF_INET:
+               addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
+               return addr.String()
+       case windows.AF_INET6:
+               var zone string
+               if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
+                       zone = strconv.FormatUint(uint64(scope), 10)
+               }
+               addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
+               return addr.String()
+       }
+       return ""
+}
+
+func (e *WinRingEndpoint) SrcToString() string {
+       return ""
+}
+
+func (ring *ringBuffer) CloseAndZero() {
+       if ring.cq != 0 {
+               winrio.CloseCompletionQueue(ring.cq)
+               ring.cq = 0
+       }
+       if ring.iocp != 0 {
+               windows.CloseHandle(ring.iocp)
+               ring.iocp = 0
+       }
+       if ring.id != 0 {
+               winrio.DeregisterBuffer(ring.id)
+               ring.id = 0
+       }
+       if ring.packets != 0 {
+               windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
+               ring.packets = 0
+       }
+}
+
+func (bind *afWinRingBind) CloseAndZero() {
+       bind.rx.CloseAndZero()
+       bind.tx.CloseAndZero()
+       if bind.sock != 0 {
+               windows.CloseHandle(bind.sock)
+               bind.sock = 0
+       }
+       bind.blackhole = false
+}
+
+func (bind *WinRingBind) closeAndZero() {
+       atomic.StoreUint32(&bind.isOpen, 0)
+       bind.v4.CloseAndZero()
+       bind.v6.CloseAndZero()
+}
+
+func (ring *ringBuffer) Open() error {
+       var err error
+       packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
+       ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+       if err != nil {
+               return err
+       }
+       ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
+       if err != nil {
+               return err
+       }
+       ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+       if err != nil {
+               return err
+       }
+       ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
+       var err error
+       bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+       if err != nil {
+               return nil, err
+       }
+       err = bind.rx.Open()
+       if err != nil {
+               return nil, err
+       }
+       err = bind.tx.Open()
+       if err != nil {
+               return nil, err
+       }
+       bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
+       if err != nil {
+               return nil, err
+       }
+       err = windows.Bind(bind.sock, sa)
+       if err != nil {
+               return nil, err
+       }
+       sa, err = windows.Getsockname(bind.sock)
+       if err != nil {
+               return nil, err
+       }
+       return sa, nil
+}
+
+func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+       defer func() {
+               if err != nil {
+                       bind.closeAndZero()
+               }
+       }()
+       if atomic.LoadUint32(&bind.isOpen) != 0 {
+               return 0, ErrBindAlreadyOpen
+       }
+       var sa windows.Sockaddr
+       sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
+       if err != nil {
+               return 0, err
+       }
+       sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
+       if err != nil {
+               return 0, err
+       }
+       selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
+       for i := 0; i < packetsPerRing; i++ {
+               err = bind.v4.InsertReceiveRequest()
+               if err != nil {
+                       return 0, err
+               }
+               err = bind.v6.InsertReceiveRequest()
+               if err != nil {
+                       return 0, err
+               }
+       }
+       atomic.StoreUint32(&bind.isOpen, 1)
+       return
+}
+
+func (bind *WinRingBind) Close() error {
+       bind.mu.RLock()
+       if atomic.LoadUint32(&bind.isOpen) != 1 {
+               bind.mu.RUnlock()
+               return nil
+       }
+       atomic.StoreUint32(&bind.isOpen, 2)
+       windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
+       windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
+       windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
+       windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
+       bind.mu.RUnlock()
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+       bind.closeAndZero()
+       return nil
+}
+
+func (bind *WinRingBind) SetMark(mark uint32) error {
+       return nil
+}
+
+func (bind *afWinRingBind) InsertReceiveRequest() error {
+       packet := bind.rx.Push()
+       dataBuffer := &winrio.Buffer{
+               Id:     bind.rx.id,
+               Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
+               Length: uint32(len(packet.data)),
+       }
+       addressBuffer := &winrio.Buffer{
+               Id:     bind.rx.id,
+               Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
+               Length: uint32(unsafe.Sizeof(packet.addr)),
+       }
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+       return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
+}
+
+//go:linkname procyield runtime.procyield
+func procyield(cycles uint32)
+
+func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
+       if atomic.LoadUint32(isOpen) != 1 {
+               return 0, nil, net.ErrClosed
+       }
+       bind.rx.mu.Lock()
+       defer bind.rx.mu.Unlock()
+       var count uint32
+       var results [1]winrio.Result
+       for tries := 0; count == 0 && tries < receiveSpins; tries++ {
+               if tries > 0 {
+                       if atomic.LoadUint32(isOpen) != 1 {
+                               return 0, nil, net.ErrClosed
+                       }
+                       procyield(1)
+               }
+               count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+       }
+       if count == 0 {
+               err := winrio.Notify(bind.rx.cq)
+               if err != nil {
+                       return 0, nil, err
+               }
+               var bytes uint32
+               var key uintptr
+               var overlapped *windows.Overlapped
+               err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+               if err != nil {
+                       return 0, nil, err
+               }
+               if atomic.LoadUint32(isOpen) != 1 {
+                       return 0, nil, net.ErrClosed
+               }
+               count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+               if count == 0 {
+                       return 0, nil, io.ErrNoProgress
+
+               }
+       }
+       bind.rx.Return(1)
+       err := bind.InsertReceiveRequest()
+       if err != nil {
+               return 0, nil, err
+       }
+       if results[0].Status != 0 {
+               return 0, nil, windows.Errno(results[0].Status)
+       }
+       packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
+       ep := packet.addr
+       n := copy(buf, packet.data[:results[0].BytesTransferred])
+       return n, &ep, nil
+}
+
+func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) {
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
+       return bind.v4.Receive(buf, &bind.isOpen)
+}
+
+func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) {
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
+       return bind.v6.Receive(buf, &bind.isOpen)
+}
+
+func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
+       if atomic.LoadUint32(isOpen) != 1 {
+               return net.ErrClosed
+       }
+       if len(buf) > bytesPerPacket {
+               return io.ErrShortBuffer
+       }
+       bind.tx.mu.Lock()
+       defer bind.tx.mu.Unlock()
+       var results [packetsPerRing]winrio.Result
+       count := winrio.DequeueCompletion(bind.tx.cq, results[:])
+       if count == 0 && bind.tx.isFull {
+               err := winrio.Notify(bind.tx.cq)
+               if err != nil {
+                       return err
+               }
+               var bytes uint32
+               var key uintptr
+               var overlapped *windows.Overlapped
+               err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+               if err != nil {
+                       return err
+               }
+               if atomic.LoadUint32(isOpen) != 1 {
+                       return net.ErrClosed
+               }
+               count = winrio.DequeueCompletion(bind.tx.cq, results[:])
+               if count == 0 {
+                       return io.ErrNoProgress
+               }
+       }
+       if count > 0 {
+               bind.tx.Return(count)
+       }
+       packet := bind.tx.Push()
+       packet.addr = *nend
+       copy(packet.data[:], buf)
+       dataBuffer := &winrio.Buffer{
+               Id:     bind.tx.id,
+               Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
+               Length: uint32(len(buf)),
+       }
+       addressBuffer := &winrio.Buffer{
+               Id:     bind.tx.id,
+               Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
+               Length: uint32(unsafe.Sizeof(packet.addr)),
+       }
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+       return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
+}
+
+func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
+       nend, ok := endpoint.(*WinRingEndpoint)
+       if !ok {
+               return ErrWrongEndpointType
+       }
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
+       switch nend.family {
+       case windows.AF_INET:
+               if bind.v4.blackhole {
+                       return nil
+               }
+               return bind.v4.Send(buf, nend, &bind.isOpen)
+       case windows.AF_INET6:
+               if bind.v6.blackhole {
+                       return nil
+               }
+               return bind.v6.Send(buf, nend, &bind.isOpen)
+       }
+       return nil
+}
+
+func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+       sysconn, err := bind.ipv4.SyscallConn()
+       if err != nil {
+               return err
+       }
+       err2 := sysconn.Control(func(fd uintptr) {
+               err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
+       })
+       if err2 != nil {
+               return err2
+       }
+       if err != nil {
+               return err
+       }
+       bind.blackhole4 = blackhole
+       return nil
+}
+
+func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+       sysconn, err := bind.ipv6.SyscallConn()
+       if err != nil {
+               return err
+       }
+       err2 := sysconn.Control(func(fd uintptr) {
+               err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
+       })
+       if err2 != nil {
+               return err2
+       }
+       if err != nil {
+               return err
+       }
+       bind.blackhole6 = blackhole
+       return nil
+}
+func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
+       if atomic.LoadUint32(&bind.isOpen) != 1 {
+               return net.ErrClosed
+       }
+       err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
+       if err != nil {
+               return err
+       }
+       bind.v4.blackhole = blackhole
+       return nil
+}
+
+func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
+       if atomic.LoadUint32(&bind.isOpen) != 1 {
+               return net.ErrClosed
+       }
+       err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
+       if err != nil {
+               return err
+       }
+       bind.v6.blackhole = blackhole
+       return nil
+}
+
+func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
+       const IP_UNICAST_IF = 31
+       /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
+       var bytes [4]byte
+       binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
+       interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
+       err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
+       const IPV6_UNICAST_IF = 31
+       return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
+}
+
+// unsafeSlice updates the slice slicePtr to be a slice
+// referencing the provided data with its length & capacity set to
+// lenCap.
+//
+// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
+// update callers to use unsafe.Slice instead of this.
+func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
+       type sliceHeader struct {
+               Data unsafe.Pointer
+               Len  int
+               Cap  int
+       }
+       h := (*sliceHeader)(slicePtr)
+       h.Data = data
+       h.Len = lenCap
+       h.Cap = lenCap
+}
diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go
deleted file mode 100644 (file)
index 6f6fdd8..0000000
+++ /dev/null
@@ -1,59 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
-       "encoding/binary"
-       "unsafe"
-
-       "golang.org/x/sys/windows"
-)
-
-const (
-       sockoptIP_UNICAST_IF   = 31
-       sockoptIPV6_UNICAST_IF = 31
-)
-
-func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
-       /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
-       bytes := make([]byte, 4)
-       binary.BigEndian.PutUint32(bytes, interfaceIndex)
-       interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
-
-       sysconn, err := bind.ipv4.SyscallConn()
-       if err != nil {
-               return err
-       }
-       err2 := sysconn.Control(func(fd uintptr) {
-               err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
-       })
-       if err2 != nil {
-               return err2
-       }
-       if err != nil {
-               return err
-       }
-       bind.blackhole4 = blackhole
-       return nil
-}
-
-func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
-       sysconn, err := bind.ipv6.SyscallConn()
-       if err != nil {
-               return err
-       }
-       err2 := sysconn.Control(func(fd uintptr) {
-               err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
-       })
-       if err2 != nil {
-               return err2
-       }
-       if err != nil {
-               return err
-       }
-       bind.blackhole6 = blackhole
-       return nil
-}
index cd9bfb04824e53d353298e2dad662427982a2c08..161454a0e6c2b8125588447f3b122c75e5c2319e 100644 (file)
@@ -1,4 +1,4 @@
-// +build !linux
+// +build !linux,!windows
 
 /* SPDX-License-Identifier: MIT
  *
diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go
new file mode 100644 (file)
index 0000000..1785a02
--- /dev/null
@@ -0,0 +1,243 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package winrio
+
+import (
+       "log"
+       "sync"
+       "syscall"
+       "unsafe"
+
+       "golang.org/x/sys/windows"
+)
+
+const (
+       MsgDontNotify = 1
+       MsgDefer      = 2
+       MsgWaitAll    = 4
+       MsgCommitOnly = 8
+
+       MaxCqSize = 0x8000000
+
+       invalidBufferId = 0xFFFFFFFF
+       invalidCq       = 0
+       invalidRq       = 0
+       corruptCq       = 0xFFFFFFFF
+)
+
+var extensionFunctionTable struct {
+       cbSize                   uint32
+       rioReceive               uintptr
+       rioReceiveEx             uintptr
+       rioSend                  uintptr
+       rioSendEx                uintptr
+       rioCloseCompletionQueue  uintptr
+       rioCreateCompletionQueue uintptr
+       rioCreateRequestQueue    uintptr
+       rioDequeueCompletion     uintptr
+       rioDeregisterBuffer      uintptr
+       rioNotify                uintptr
+       rioRegisterBuffer        uintptr
+       rioResizeCompletionQueue uintptr
+       rioResizeRequestQueue    uintptr
+}
+
+type Cq uintptr
+
+type Rq uintptr
+
+type BufferId uintptr
+
+type Buffer struct {
+       Id     BufferId
+       Offset uint32
+       Length uint32
+}
+
+type Result struct {
+       Status           int32
+       BytesTransferred uint32
+       SocketContext    uint64
+       RequestContext   uint64
+}
+
+type notificationCompletionType uint32
+
+const (
+       eventCompletion notificationCompletionType = 1
+       iocpCompletion  notificationCompletionType = 2
+)
+
+type eventNotificationCompletion struct {
+       completionType notificationCompletionType
+       event          windows.Handle
+       notifyReset    uint32
+}
+
+type iocpNotificationCompletion struct {
+       completionType notificationCompletionType
+       iocp           windows.Handle
+       key            uintptr
+       overlapped     *windows.Overlapped
+}
+
+var initialized sync.Once
+var available bool
+
+func Initialize() bool {
+       initialized.Do(func() {
+               var (
+                       err    error
+                       socket windows.Handle
+                       cq     Cq
+               )
+               defer func() {
+                       if err == nil {
+                               return
+                       }
+                       if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
+                               return
+                       }
+                       log.Printf("Registered I/O is unavailable: %v", err)
+               }()
+               socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+               if err != nil {
+                       return
+               }
+               defer windows.CloseHandle(socket)
+               var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
+               const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
+               ob := uint32(0)
+               err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
+                       (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
+                       (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
+                       &ob, nil, 0)
+               if err != nil {
+                       return
+               }
+               // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
+               // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
+               cq, err = CreatePolledCompletionQueue(2)
+               if err != nil {
+                       return
+               }
+               defer CloseCompletionQueue(cq)
+               _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
+               if err != nil {
+                       return
+               }
+               available = true
+       })
+       return available
+}
+
+func Socket(af, typ, proto int32) (windows.Handle, error) {
+       return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
+}
+
+func CloseCompletionQueue(cq Cq) {
+       _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
+}
+
+func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
+       notificationCompletion := &eventNotificationCompletion{
+               completionType: eventCompletion,
+               event:          event,
+       }
+       if notifyReset {
+               notificationCompletion.notifyReset = 1
+       }
+       ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+       if ret == invalidCq {
+               return 0, err
+       }
+       return Cq(ret), nil
+}
+
+func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
+       notificationCompletion := &iocpNotificationCompletion{
+               completionType: iocpCompletion,
+               iocp:           iocp,
+               overlapped:     overlapped,
+       }
+       ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+       if ret == invalidCq {
+               return 0, err
+       }
+       return Cq(ret), nil
+}
+
+func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
+       ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
+       if ret == invalidCq {
+               return 0, err
+       }
+       return Cq(ret), nil
+}
+
+func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
+       ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
+       if ret == invalidRq {
+               return 0, err
+       }
+       return Rq(ret), nil
+}
+
+func DequeueCompletion(cq Cq, results []Result) uint32 {
+       var array uintptr
+       if len(results) > 0 {
+               array = uintptr(unsafe.Pointer(&results[0]))
+       }
+       ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
+       if ret == corruptCq {
+               panic("cq is corrupt")
+       }
+       return uint32(ret)
+}
+
+func DeregisterBuffer(id BufferId) {
+       _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
+}
+
+func RegisterBuffer(buffer []byte) (BufferId, error) {
+       var buf unsafe.Pointer
+       if len(buffer) > 0 {
+               buf = unsafe.Pointer(&buffer[0])
+       }
+       return RegisterPointer(buf, uint32(len(buffer)))
+}
+
+func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
+       ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
+       if ret == invalidBufferId {
+               return 0, err
+       }
+       return BufferId(ret), nil
+}
+
+func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+       ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+       if ret == 0 {
+               return err
+       }
+       return nil
+}
+
+func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+       ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+       if ret == 0 {
+               return err
+       }
+       return nil
+}
+
+func Notify(cq Cq) error {
+       ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
+       if ret != 0 {
+               return windows.Errno(ret)
+       }
+       return nil
+}
index 773c2cada1700cdd81f15e9e2a3e10ebe66bb7ed..d5c6927bcad3e4fd8297ac57a5c2c403c1675d0a 100644 (file)
@@ -1,4 +1,4 @@
-// +build !android,!ios
+// +build !android,!ios,!windows
 
 /* SPDX-License-Identifier: MIT
  *
diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go
new file mode 100644 (file)
index 0000000..e330a6b
--- /dev/null
@@ -0,0 +1,15 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+const (
+       QueueStagedSize            = 128
+       QueueOutboundSize          = 1024
+       QueueInboundSize           = 1024
+       QueueHandshakeSize         = 1024
+       MaxSegmentSize             = 2048 - 32 // largest possible UDP datagram
+       PreallocatedBuffersPerPool = 0         // Disable and allow for infinite memory growth
+)
diff --git a/go.mod b/go.mod
index 11b6c7f23db19d150a79dc49819ae17e60cc71d5..0aa27d850936cb004bda913f12a3847c9b672a6f 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -3,7 +3,7 @@ module golang.zx2c4.com/wireguard
 go 1.16
 
 require (
-       golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad
-       golang.org/x/net v0.0.0-20201224014010-6772e930b67b
-       golang.org/x/sys v0.0.0-20210105210732-16f7687f5001
+       golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
+       golang.org/x/net v0.0.0-20210224082022-3d97a244fca7
+       golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7
 )
diff --git a/go.sum b/go.sum
index 62a8501763ae49e21c6b92be4655071e9309a880..1ccf774cac244ac0c3c39a552603705d8a968f96 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -1,17 +1,16 @@
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY=
-golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
+golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
+golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw=
-golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 h1:OgUuv8lsRpBibGNbSizVwKWlysjaNzmC9gYMhPVfqFM=
+golang.org/x/net v0.0.0-20210224082022-3d97a244fca7/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210105210732-16f7687f5001 h1:/dSxr6gT0FNI1MO5WLJo8mTmItROeOKTkDn+7OwWBos=
-golang.org/x/sys v0.0.0-20210105210732-16f7687f5001/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7 h1:pk3Y+QnSKjMLfO/HIqzn/Zvv3/IHjRPhwblrmUuodzw=
+golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=