]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
namedpipe: rename from winpipe to keep in sync with CL299009
authorJason A. Donenfeld <Jason@zx2c4.com>
Sat, 30 Oct 2021 00:39:56 +0000 (02:39 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 4 Nov 2021 11:53:52 +0000 (12:53 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
ipc/namedpipe/file.go [moved from ipc/winpipe/file.go with 96% similarity]
ipc/namedpipe/namedpipe.go [moved from ipc/winpipe/winpipe.go with 83% similarity]
ipc/namedpipe/namedpipe_test.go [moved from ipc/winpipe/winpipe_test.go with 81% similarity]
ipc/uapi_windows.go
tun/wintun/dll_windows.go [deleted file]
tun/wintun/session_windows.go [deleted file]
tun/wintun/wintun_windows.go [deleted file]

similarity index 96%
rename from ipc/winpipe/file.go
rename to ipc/namedpipe/file.go
index 319565f826512965483ba78d9aa3f7caa226d61d..9c2481d96864192a5d54a7c6d3f3a10f5c45d587 100644 (file)
@@ -1,12 +1,12 @@
-//go:build windows
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
 
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
+//go:build windows
+// +build windows
 
-package winpipe
+package namedpipe
 
 import (
        "io"
similarity index 83%
rename from ipc/winpipe/winpipe.go
rename to ipc/namedpipe/namedpipe.go
index e3719d69e3880959da1d05575e136d66d50f6d4c..6db5ea31e03315dcc503a0bb70882fb2c660c306 100644 (file)
@@ -1,13 +1,13 @@
-//go:build windows
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
 
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
+//go:build windows
+// +build windows
 
-// Package winpipe implements a net.Conn and net.Listener around Windows named pipes.
-package winpipe
+// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
+package namedpipe
 
 import (
        "context"
@@ -15,6 +15,7 @@ import (
        "net"
        "os"
        "runtime"
+       "sync/atomic"
        "time"
        "unsafe"
 
@@ -28,7 +29,7 @@ type pipe struct {
 
 type messageBytePipe struct {
        pipe
-       writeClosed bool
+       writeClosed int32
        readEOF     bool
 }
 
@@ -50,25 +51,26 @@ func (f *pipe) SetDeadline(t time.Time) error {
 
 // CloseWrite closes the write side of a message pipe in byte mode.
 func (f *messageBytePipe) CloseWrite() error {
-       if f.writeClosed {
+       if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
                return io.ErrClosedPipe
        }
        err := f.file.Flush()
        if err != nil {
+               atomic.StoreInt32(&f.writeClosed, 0)
                return err
        }
        _, err = f.file.Write(nil)
        if err != nil {
+               atomic.StoreInt32(&f.writeClosed, 0)
                return err
        }
-       f.writeClosed = true
        return nil
 }
 
 // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
 // they are used to implement CloseWrite.
 func (f *messageBytePipe) Write(b []byte) (int, error) {
-       if f.writeClosed {
+       if atomic.LoadInt32(&f.writeClosed) != 0 {
                return 0, io.ErrClosedPipe
        }
        if len(b) == 0 {
@@ -142,30 +144,24 @@ type DialConfig struct {
        ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
 }
 
-// Dial connects to the specified named pipe by path, timing out if the connection
-// takes longer than the specified duration. If timeout is nil, then we use
-// a default timeout of 2 seconds.
-func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) {
-       var absTimeout time.Time
-       if timeout != nil {
-               absTimeout = time.Now().Add(*timeout)
-       } else {
-               absTimeout = time.Now().Add(2 * time.Second)
+// DialTimeout connects to the specified named pipe by path, timing out if the
+// connection  takes longer than the specified duration. If timeout is zero, then
+// we use a default timeout of 2 seconds.
+func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+       if timeout == 0 {
+               timeout = time.Second * 2
        }
+       absTimeout := time.Now().Add(timeout)
        ctx, _ := context.WithDeadline(context.Background(), absTimeout)
-       conn, err := DialContext(ctx, path, config)
+       conn, err := config.DialContext(ctx, path)
        if err == context.DeadlineExceeded {
                return nil, os.ErrDeadlineExceeded
        }
        return conn, err
 }
 
-// DialContext attempts to connect to the specified named pipe by path
-// cancellation or timeout.
-func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) {
-       if config == nil {
-               config = &DialConfig{}
-       }
+// DialContext attempts to connect to the specified named pipe by path.
+func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
        var err error
        var h windows.Handle
        h, err = tryDialPipe(ctx, &path)
@@ -213,6 +209,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn
        return &pipe{file: f, path: path}, nil
 }
 
+var defaultDialer DialConfig
+
+// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
+func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+       return defaultDialer.DialTimeout(path, timeout)
+}
+
+// DialContext calls DialConfig.DialContext using an empty configuration.
+func DialContext(ctx context.Context, path string) (net.Conn, error) {
+       return defaultDialer.DialContext(ctx, path)
+}
+
 type acceptResponse struct {
        f   *file
        err error
@@ -222,12 +230,12 @@ type pipeListener struct {
        firstHandle windows.Handle
        path        string
        config      ListenConfig
-       acceptCh    chan (chan acceptResponse)
+       acceptCh    chan chan acceptResponse
        closeCh     chan int
        doneCh      chan int
 }
 
-func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) {
+func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
        path16, err := windows.UTF16PtrFromString(path)
        if err != nil {
                return 0, &os.PathError{Op: "open", Path: path, Err: err}
@@ -247,7 +255,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
        oa.ObjectName = &ntPath
 
        // The security descriptor is only needed for the first pipe.
-       if first {
+       if isFirstPipe {
                if sd != nil {
                        oa.SecurityDescriptor = sd
                } else {
@@ -257,7 +265,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
                                return 0, err
                        }
                        defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
-                       sd, err := windows.NewSecurityDescriptor()
+                       sd, err = windows.NewSecurityDescriptor()
                        if err != nil {
                                return 0, err
                        }
@@ -275,11 +283,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
 
        disposition := uint32(windows.FILE_OPEN)
        access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
-       if first {
+       if isFirstPipe {
                disposition = windows.FILE_CREATE
                // By not asking for read or write access, the named pipe file system
                // will put this pipe into an initially disconnected state, blocking
-               // client connections until the next call with first == false.
+               // client connections until the next call with isFirstPipe == false.
                access = windows.SYNCHRONIZE
        }
 
@@ -395,10 +403,7 @@ type ListenConfig struct {
 
 // Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
 // The pipe must not already exist.
-func Listen(path string, c *ListenConfig) (net.Listener, error) {
-       if c == nil {
-               c = &ListenConfig{}
-       }
+func (c *ListenConfig) Listen(path string) (net.Listener, error) {
        h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
        if err != nil {
                return nil, err
@@ -407,12 +412,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
                firstHandle: h,
                path:        path,
                config:      *c,
-               acceptCh:    make(chan (chan acceptResponse)),
+               acceptCh:    make(chan chan acceptResponse),
                closeCh:     make(chan int),
                doneCh:      make(chan int),
        }
        // The first connection is swallowed on Windows 7 & 8, so synthesize it.
-       if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
+       if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
                path16, err := windows.UTF16PtrFromString(path)
                if err == nil {
                        h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
@@ -425,6 +430,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
        return l, nil
 }
 
+var defaultListener ListenConfig
+
+// Listen calls ListenConfig.Listen using an empty configuration.
+func Listen(path string) (net.Listener, error) {
+       return defaultListener.Listen(path)
+}
+
 func connectPipe(p *file) error {
        c, err := p.prepareIo()
        if err != nil {
similarity index 81%
rename from ipc/winpipe/winpipe_test.go
rename to ipc/namedpipe/namedpipe_test.go
index ea515e3e071e06b30b33b678d0a5706c998a2086..0573d0fbb1c2a492a8dd364fd0f2f014ed045422 100644 (file)
@@ -1,12 +1,12 @@
-//go:build windows
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
 
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
+//go:build windows
+// +build windows
 
-package winpipe_test
+package namedpipe_test
 
 import (
        "bufio"
@@ -22,7 +22,7 @@ import (
        "time"
 
        "golang.org/x/sys/windows"
-       "golang.zx2c4.com/wireguard/ipc/winpipe"
+       "golang.zx2c4.com/wireguard/ipc/namedpipe"
 )
 
 func randomPipePath() string {
@@ -30,7 +30,7 @@ func randomPipePath() string {
        if err != nil {
                panic(err)
        }
-       return `\\.\PIPE\go-winpipe-test-` + guid.String()
+       return `\\.\PIPE\go-namedpipe-test-` + guid.String()
 }
 
 func TestPingPong(t *testing.T) {
@@ -39,7 +39,7 @@ func TestPingPong(t *testing.T) {
                pong = 24
        )
        pipePath := randomPipePath()
-       listener, err := winpipe.Listen(pipePath, nil)
+       listener, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatalf("unable to listen on pipe: %v", err)
        }
@@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) {
                        t.Fatalf("unable to write pong to pipe: %v", err)
                }
        }()
-       client, err := winpipe.Dial(pipePath, nil, nil)
+       client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
        if err != nil {
                t.Fatalf("unable to dial pipe: %v", err)
        }
        defer client.Close()
+       client.SetDeadline(time.Now().Add(time.Second * 5))
        var data [1]byte
        data[0] = ping
        _, err = client.Write(data[:])
@@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) {
 }
 
 func TestDialUnknownFailsImmediately(t *testing.T) {
-       _, err := winpipe.Dial(randomPipePath(), nil, nil)
+       _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
        if !errors.Is(err, syscall.ENOENT) {
                t.Fatalf("expected ENOENT got %v", err)
        }
@@ -93,13 +94,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) {
 
 func TestDialListenerTimesOut(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
        defer l.Close()
-       d := 10 * time.Millisecond
-       _, err = winpipe.Dial(pipePath, &d, nil)
+       pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
+       if err == nil {
+               pipe.Close()
+       }
        if err != os.ErrDeadlineExceeded {
                t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
        }
@@ -107,14 +110,17 @@ func TestDialListenerTimesOut(t *testing.T) {
 
 func TestDialContextListenerTimesOut(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
        defer l.Close()
        d := 10 * time.Millisecond
        ctx, _ := context.WithTimeout(context.Background(), d)
-       _, err = winpipe.DialContext(ctx, pipePath, nil)
+       pipe, err := namedpipe.DialContext(ctx, pipePath)
+       if err == nil {
+               pipe.Close()
+       }
        if err != context.DeadlineExceeded {
                t.Fatalf("expected context.DeadlineExceeded, got %v", err)
        }
@@ -123,14 +129,14 @@ func TestDialContextListenerTimesOut(t *testing.T) {
 func TestDialListenerGetsCancelled(t *testing.T) {
        pipePath := randomPipePath()
        ctx, cancel := context.WithCancel(context.Background())
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
-       ch := make(chan error)
        defer l.Close()
+       ch := make(chan error)
        go func(ctx context.Context, ch chan error) {
-               _, err := winpipe.DialContext(ctx, pipePath, nil)
+               _, err := namedpipe.DialContext(ctx, pipePath)
                ch <- err
        }(ctx, ch)
        time.Sleep(time.Millisecond * 30)
@@ -147,23 +153,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
        }
        pipePath := randomPipePath()
        sd, _ := windows.SecurityDescriptorFromString("D:")
-       c := winpipe.ListenConfig{
+       l, err := (&namedpipe.ListenConfig{
                SecurityDescriptor: sd,
-       }
-       l, err := winpipe.Listen(pipePath, &c)
+       }).Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
        defer l.Close()
-       _, err = winpipe.Dial(pipePath, nil, nil)
+       pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+       if err == nil {
+               pipe.Close()
+       }
        if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
                t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
        }
 }
 
-func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
+func getConnection(cfg *namedpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, cfg)
+       if cfg == nil {
+               cfg = &namedpipe.ListenConfig{}
+       }
+       l, err := cfg.Listen(pipePath)
        if err != nil {
                return
        }
@@ -179,7 +190,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn,
                ch <- response{c, err}
        }()
 
-       c, err := winpipe.Dial(pipePath, nil, nil)
+       c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
        if err != nil {
                return
        }
@@ -236,7 +247,7 @@ func server(l net.Listener, ch chan int) {
 
 func TestFullListenDialReadWrite(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -245,7 +256,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
        ch := make(chan int)
        go server(l, ch)
 
-       c, err := winpipe.Dial(pipePath, nil, nil)
+       c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
        if err != nil {
                t.Fatal(err)
        }
@@ -275,7 +286,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
 
 func TestCloseAbortsListen(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -328,7 +339,7 @@ func TestCloseServerEOFClient(t *testing.T) {
 }
 
 func TestCloseWriteEOF(t *testing.T) {
-       cfg := &winpipe.ListenConfig{
+       cfg := &namedpipe.ListenConfig{
                MessageMode: true,
        }
        c, s, err := getConnection(cfg)
@@ -356,7 +367,7 @@ func TestCloseWriteEOF(t *testing.T) {
 
 func TestAcceptAfterCloseFails(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -369,12 +380,15 @@ func TestAcceptAfterCloseFails(t *testing.T) {
 
 func TestDialTimesOutByDefault(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
        defer l.Close()
-       _, err = winpipe.Dial(pipePath, nil, nil)
+       pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
+       if err == nil {
+               pipe.Close()
+       }
        if err != os.ErrDeadlineExceeded {
                t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
        }
@@ -382,7 +396,7 @@ func TestDialTimesOutByDefault(t *testing.T) {
 
 func TestTimeoutPendingRead(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -400,7 +414,7 @@ func TestTimeoutPendingRead(t *testing.T) {
                close(serverDone)
        }()
 
-       client, err := winpipe.Dial(pipePath, nil, nil)
+       client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
        if err != nil {
                t.Fatal(err)
        }
@@ -430,7 +444,7 @@ func TestTimeoutPendingRead(t *testing.T) {
 
 func TestTimeoutPendingWrite(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -448,7 +462,7 @@ func TestTimeoutPendingWrite(t *testing.T) {
                close(serverDone)
        }()
 
-       client, err := winpipe.Dial(pipePath, nil, nil)
+       client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
        if err != nil {
                t.Fatal(err)
        }
@@ -480,13 +494,12 @@ type CloseWriter interface {
 }
 
 func TestEchoWithMessaging(t *testing.T) {
-       c := winpipe.ListenConfig{
+       pipePath := randomPipePath()
+       l, err := (&namedpipe.ListenConfig{
                MessageMode:      true,  // Use message mode so that CloseWrite() is supported
                InputBufferSize:  65536, // Use 64KB buffers to improve performance
                OutputBufferSize: 65536,
-       }
-       pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, &c)
+       }).Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -496,19 +509,21 @@ func TestEchoWithMessaging(t *testing.T) {
        clientDone := make(chan bool)
        go func() {
                // server echo
-               conn, e := l.Accept()
-               if e != nil {
-                       t.Fatal(e)
+               conn, err := l.Accept()
+               if err != nil {
+                       t.Fatal(err)
                }
                defer conn.Close()
 
                time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
-               io.Copy(conn, conn)
+               _, err = io.Copy(conn, conn)
+               if err != nil {
+                       t.Fatal(err)
+               }
                conn.(CloseWriter).CloseWrite()
                close(listenerDone)
        }()
-       timeout := 1 * time.Second
-       client, err := winpipe.Dial(pipePath, &timeout, nil)
+       client, err := namedpipe.DialTimeout(pipePath, time.Second)
        if err != nil {
                t.Fatal(err)
        }
@@ -521,7 +536,7 @@ func TestEchoWithMessaging(t *testing.T) {
                if e != nil {
                        t.Fatal(e)
                }
-               if n != 2 {
+               if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
                        t.Fatalf("expected 2 bytes, got %v", n)
                }
                close(clientDone)
@@ -545,7 +560,7 @@ func TestEchoWithMessaging(t *testing.T) {
 
 func TestConnectRace(t *testing.T) {
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, nil)
+       l, err := namedpipe.Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -565,7 +580,7 @@ func TestConnectRace(t *testing.T) {
        }()
 
        for i := 0; i < 1000; i++ {
-               c, err := winpipe.Dial(pipePath, nil, nil)
+               c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
                if err != nil {
                        t.Fatal(err)
                }
@@ -580,7 +595,7 @@ func TestMessageReadMode(t *testing.T) {
        var wg sync.WaitGroup
        defer wg.Wait()
        pipePath := randomPipePath()
-       l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true})
+       l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
        if err != nil {
                t.Fatal(err)
        }
@@ -602,7 +617,7 @@ func TestMessageReadMode(t *testing.T) {
                s.Close()
        }()
 
-       c, err := winpipe.Dial(pipePath, nil, nil)
+       c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
        if err != nil {
                t.Fatal(err)
        }
@@ -643,13 +658,13 @@ func TestListenConnectRace(t *testing.T) {
                var wg sync.WaitGroup
                wg.Add(1)
                go func() {
-                       c, err := winpipe.Dial(pipePath, nil, nil)
+                       c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
                        if err == nil {
                                c.Close()
                        }
                        wg.Done()
                }()
-               s, err := winpipe.Listen(pipePath, nil)
+               s, err := namedpipe.Listen(pipePath)
                if err != nil {
                        t.Error(i, err)
                } else {
index a4d68da81b8ce7aa2eb1ac85af1e32079ebc8ab2..a1bfbd1bd69e8111ff9f7f74e5c31a59e003ae4f 100644 (file)
@@ -9,8 +9,7 @@ import (
        "net"
 
        "golang.org/x/sys/windows"
-
-       "golang.zx2c4.com/wireguard/ipc/winpipe"
+       "golang.zx2c4.com/wireguard/ipc/namedpipe"
 )
 
 // TODO: replace these with actual standard windows error numbers from the win package
@@ -61,10 +60,9 @@ func init() {
 }
 
 func UAPIListen(name string) (net.Listener, error) {
-       config := winpipe.ListenConfig{
+       listener, err := (&namedpipe.ListenConfig{
                SecurityDescriptor: UAPISecurityDescriptor,
-       }
-       listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
+       }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
        if err != nil {
                return nil, err
        }
diff --git a/tun/wintun/dll_windows.go b/tun/wintun/dll_windows.go
deleted file mode 100644 (file)
index 3832c1e..0000000
+++ /dev/null
@@ -1,128 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
-       "fmt"
-       "sync"
-       "sync/atomic"
-       "unsafe"
-
-       "golang.org/x/sys/windows"
-)
-
-func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
-       return &lazyDLL{Name: name, onLoad: onLoad}
-}
-
-func (d *lazyDLL) NewProc(name string) *lazyProc {
-       return &lazyProc{dll: d, Name: name}
-}
-
-type lazyProc struct {
-       Name string
-       mu   sync.Mutex
-       dll  *lazyDLL
-       addr uintptr
-}
-
-func (p *lazyProc) Find() error {
-       if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
-               return nil
-       }
-       p.mu.Lock()
-       defer p.mu.Unlock()
-       if p.addr != 0 {
-               return nil
-       }
-
-       err := p.dll.Load()
-       if err != nil {
-               return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
-       }
-       addr, err := p.nameToAddr()
-       if err != nil {
-               return fmt.Errorf("Error getting %v address: %w", p.Name, err)
-       }
-
-       atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
-       return nil
-}
-
-func (p *lazyProc) Addr() uintptr {
-       err := p.Find()
-       if err != nil {
-               panic(err)
-       }
-       return p.addr
-}
-
-type lazyDLL struct {
-       Name   string
-       mu     sync.Mutex
-       module windows.Handle
-       onLoad func(d *lazyDLL)
-}
-
-func (d *lazyDLL) Load() error {
-       if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
-               return nil
-       }
-       d.mu.Lock()
-       defer d.mu.Unlock()
-       if d.module != 0 {
-               return nil
-       }
-
-       const (
-               LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
-               LOAD_LIBRARY_SEARCH_SYSTEM32        = 0x00000800
-       )
-       module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
-       if err != nil {
-               return fmt.Errorf("Unable to load library: %w", err)
-       }
-
-       atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
-       if d.onLoad != nil {
-               d.onLoad(d)
-       }
-       return nil
-}
-
-func (p *lazyProc) nameToAddr() (uintptr, error) {
-       return windows.GetProcAddress(p.dll.module, p.Name)
-}
-
-// Version returns the version of the Wintun DLL.
-func Version() string {
-       if modwintun.Load() != nil {
-               return "unknown"
-       }
-       resInfo, err := windows.FindResource(modwintun.module, windows.ResourceID(1), windows.RT_VERSION)
-       if err != nil {
-               return "unknown"
-       }
-       data, err := windows.LoadResourceData(modwintun.module, resInfo)
-       if err != nil {
-               return "unknown"
-       }
-
-       var fixedInfo *windows.VS_FIXEDFILEINFO
-       fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
-       err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
-       if err != nil {
-               return "unknown"
-       }
-       version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff)
-       if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 {
-               version += fmt.Sprintf(".%d", nextNibble)
-       }
-       if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 {
-               version += fmt.Sprintf(".%d", nextNibble)
-       }
-       return version
-}
diff --git a/tun/wintun/session_windows.go b/tun/wintun/session_windows.go
deleted file mode 100644 (file)
index f023baf..0000000
+++ /dev/null
@@ -1,90 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
-       "syscall"
-       "unsafe"
-
-       "golang.org/x/sys/windows"
-)
-
-type Session struct {
-       handle uintptr
-}
-
-const (
-       PacketSizeMax   = 0xffff    // Maximum packet size
-       RingCapacityMin = 0x20000   // Minimum ring capacity (128 kiB)
-       RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB)
-)
-
-// Packet with data
-type Packet struct {
-       Next *Packet              // Pointer to next packet in queue
-       Size uint32               // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE)
-       Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet
-}
-
-var (
-       procWintunAllocateSendPacket   = modwintun.NewProc("WintunAllocateSendPacket")
-       procWintunEndSession           = modwintun.NewProc("WintunEndSession")
-       procWintunGetReadWaitEvent     = modwintun.NewProc("WintunGetReadWaitEvent")
-       procWintunReceivePacket        = modwintun.NewProc("WintunReceivePacket")
-       procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket")
-       procWintunSendPacket           = modwintun.NewProc("WintunSendPacket")
-       procWintunStartSession         = modwintun.NewProc("WintunStartSession")
-)
-
-func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) {
-       r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0)
-       if r0 == 0 {
-               err = e1
-       } else {
-               session = Session{r0}
-       }
-       return
-}
-
-func (session Session) End() {
-       syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0)
-       session.handle = 0
-}
-
-func (session Session) ReadWaitEvent() (handle windows.Handle) {
-       r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0)
-       handle = windows.Handle(r0)
-       return
-}
-
-func (session Session) ReceivePacket() (packet []byte, err error) {
-       var packetSize uint32
-       r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0)
-       if r0 == 0 {
-               err = e1
-               return
-       }
-       packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
-       return
-}
-
-func (session Session) ReleaseReceivePacket(packet []byte) {
-       syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
-}
-
-func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) {
-       r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0)
-       if r0 == 0 {
-               err = e1
-               return
-       }
-       packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
-       return
-}
-
-func (session Session) SendPacket(packet []byte) {
-       syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
-}
diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go
deleted file mode 100644 (file)
index 2fe26a7..0000000
+++ /dev/null
@@ -1,150 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
-       "log"
-       "runtime"
-       "syscall"
-       "unsafe"
-
-       "golang.org/x/sys/windows"
-)
-
-type loggerLevel int
-
-const (
-       logInfo loggerLevel = iota
-       logWarn
-       logErr
-)
-
-const AdapterNameMax = 128
-
-type Adapter struct {
-       handle uintptr
-}
-
-var (
-       modwintun                         = newLazyDLL("wintun.dll", setupLogger)
-       procWintunCreateAdapter           = modwintun.NewProc("WintunCreateAdapter")
-       procWintunOpenAdapter             = modwintun.NewProc("WintunOpenAdapter")
-       procWintunCloseAdapter            = modwintun.NewProc("WintunCloseAdapter")
-       procWintunDeleteDriver            = modwintun.NewProc("WintunDeleteDriver")
-       procWintunGetAdapterLUID          = modwintun.NewProc("WintunGetAdapterLUID")
-       procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
-)
-
-type TimestampedWriter interface {
-       WriteWithTimestamp(p []byte, ts int64) (n int, err error)
-}
-
-func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
-       if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
-               tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
-       } else {
-               log.Println(windows.UTF16PtrToString(msg))
-       }
-       return 0
-}
-
-func setupLogger(dll *lazyDLL) {
-       var callback uintptr
-       if runtime.GOARCH == "386" {
-               callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
-                       return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
-               })
-       } else if runtime.GOARCH == "arm" {
-               callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int {
-                       return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
-               })
-       } else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
-               callback = windows.NewCallback(logMessage)
-       }
-       syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
-}
-
-func closeAdapter(wintun *Adapter) {
-       syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
-}
-
-// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
-// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
-// the GUID of the created network adapter, which then influences NLA generation
-// deterministically. If it is set to nil, the GUID is chosen by the system at random,
-// and hence a new NLA entry is created for each new adapter.
-func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
-       var name16 *uint16
-       name16, err = windows.UTF16PtrFromString(name)
-       if err != nil {
-               return
-       }
-       var tunnelType16 *uint16
-       tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
-       if err != nil {
-               return
-       }
-       r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
-       if r0 == 0 {
-               err = e1
-               return
-       }
-       wintun = &Adapter{handle: r0}
-       runtime.SetFinalizer(wintun, closeAdapter)
-       return
-}
-
-// OpenAdapter opens an existing Wintun adapter by name.
-func OpenAdapter(name string) (wintun *Adapter, err error) {
-       var name16 *uint16
-       name16, err = windows.UTF16PtrFromString(name)
-       if err != nil {
-               return
-       }
-       r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
-       if r0 == 0 {
-               err = e1
-               return
-       }
-       wintun = &Adapter{handle: r0}
-       runtime.SetFinalizer(wintun, closeAdapter)
-       return
-}
-
-// Close closes a Wintun adapter.
-func (wintun *Adapter) Close() (err error) {
-       runtime.SetFinalizer(wintun, nil)
-       r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
-       if r1 == 0 {
-               err = e1
-       }
-       return
-}
-
-// Uninstall removes the driver from the system if no drivers are currently in use.
-func Uninstall() (err error) {
-       r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
-       if r1 == 0 {
-               err = e1
-       }
-       return
-}
-
-// RunningVersion returns the version of the loaded driver.
-func RunningVersion() (version uint32, err error) {
-       r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
-       version = uint32(r0)
-       if version == 0 {
-               err = e1
-       }
-       return
-}
-
-// LUID returns the LUID of the adapter.
-func (wintun *Adapter) LUID() (luid uint64) {
-       syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)
-       return
-}