// DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
-func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
+func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) {
var absTimeout time.Time
if timeout != nil {
absTimeout = time.Now().Add(*timeout)
absTimeout = time.Now().Add(time.Second * 2)
}
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
- conn, err := DialPipeContext(ctx, path)
+ conn, err := DialPipeContext(ctx, path, expectedOwner)
if err == context.DeadlineExceeded {
return nil, ErrTimeout
}
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout.
-func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
+func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) {
var err error
var h syscall.Handle
h, err = tryDialPipe(ctx, &path)
return nil, err
}
+ if expectedOwner != nil {
+ var realOwner *syscall.SID
+ var realSd uintptr
+ err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
+ if err != nil {
+ syscall.Close(h)
+ return nil, err
+ }
+ defer localFree(realSd)
+ if !equalSid(realOwner, expectedOwner) {
+ syscall.Close(h)
+ return nil, syscall.ERROR_ACCESS_DENIED
+ }
+ }
+
var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
+ syscall.Close(h)
return nil, err
}
"unsafe"
)
-//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
-//sys localFree(mem uintptr) = LocalFree
-//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
+//sys localFree(mem uintptr) = LocalFree
+//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
+//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
+
+const (
+ SE_FILE_OBJECT = 1
+ OWNER_SECURITY_INFORMATION = 1
+)
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
-}
+}
\ No newline at end of file
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
procLocalFree = modkernel32.NewProc("LocalFree")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
+ procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo")
+ procEqualSid = modadvapi32.NewProc("EqualSid")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
return
}
+func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
+ r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
+ return
+}
+
+func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
+ r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
+ isEqual = r0 != 0
+ return
+}
+
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {