]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
winpipe: enforce ownership of client connection
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 30 Aug 2019 19:21:47 +0000 (13:21 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 30 Aug 2019 19:21:47 +0000 (13:21 -0600)
ipc/winpipe/pipe.go
ipc/winpipe/sd.go
ipc/winpipe/zsyscall_windows.go

index 1e99a93ce760a04450fc3bd802267c365cf9c2de..39ccfa4a6b22654ee86c7b12ef8d0cf1f410cfc1 100644 (file)
@@ -211,7 +211,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
 // DialPipe connects to a named pipe by path, timing out if the connection
 // takes longer than the specified duration. If timeout is nil, then we use
 // a default timeout of 2 seconds.  (We do not use WaitNamedPipe.)
-func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
+func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) {
        var absTimeout time.Time
        if timeout != nil {
                absTimeout = time.Now().Add(*timeout)
@@ -219,7 +219,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
                absTimeout = time.Now().Add(time.Second * 2)
        }
        ctx, _ := context.WithDeadline(context.Background(), absTimeout)
-       conn, err := DialPipeContext(ctx, path)
+       conn, err := DialPipeContext(ctx, path, expectedOwner)
        if err == context.DeadlineExceeded {
                return nil, ErrTimeout
        }
@@ -228,7 +228,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
 
 // DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
 // cancellation or timeout.
-func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
+func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) {
        var err error
        var h syscall.Handle
        h, err = tryDialPipe(ctx, &path)
@@ -236,9 +236,25 @@ func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
                return nil, err
        }
 
+       if expectedOwner != nil {
+               var realOwner *syscall.SID
+               var realSd uintptr
+               err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
+               if err != nil {
+                       syscall.Close(h)
+                       return nil, err
+               }
+               defer localFree(realSd)
+               if !equalSid(realOwner, expectedOwner) {
+                       syscall.Close(h)
+                       return nil, syscall.ERROR_ACCESS_DENIED
+               }
+       }
+
        var flags uint32
        err = getNamedPipeInfo(h, &flags, nil, nil, nil)
        if err != nil {
+               syscall.Close(h)
                return nil, err
        }
 
index 75686b2b36010b05e592ab206d3bf7cb886b9a21..4456917320b04bdaff059c1c43789d92b38f964c 100644 (file)
@@ -12,9 +12,16 @@ import (
        "unsafe"
 )
 
-//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
-//sys localFree(mem uintptr) = LocalFree
-//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+//sys  convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
+//sys  localFree(mem uintptr) = LocalFree
+//sys  getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+//sys  getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
+//sys  equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
+
+const (
+       SE_FILE_OBJECT             = 1
+       OWNER_SECURITY_INFORMATION = 1
+)
 
 func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
        var sdBuffer uintptr
@@ -26,4 +33,4 @@ func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
        sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
        copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
        return sd, nil
-}
+}
\ No newline at end of file
index b8eedb40b3aeb85965d694b72f98a0067d8269d5..ecf3e840e3d3bcce53b7fa5c37526691a8c34104 100644 (file)
@@ -55,6 +55,8 @@ var (
        procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
        procLocalFree                                            = modkernel32.NewProc("LocalFree")
        procGetSecurityDescriptorLength                          = modadvapi32.NewProc("GetSecurityDescriptorLength")
+       procGetSecurityInfo                                      = modadvapi32.NewProc("GetSecurityInfo")
+       procEqualSid                                             = modadvapi32.NewProc("EqualSid")
        procCancelIoEx                                           = modkernel32.NewProc("CancelIoEx")
        procCreateIoCompletionPort                               = modkernel32.NewProc("CreateIoCompletionPort")
        procGetQueuedCompletionStatus                            = modkernel32.NewProc("GetQueuedCompletionStatus")
@@ -206,6 +208,20 @@ func getSecurityDescriptorLength(sd uintptr) (len uint32) {
        return
 }
 
+func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
+       r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
+       if r0 != 0 {
+               ret = syscall.Errno(r0)
+       }
+       return
+}
+
+func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
+       r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
+       isEqual = r0 != 0
+       return
+}
+
 func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
        r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
        if r1 == 0 {