]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
memmod: import from wireguard-windows
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 11 Oct 2021 20:53:36 +0000 (14:53 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 11 Oct 2021 20:53:36 +0000 (14:53 -0600)
We'll eventually be getting rid of it here, but keep it sync'd up for
now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
go.mod
tun/wintun/memmod/memmod_windows.go
tun/wintun/memmod/memmod_windows_32.go
tun/wintun/memmod/memmod_windows_64.go
tun/wintun/memmod/syscall_windows_32.go
tun/wintun/memmod/syscall_windows_64.go

diff --git a/go.mod b/go.mod
index 5d8388b79829b393858ff514b5d625195d4955c7..e543167ea84e3067adc5c21bdb4d4f964c05073f 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
 module golang.zx2c4.com/wireguard
 
-go 1.16
+go 1.17
 
 require (
        golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
index 075c03a08ac72d19e33dfa2415b95e086001440c..da6ff9afb41d2b5e8de29764adb64fa13dab1234 100644 (file)
@@ -8,6 +8,8 @@ package memmod
 import (
        "errors"
        "fmt"
+       "strings"
+       "sync"
        "syscall"
        "unsafe"
 
@@ -62,8 +64,7 @@ func (module *Module) copySections(address uintptr, size uintptr, oldHeaders *IM
                        dest = module.codeBase + uintptr(sections[i].VirtualAddress)
                        // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
                        sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
-                       var dst []byte
-                       unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
+                       dst := unsafe.Slice((*byte)(a2p(dest)), sectionSize)
                        for j := range dst {
                                dst[j] = 0
                        }
@@ -245,11 +246,9 @@ func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err
        for relocationHdr.VirtualAddress > 0 {
                dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
 
-               var relInfos []uint16
-               unsafeSlice(
-                       unsafe.Pointer(&relInfos),
-                       a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
-                       int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
+               relInfos := unsafe.Slice(
+                       (*uint16)(a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr))),
+                       (uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(uint16(0)))
                for _, relInfo := range relInfos {
                        // The upper 4 bits define the type of relocation.
                        relType := relInfo >> 12
@@ -370,10 +369,8 @@ func (module *Module) buildNameExports() error {
        if exports.NumberOfNames == 0 {
                return errors.New("No functions exported by name")
        }
-       var nameRefs []uint32
-       unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
-       var ordinals []uint16
-       unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
+       nameRefs := unsafe.Slice((*uint32)(a2p(module.codeBase+uintptr(exports.AddressOfNames))), exports.NumberOfNames)
+       ordinals := unsafe.Slice((*uint16)(a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals))), exports.NumberOfNames)
        module.nameExports = make(map[string]uint16)
        for i := range nameRefs {
                nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
@@ -382,6 +379,76 @@ func (module *Module) buildNameExports() error {
        return nil
 }
 
+type addressRange struct {
+       start uintptr
+       end   uintptr
+}
+
+var loadedAddressRanges []addressRange
+var loadedAddressRangesMu sync.RWMutex
+var haveHookedRtlPcToFileHeader sync.Once
+var hookRtlPcToFileHeaderResult error
+
+func hookRtlPcToFileHeader() error {
+       var kernelBase windows.Handle
+       err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase)
+       if err != nil {
+               return err
+       }
+       imageBase := unsafe.Pointer(kernelBase)
+       dosHeader := (*IMAGE_DOS_HEADER)(imageBase)
+       ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew))
+       importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
+       importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress))
+       for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) {
+               libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name)))
+               if strings.EqualFold(libraryName, "ntdll.dll") {
+                       break
+               }
+       }
+       if importDescriptor.Name == 0 {
+               return errors.New("ntdll.dll not found")
+       }
+       originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk()))
+       thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk))
+       for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) {
+               if *originalThunk&IMAGE_ORDINAL_FLAG == 0 {
+                       function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk))
+                       name := windows.BytePtrToString(&function.Name[0])
+                       if name == "RtlPcToFileHeader" {
+                               break
+                       }
+               }
+               thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk)))
+       }
+       if *originalThunk == 0 {
+               return errors.New("RtlPcToFileHeader not found")
+       }
+       var oldProtect uint32
+       err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect)
+       if err != nil {
+               return err
+       }
+       originalRtlPcToFileHeader := *thunk
+       *thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr {
+               loadedAddressRangesMu.RLock()
+               for i := range loadedAddressRanges {
+                       if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end {
+                               pcValue = *thunk
+                               break
+                       }
+               }
+               loadedAddressRangesMu.RUnlock()
+               ret, _, _ := syscall.Syscall(originalRtlPcToFileHeader, 2, pcValue, uintptr(unsafe.Pointer(baseOfImage)), 0)
+               return ret
+       })
+       err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect)
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
 // LoadLibrary loads module image to memory.
 func LoadLibrary(data []byte) (module *Module, err error) {
        addr := uintptr(unsafe.Pointer(&data[0]))
@@ -513,6 +580,18 @@ func LoadLibrary(data []byte) (module *Module, err error) {
        // Register exception tables, if they exist.
        module.registerExceptionHandlers()
 
+       // Register function PCs.
+       loadedAddressRangesMu.Lock()
+       loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize})
+       loadedAddressRangesMu.Unlock()
+       haveHookedRtlPcToFileHeader.Do(func() {
+               hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader()
+       })
+       err = hookRtlPcToFileHeaderResult
+       if err != nil {
+               return
+       }
+
        // TLS callbacks are executed BEFORE the main loading.
        module.executeTLS()
 
@@ -610,26 +689,5 @@ func a2p(addr uintptr) unsafe.Pointer {
 }
 
 func memcpy(dst, src, size uintptr) {
-       var d, s []byte
-       unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
-       unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
-       copy(d, s)
-}
-
-// unsafeSlice updates the slice slicePtr to be a slice
-// referencing the provided data with its length & capacity set to
-// lenCap.
-//
-// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
-// update callers to use unsafe.Slice instead of this.
-func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
-       type sliceHeader struct {
-               Data unsafe.Pointer
-               Len  int
-               Cap  int
-       }
-       h := (*sliceHeader)(slicePtr)
-       h.Data = data
-       h.Len = lenCap
-       h.Cap = lenCap
+       copy(unsafe.Slice((*byte)(a2p(dst)), size), unsafe.Slice((*byte)(a2p(src)), size))
 }
index ac76bdcca24cb3bc90852130bd936c8e0734f6a2..75d7ca1bb9c58917762b98d72021fdbf457734c0 100644 (file)
@@ -1,3 +1,4 @@
+//go:build (windows && 386) || (windows && arm)
 // +build windows,386 windows,arm
 
 /* SPDX-License-Identifier: MIT
index a6203682674ebbff625b78047f3ca3e042383325..09e6e73ad4e7aa2cf6a179221040aa3a54f6b973 100644 (file)
@@ -1,3 +1,4 @@
+//go:build (windows && amd64) || (windows && arm64)
 // +build windows,amd64 windows,arm64
 
 /* SPDX-License-Identifier: MIT
index 7abbac98a204b26884f73626990aa82d3909ac6d..007271088ee8375d65687e7fd7eda33c1b91a384 100644 (file)
@@ -1,3 +1,4 @@
+//go:build (windows && 386) || (windows && arm)
 // +build windows,386 windows,arm
 
 /* SPDX-License-Identifier: MIT
index 10c6533208975f4fda8ee3628196b5c640152994..b4752025a1fca284fd9a19bc938ff93879b6d592 100644 (file)
@@ -1,3 +1,4 @@
+//go:build (windows && amd64) || (windows && arm64)
 // +build windows,amd64 windows,arm64
 
 /* SPDX-License-Identifier: MIT