]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Added code from windows branch
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 27 Aug 2017 13:41:00 +0000 (15:41 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 27 Aug 2017 13:41:00 +0000 (15:41 +0200)
src/build.cmd [new file with mode: 0755]
src/conn_default.go
src/daemon_windows.go [new file with mode: 0644]
src/timers.go
src/tun_windows.go [new file with mode: 0644]
src/uapi_windows.go [new file with mode: 0644]

diff --git a/src/build.cmd b/src/build.cmd
new file mode 100755 (executable)
index 0000000..52cb883
--- /dev/null
@@ -0,0 +1,6 @@
+@echo off
+
+REM builds wireguard for windows
+
+go get
+go build -o wireguard-go.exe
index a6dc97da15a0a1bfb285f9938fc8c05bbdc09940..5ef2659a57e113f62b05c87739f1a1497673a3fe 100644 (file)
@@ -6,6 +6,6 @@ import (
        "net"
 )
 
-func setFwmark(conn *net.UDPConn, value int) error {
+func setMark(conn *net.UDPConn, value int) error {
        return nil
 }
diff --git a/src/daemon_windows.go b/src/daemon_windows.go
new file mode 100644 (file)
index 0000000..d5ec1e8
--- /dev/null
@@ -0,0 +1,34 @@
+package main\r
+\r
+import (\r
+       "os"\r
+)\r
+\r
+/* Daemonizes the process on windows\r
+ *\r
+ * This is done by spawning and releasing a copy with the --foreground flag\r
+ */\r
+\r
+func Daemonize() error {\r
+       argv := []string{os.Args[0], "--foreground"}\r
+       argv = append(argv, os.Args[1:]...)\r
+       attr := &os.ProcAttr{\r
+               Dir: ".",\r
+               Env: os.Environ(),\r
+               Files: []*os.File{\r
+                       os.Stdin,\r
+                       nil,\r
+                       nil,\r
+               },\r
+       }\r
+       process, err := os.StartProcess(\r
+               argv[0],\r
+               argv,\r
+               attr,\r
+       )\r
+       if err != nil {\r
+               return err\r
+       }\r
+       process.Release()\r
+       return nil\r
+}\r
index ab2e7adf59fc53b8130288b6be0056a31ae6fdfb..de54a96db934e7e25a545592db0b9096a8d4add5 100644 (file)
-package main
-
-import (
-       "bytes"
-       "encoding/binary"
-       "golang.org/x/crypto/blake2s"
-       "math/rand"
-       "sync/atomic"
-       "time"
-)
-
-/* Called when a new authenticated message has been send
- *
- */
-func (peer *Peer) KeepKeyFreshSending() {
-       kp := peer.keyPairs.Current()
-       if kp == nil {
-               return
-       }
-       nonce := atomic.LoadUint64(&kp.sendNonce)
-       if nonce > RekeyAfterMessages {
-               signalSend(peer.signal.handshakeBegin)
-       }
-       if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
-               signalSend(peer.signal.handshakeBegin)
-       }
-}
-
-/* Called when a new authenticated message has been recevied
- *
- */
-func (peer *Peer) KeepKeyFreshReceiving() {
-       // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
-       kp := peer.keyPairs.Current()
-       if kp == nil {
-               return
-       }
-       if !kp.isInitiator {
-               return
-       }
-       nonce := atomic.LoadUint64(&kp.sendNonce)
-       send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
-       if send {
-               signalSend(peer.signal.handshakeBegin)
-       }
-}
-
-/* Queues a keep-alive if no packets are queued for peer
- */
-func (peer *Peer) SendKeepAlive() bool {
-       elem := peer.device.NewOutboundElement()
-       elem.packet = nil
-       if len(peer.queue.nonce) == 0 {
-               select {
-               case peer.queue.nonce <- elem:
-                       return true
-               default:
-                       return false
-               }
-       }
-       return true
-}
-
-/* Event:
- * Sent non-empty (authenticated) transport message
- */
-func (peer *Peer) TimerDataSent() {
-       timerStop(peer.timer.keepalivePassive)
-       if !peer.timer.pendingNewHandshake {
-               peer.timer.pendingNewHandshake = true
-               peer.timer.newHandshake.Reset(NewHandshakeTime)
-       }
-}
-
-/* Event:
- * Received non-empty (authenticated) transport message
- */
-func (peer *Peer) TimerDataReceived() {
-       if peer.timer.pendingKeepalivePassive {
-               peer.timer.needAnotherKeepalive = true
-               return
-       }
-       peer.timer.pendingKeepalivePassive = false
-       peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
-}
-
-/* Event:
- * Any (authenticated) packet received
- */
-func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
-       timerStop(peer.timer.newHandshake)
-}
-
-/* Event:
- * Any authenticated packet send / received.
- */
-func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
-       interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
-       if interval > 0 {
-               duration := time.Duration(interval) * time.Second
-               peer.timer.keepalivePersistent.Reset(duration)
-       }
-}
-
-/* Called after succesfully completing a handshake.
- * i.e. after:
- *
- * - Valid handshake response
- * - First transport message under the "next" key
- */
-func (peer *Peer) TimerHandshakeComplete() {
-       atomic.StoreInt64(
-               &peer.stats.lastHandshakeNano,
-               time.Now().UnixNano(),
-       )
-       signalSend(peer.signal.handshakeCompleted)
-       peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
-}
-
-/* Event:
- * An ephemeral key is generated
- *
- * i.e after:
- *
- * CreateMessageInitiation
- * CreateMessageResponse
- *
- * Schedules the deletion of all key material
- * upon failure to complete a handshake
- */
-func (peer *Peer) TimerEphemeralKeyCreated() {
-       peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
-}
-
-func (peer *Peer) RoutineTimerHandler() {
-       device := peer.device
-       indices := &device.indices
-
-       logDebug := device.log.Debug
-       logDebug.Println("Routine, timer handler, started for peer", peer.String())
-
-       for {
-               select {
-
-               case <-peer.signal.stop:
-                       return
-
-               // keep-alives
-
-               case <-peer.timer.keepalivePersistent.C:
-
-                       interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
-                       if interval > 0 {
-                               logDebug.Println("Sending keep-alive to", peer.String())
-                               peer.SendKeepAlive()
-                       }
-
-               case <-peer.timer.keepalivePassive.C:
-
-                       logDebug.Println("Sending keep-alive to", peer.String())
-
-                       peer.SendKeepAlive()
-
-                       if peer.timer.needAnotherKeepalive {
-                               peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
-                               peer.timer.needAnotherKeepalive = false
-                       }
-
-               // unresponsive session
-
-               case <-peer.timer.newHandshake.C:
-
-                       logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
-
-                       signalSend(peer.signal.handshakeBegin)
-
-               // clear key material
-
-               case <-peer.timer.zeroAllKeys.C:
-
-                       logDebug.Println("Clearing all key material for", peer.String())
-
-                       hs := &peer.handshake
-                       hs.mutex.Lock()
-
-                       kp := &peer.keyPairs
-                       kp.mutex.Lock()
-
-                       // unmap indecies
-
-                       indices.mutex.Lock()
-                       if kp.previous != nil {
-                               delete(indices.table, kp.previous.localIndex)
-                       }
-                       if kp.current != nil {
-                               delete(indices.table, kp.current.localIndex)
-                       }
-                       if kp.next != nil {
-                               delete(indices.table, kp.next.localIndex)
-                       }
-                       delete(indices.table, hs.localIndex)
-                       indices.mutex.Unlock()
-
-                       // zero out key pairs (TODO: better than wait for GC)
-
-                       kp.current = nil
-                       kp.previous = nil
-                       kp.next = nil
-                       kp.mutex.Unlock()
-
-                       // zero out handshake
-
-                       hs.localIndex = 0
-                       hs.localEphemeral = NoisePrivateKey{}
-                       hs.remoteEphemeral = NoisePublicKey{}
-                       hs.chainKey = [blake2s.Size]byte{}
-                       hs.hash = [blake2s.Size]byte{}
-                       hs.mutex.Unlock()
-               }
-       }
-}
-
-/* This is the state machine for handshake initiation
- *
- * Associated with this routine is the signal "handshakeBegin"
- * The routine will read from the "handshakeBegin" channel
- * at most every RekeyTimeout seconds
- */
-func (peer *Peer) RoutineHandshakeInitiator() {
-       device := peer.device
-
-       logInfo := device.log.Info
-       logError := device.log.Error
-       logDebug := device.log.Debug
-       logDebug.Println("Routine, handshake initator, started for", peer.String())
-
-       var temp [256]byte
-
-       for {
-
-               // wait for signal
-
-               select {
-               case <-peer.signal.handshakeBegin:
-               case <-peer.signal.stop:
-                       return
-               }
-
-               // set deadline
-
-       BeginHandshakes:
-
-               signalClear(peer.signal.handshakeReset)
-               deadline := time.NewTimer(RekeyAttemptTime)
-
-       AttemptHandshakes:
-
-               for attempts := uint(1); ; attempts++ {
-
-                       // check if deadline reached
-
-                       select {
-                       case <-deadline.C:
-                               logInfo.Println("Handshake negotiation timed out for:", peer.String())
-                               signalSend(peer.signal.flushNonceQueue)
-                               timerStop(peer.timer.keepalivePersistent)
-                               break
-                       case <-peer.signal.stop:
-                               return
-                       default:
-                       }
-
-                       signalClear(peer.signal.handshakeCompleted)
-
-                       // create initiation message
-
-                       msg, err := peer.device.CreateMessageInitiation(peer)
-                       if err != nil {
-                               logError.Println("Failed to create handshake initiation message:", err)
-                               break AttemptHandshakes
-                       }
-
-                       jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
-
-                       // marshal and send
-
-                       writer := bytes.NewBuffer(temp[:0])
-                       binary.Write(writer, binary.LittleEndian, msg)
-                       packet := writer.Bytes()
-                       peer.mac.AddMacs(packet)
-
-                       _, err = peer.SendBuffer(packet)
-                       if err != nil {
-                               logError.Println(
-                                       "Failed to send handshake initiation message to",
-                                       peer.String(), ":", err,
-                               )
-                               break
-                       }
-
-                       peer.TimerAnyAuthenticatedPacketTraversal()
-
-                       // set handshake timeout
-
-                       timeout := time.NewTimer(RekeyTimeout + jitter)
-                       logDebug.Println(
-                               "Handshake initiation attempt",
-                               attempts, "sent to", peer.String(),
-                       )
-
-                       // wait for handshake or timeout
-
-                       select {
-
-                       case <-peer.signal.stop:
-                               return
-
-                       case <-peer.signal.handshakeCompleted:
-                               <-timeout.C
-                               break AttemptHandshakes
-
-                       case <-peer.signal.handshakeReset:
-                               <-timeout.C
-                               goto BeginHandshakes
-
-                       case <-timeout.C:
-                               // TODO: Clear source address for peer
-                               continue
-                       }
-               }
-
-               // clear signal set in the meantime
-
-               signalClear(peer.signal.handshakeBegin)
-       }
-}
+package main\r
+\r
+import (\r
+       "bytes"\r
+       "encoding/binary"\r
+       "golang.org/x/crypto/blake2s"\r
+       "math/rand"\r
+       "sync/atomic"\r
+       "time"\r
+)\r
+\r
+/* Called when a new authenticated message has been send\r
+ *\r
+ */\r
+func (peer *Peer) KeepKeyFreshSending() {\r
+       kp := peer.keyPairs.Current()\r
+       if kp == nil {\r
+               return\r
+       }\r
+       nonce := atomic.LoadUint64(&kp.sendNonce)\r
+       if nonce > RekeyAfterMessages {\r
+               signalSend(peer.signal.handshakeBegin)\r
+       }\r
+       if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {\r
+               signalSend(peer.signal.handshakeBegin)\r
+       }\r
+}\r
+\r
+/* Called when a new authenticated message has been recevied\r
+ *\r
+ */\r
+func (peer *Peer) KeepKeyFreshReceiving() {\r
+       // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)\r
+       kp := peer.keyPairs.Current()\r
+       if kp == nil {\r
+               return\r
+       }\r
+       if !kp.isInitiator {\r
+               return\r
+       }\r
+       nonce := atomic.LoadUint64(&kp.sendNonce)\r
+       send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving\r
+       if send {\r
+               signalSend(peer.signal.handshakeBegin)\r
+       }\r
+}\r
+\r
+/* Queues a keep-alive if no packets are queued for peer\r
+ */\r
+func (peer *Peer) SendKeepAlive() bool {\r
+       elem := peer.device.NewOutboundElement()\r
+       elem.packet = nil\r
+       if len(peer.queue.nonce) == 0 {\r
+               select {\r
+               case peer.queue.nonce <- elem:\r
+                       return true\r
+               default:\r
+                       return false\r
+               }\r
+       }\r
+       return true\r
+}\r
+\r
+/* Event:\r
+ * Sent non-empty (authenticated) transport message\r
+ */\r
+func (peer *Peer) TimerDataSent() {\r
+       timerStop(peer.timer.keepalivePassive)\r
+       if !peer.timer.pendingNewHandshake {\r
+               peer.timer.pendingNewHandshake = true\r
+               peer.timer.newHandshake.Reset(NewHandshakeTime)\r
+       }\r
+}\r
+\r
+/* Event:\r
+ * Received non-empty (authenticated) transport message\r
+ */\r
+func (peer *Peer) TimerDataReceived() {\r
+       if peer.timer.pendingKeepalivePassive {\r
+               peer.timer.needAnotherKeepalive = true\r
+               return\r
+       }\r
+       peer.timer.pendingKeepalivePassive = false\r
+       peer.timer.keepalivePassive.Reset(KeepaliveTimeout)\r
+}\r
+\r
+/* Event:\r
+ * Any (authenticated) packet received\r
+ */\r
+func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {\r
+       timerStop(peer.timer.newHandshake)\r
+}\r
+\r
+/* Event:\r
+ * Any authenticated packet send / received.\r
+ */\r
+func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {\r
+       interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)\r
+       if interval > 0 {\r
+               duration := time.Duration(interval) * time.Second\r
+               peer.timer.keepalivePersistent.Reset(duration)\r
+       }\r
+}\r
+\r
+/* Called after succesfully completing a handshake.\r
+ * i.e. after:\r
+ *\r
+ * - Valid handshake response\r
+ * - First transport message under the "next" key\r
+ */\r
+func (peer *Peer) TimerHandshakeComplete() {\r
+       atomic.StoreInt64(\r
+               &peer.stats.lastHandshakeNano,\r
+               time.Now().UnixNano(),\r
+       )\r
+       signalSend(peer.signal.handshakeCompleted)\r
+       peer.device.log.Info.Println("Negotiated new handshake for", peer.String())\r
+}\r
+\r
+/* Event:\r
+ * An ephemeral key is generated\r
+ *\r
+ * i.e after:\r
+ *\r
+ * CreateMessageInitiation\r
+ * CreateMessageResponse\r
+ *\r
+ * Schedules the deletion of all key material\r
+ * upon failure to complete a handshake\r
+ */\r
+func (peer *Peer) TimerEphemeralKeyCreated() {\r
+       peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)\r
+}\r
+\r
+func (peer *Peer) RoutineTimerHandler() {\r
+       device := peer.device\r
+       indices := &device.indices\r
+\r
+       logDebug := device.log.Debug\r
+       logDebug.Println("Routine, timer handler, started for peer", peer.String())\r
+\r
+       for {\r
+               select {\r
+\r
+               case <-peer.signal.stop:\r
+                       return\r
+\r
+               // keep-alives\r
+\r
+               case <-peer.timer.keepalivePersistent.C:\r
+\r
+                       interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)\r
+                       if interval > 0 {\r
+                               logDebug.Println("Sending keep-alive to", peer.String())\r
+                               peer.SendKeepAlive()\r
+                       }\r
+\r
+               case <-peer.timer.keepalivePassive.C:\r
+\r
+                       logDebug.Println("Sending keep-alive to", peer.String())\r
+\r
+                       peer.SendKeepAlive()\r
+\r
+                       if peer.timer.needAnotherKeepalive {\r
+                               peer.timer.keepalivePassive.Reset(KeepaliveTimeout)\r
+                               peer.timer.needAnotherKeepalive = false\r
+                       }\r
+\r
+               // unresponsive session\r
+\r
+               case <-peer.timer.newHandshake.C:\r
+\r
+                       logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")\r
+\r
+                       signalSend(peer.signal.handshakeBegin)\r
+\r
+               // clear key material\r
+\r
+               case <-peer.timer.zeroAllKeys.C:\r
+\r
+                       logDebug.Println("Clearing all key material for", peer.String())\r
+\r
+                       hs := &peer.handshake\r
+                       hs.mutex.Lock()\r
+\r
+                       kp := &peer.keyPairs\r
+                       kp.mutex.Lock()\r
+\r
+                       // unmap indecies\r
+\r
+                       indices.mutex.Lock()\r
+                       if kp.previous != nil {\r
+                               delete(indices.table, kp.previous.localIndex)\r
+                       }\r
+                       if kp.current != nil {\r
+                               delete(indices.table, kp.current.localIndex)\r
+                       }\r
+                       if kp.next != nil {\r
+                               delete(indices.table, kp.next.localIndex)\r
+                       }\r
+                       delete(indices.table, hs.localIndex)\r
+                       indices.mutex.Unlock()\r
+\r
+                       // zero out key pairs (TODO: better than wait for GC)\r
+\r
+                       kp.current = nil\r
+                       kp.previous = nil\r
+                       kp.next = nil\r
+                       kp.mutex.Unlock()\r
+\r
+                       // zero out handshake\r
+\r
+                       hs.localIndex = 0\r
+                       hs.localEphemeral = NoisePrivateKey{}\r
+                       hs.remoteEphemeral = NoisePublicKey{}\r
+                       hs.chainKey = [blake2s.Size]byte{}\r
+                       hs.hash = [blake2s.Size]byte{}\r
+                       hs.mutex.Unlock()\r
+               }\r
+       }\r
+}\r
+\r
+/* This is the state machine for handshake initiation\r
+ *\r
+ * Associated with this routine is the signal "handshakeBegin"\r
+ * The routine will read from the "handshakeBegin" channel\r
+ * at most every RekeyTimeout seconds\r
+ */\r
+func (peer *Peer) RoutineHandshakeInitiator() {\r
+       device := peer.device\r
+\r
+       logInfo := device.log.Info\r
+       logError := device.log.Error\r
+       logDebug := device.log.Debug\r
+       logDebug.Println("Routine, handshake initator, started for", peer.String())\r
+\r
+       var temp [256]byte\r
+\r
+       for {\r
+\r
+               // wait for signal\r
+\r
+               select {\r
+               case <-peer.signal.handshakeBegin:\r
+               case <-peer.signal.stop:\r
+                       return\r
+               }\r
+\r
+               // set deadline\r
+\r
+       BeginHandshakes:\r
+\r
+               signalClear(peer.signal.handshakeReset)\r
+               deadline := time.NewTimer(RekeyAttemptTime)\r
+\r
+       AttemptHandshakes:\r
+\r
+               for attempts := uint(1); ; attempts++ {\r
+\r
+                       // check if deadline reached\r
+\r
+                       select {\r
+                       case <-deadline.C:\r
+                               logInfo.Println("Handshake negotiation timed out for:", peer.String())\r
+                               signalSend(peer.signal.flushNonceQueue)\r
+                               timerStop(peer.timer.keepalivePersistent)\r
+                               break\r
+                       case <-peer.signal.stop:\r
+                               return\r
+                       default:\r
+                       }\r
+\r
+                       signalClear(peer.signal.handshakeCompleted)\r
+\r
+                       // create initiation message\r
+\r
+                       msg, err := peer.device.CreateMessageInitiation(peer)\r
+                       if err != nil {\r
+                               logError.Println("Failed to create handshake initiation message:", err)\r
+                               break AttemptHandshakes\r
+                       }\r
+\r
+                       jitter := time.Millisecond * time.Duration(rand.Uint32()%334)\r
+\r
+                       // marshal and send\r
+\r
+                       writer := bytes.NewBuffer(temp[:0])\r
+                       binary.Write(writer, binary.LittleEndian, msg)\r
+                       packet := writer.Bytes()\r
+                       peer.mac.AddMacs(packet)\r
+\r
+                       _, err = peer.SendBuffer(packet)\r
+                       if err != nil {\r
+                               logError.Println(\r
+                                       "Failed to send handshake initiation message to",\r
+                                       peer.String(), ":", err,\r
+                               )\r
+                               continue\r
+                       }\r
+\r
+                       peer.TimerAnyAuthenticatedPacketTraversal()\r
+\r
+                       // set handshake timeout\r
+\r
+                       timeout := time.NewTimer(RekeyTimeout + jitter)\r
+                       logDebug.Println(\r
+                               "Handshake initiation attempt",\r
+                               attempts, "sent to", peer.String(),\r
+                       )\r
+\r
+                       // wait for handshake or timeout\r
+\r
+                       select {\r
+\r
+                       case <-peer.signal.stop:\r
+                               return\r
+\r
+                       case <-peer.signal.handshakeCompleted:\r
+                               <-timeout.C\r
+                               break AttemptHandshakes\r
+\r
+                       case <-peer.signal.handshakeReset:\r
+                               <-timeout.C\r
+                               goto BeginHandshakes\r
+\r
+                       case <-timeout.C:\r
+                               // TODO: Clear source address for peer\r
+                               continue\r
+                       }\r
+               }\r
+\r
+               // clear signal set in the meantime\r
+\r
+               signalClear(peer.signal.handshakeBegin)\r
+       }\r
+}\r
diff --git a/src/tun_windows.go b/src/tun_windows.go
new file mode 100644 (file)
index 0000000..0711032
--- /dev/null
@@ -0,0 +1,475 @@
+package main
+
+import (
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "golang.org/x/sys/windows"
+       "golang.org/x/sys/windows/registry"
+       "net"
+       "sync"
+       "syscall"
+       "time"
+       "unsafe"
+)
+
+/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version)
+ *
+ * https://github.com/OpenVPN/tap-windows
+ */
+
+type NativeTUN struct {
+       fd     windows.Handle
+       rl     sync.Mutex
+       wl     sync.Mutex
+       ro     *windows.Overlapped
+       wo     *windows.Overlapped
+       events chan TUNEvent
+       name   string
+}
+
+const (
+       METHOD_BUFFERED = 0
+       ComponentID     = "tap0901" // tap0801
+)
+
+func ctl_code(device_type, function, method, access uint32) uint32 {
+       return (device_type << 16) | (access << 14) | (function << 2) | method
+}
+
+func TAP_CONTROL_CODE(request, method uint32) uint32 {
+       return ctl_code(file_device_unknown, request, method, 0)
+}
+
+var (
+       errIfceNameNotFound = errors.New("Failed to find the name of interface")
+
+       TAP_IOCTL_GET_MAC               = TAP_CONTROL_CODE(1, METHOD_BUFFERED)
+       TAP_IOCTL_GET_VERSION           = TAP_CONTROL_CODE(2, METHOD_BUFFERED)
+       TAP_IOCTL_GET_MTU               = TAP_CONTROL_CODE(3, METHOD_BUFFERED)
+       TAP_IOCTL_GET_INFO              = TAP_CONTROL_CODE(4, METHOD_BUFFERED)
+       TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED)
+       TAP_IOCTL_SET_MEDIA_STATUS      = TAP_CONTROL_CODE(6, METHOD_BUFFERED)
+       TAP_IOCTL_CONFIG_DHCP_MASQ      = TAP_CONTROL_CODE(7, METHOD_BUFFERED)
+       TAP_IOCTL_GET_LOG_LINE          = TAP_CONTROL_CODE(8, METHOD_BUFFERED)
+       TAP_IOCTL_CONFIG_DHCP_SET_OPT   = TAP_CONTROL_CODE(9, METHOD_BUFFERED)
+       TAP_IOCTL_CONFIG_TUN            = TAP_CONTROL_CODE(10, METHOD_BUFFERED)
+
+       file_device_unknown = uint32(0x00000022)
+       nCreateEvent,
+       nResetEvent,
+       nGetOverlappedResult uintptr
+)
+
+func init() {
+       k32, err := windows.LoadLibrary("kernel32.dll")
+       if err != nil {
+               panic("LoadLibrary " + err.Error())
+       }
+       defer windows.FreeLibrary(k32)
+       nCreateEvent = getProcAddr(k32, "CreateEventW")
+       nResetEvent = getProcAddr(k32, "ResetEvent")
+       nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult")
+}
+
+/* implementation of the read/write/closer interface */
+
+func getProcAddr(lib windows.Handle, name string) uintptr {
+       addr, err := windows.GetProcAddress(lib, name)
+       if err != nil {
+               panic(name + " " + err.Error())
+       }
+       return addr
+}
+
+func resetEvent(h windows.Handle) error {
+       r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0)
+       if r == 0 {
+               return err
+       }
+       return nil
+}
+
+func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) {
+       var n int
+       r, _, err := syscall.Syscall6(
+               nGetOverlappedResult,
+               4,
+               uintptr(h),
+               uintptr(unsafe.Pointer(overlapped)),
+               uintptr(unsafe.Pointer(&n)), 1, 0, 0)
+
+       if r == 0 {
+               return n, err
+       }
+       return n, nil
+}
+
+func newOverlapped() (*windows.Overlapped, error) {
+       var overlapped windows.Overlapped
+       r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0)
+       if r == 0 {
+               return nil, err
+       }
+       overlapped.HEvent = windows.Handle(r)
+       return &overlapped, nil
+}
+
+func (f *NativeTUN) Events() chan TUNEvent {
+       return f.events
+}
+
+func (f *NativeTUN) Close() error {
+       return windows.Close(f.fd)
+}
+
+func (f *NativeTUN) Write(b []byte) (int, error) {
+       f.wl.Lock()
+       defer f.wl.Unlock()
+
+       if err := resetEvent(f.wo.HEvent); err != nil {
+               return 0, err
+       }
+       var n uint32
+       err := windows.WriteFile(f.fd, b, &n, f.wo)
+       if err != nil && err != windows.ERROR_IO_PENDING {
+               return int(n), err
+       }
+       return getOverlappedResult(f.fd, f.wo)
+}
+
+func (f *NativeTUN) Read(b []byte) (int, error) {
+       f.rl.Lock()
+       defer f.rl.Unlock()
+
+       if err := resetEvent(f.ro.HEvent); err != nil {
+               return 0, err
+       }
+       var done uint32
+       err := windows.ReadFile(f.fd, b, &done, f.ro)
+       if err != nil && err != windows.ERROR_IO_PENDING {
+               return int(done), err
+       }
+       return getOverlappedResult(f.fd, f.ro)
+}
+
+func getdeviceid(
+       targetComponentId string,
+       targetDeviceName string,
+) (deviceid string, err error) {
+
+       getName := func(instanceId string) (string, error) {
+               path := fmt.Sprintf(
+                       `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`,
+                       instanceId,
+               )
+
+               key, err := registry.OpenKey(
+                       registry.LOCAL_MACHINE,
+                       path,
+                       registry.READ,
+               )
+
+               if err != nil {
+                       return "", err
+               }
+               defer key.Close()
+
+               val, _, err := key.GetStringValue("Name")
+               key.Close()
+               return val, err
+       }
+
+       getInstanceId := func(keyName string) (string, string, error) {
+               path := fmt.Sprintf(
+                       `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`,
+                       keyName,
+               )
+
+               key, err := registry.OpenKey(
+                       registry.LOCAL_MACHINE,
+                       path,
+                       registry.READ,
+               )
+
+               if err != nil {
+                       return "", "", err
+               }
+               defer key.Close()
+
+               componentId, _, err := key.GetStringValue("ComponentId")
+               if err != nil {
+                       return "", "", err
+               }
+
+               instanceId, _, err := key.GetStringValue("NetCfgInstanceId")
+
+               return componentId, instanceId, err
+       }
+
+       // find list of all network devices
+
+       k, err := registry.OpenKey(
+               registry.LOCAL_MACHINE,
+               `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`,
+               registry.READ,
+       )
+
+       if err != nil {
+               return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err)
+       }
+
+       defer k.Close()
+
+       keys, err := k.ReadSubKeyNames(-1)
+
+       if err != nil {
+               return "", err
+       }
+
+       // look for matching component id and name
+
+       var componentFound bool
+
+       for _, v := range keys {
+
+               componentId, instanceId, err := getInstanceId(v)
+               if err != nil || componentId != targetComponentId {
+                       continue
+               }
+
+               componentFound = true
+
+               deviceName, err := getName(instanceId)
+               if err != nil || deviceName != targetDeviceName {
+                       continue
+               }
+
+               return instanceId, nil
+       }
+
+       // provide a descriptive error message
+
+       if componentFound {
+               return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName)
+       }
+
+       return "", fmt.Errorf(
+               "Unable to find device in registry with ComponentId = %s, is tap-windows installed?",
+               targetComponentId,
+       )
+}
+
+// setStatus is used to bring up or bring down the interface
+func setStatus(fd windows.Handle, status bool) error {
+       var code [4]byte
+       if status {
+               binary.LittleEndian.PutUint32(code[:], 1)
+       }
+
+       var bytesReturned uint32
+       rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
+       return windows.DeviceIoControl(
+               fd,
+               TAP_IOCTL_SET_MEDIA_STATUS,
+               &code[0],
+               uint32(4),
+               &rdbbuf[0],
+               uint32(len(rdbbuf)),
+               &bytesReturned,
+               nil,
+       )
+}
+
+/* When operating in TUN mode we must assign an ip address & subnet to the device.
+ *
+ */
+func setTUN(fd windows.Handle, network string) error {
+       var bytesReturned uint32
+       rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
+       localIP, remoteNet, err := net.ParseCIDR(network)
+
+       if err != nil {
+               return fmt.Errorf("Failed to parse network CIDR in config, %v", err)
+       }
+
+       if localIP.To4() == nil {
+               return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network)
+       }
+
+       var param [12]byte
+
+       copy(param[0:4], localIP.To4())
+       copy(param[4:8], remoteNet.IP.To4())
+       copy(param[8:12], remoteNet.Mask)
+
+       return windows.DeviceIoControl(
+               fd,
+               TAP_IOCTL_CONFIG_TUN,
+               &param[0],
+               uint32(12),
+               &rdbbuf[0],
+               uint32(len(rdbbuf)),
+               &bytesReturned,
+               nil,
+       )
+}
+
+func (tun *NativeTUN) MTU() (int, error) {
+       var mtu [4]byte
+       var bytesReturned uint32
+       err := windows.DeviceIoControl(
+               tun.fd,
+               TAP_IOCTL_GET_MTU,
+               &mtu[0],
+               uint32(len(mtu)),
+               &mtu[0],
+               uint32(len(mtu)),
+               &bytesReturned,
+               nil,
+       )
+       val := binary.LittleEndian.Uint32(mtu[:])
+       return int(val), err
+}
+
+func (tun *NativeTUN) Name() string {
+       return tun.name
+}
+
+func CreateTUN(name string) (TUNDevice, error) {
+
+       // find the device in registry.
+
+       deviceid, err := getdeviceid(ComponentID, name)
+       if err != nil {
+               return nil, err
+       }
+       path := "\\\\.\\Global\\" + deviceid + ".tap"
+       pathp, err := windows.UTF16PtrFromString(path)
+       if err != nil {
+               return nil, err
+       }
+
+       // create TUN device
+
+       handle, err := windows.CreateFile(
+               pathp,
+               windows.GENERIC_READ|windows.GENERIC_WRITE,
+               0,
+               nil,
+               windows.OPEN_EXISTING,
+               windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED,
+               0,
+       )
+
+       if err != nil {
+               return nil, err
+       }
+
+       ro, err := newOverlapped()
+       if err != nil {
+               windows.Close(handle)
+               return nil, err
+       }
+
+       wo, err := newOverlapped()
+       if err != nil {
+               windows.Close(handle)
+               return nil, err
+       }
+
+       tun := &NativeTUN{
+               fd:     handle,
+               name:   name,
+               ro:     ro,
+               wo:     wo,
+               events: make(chan TUNEvent, 5),
+       }
+
+       // find addresses of interface
+       // TODO: fix this hack, the question is how
+
+       inter, err := net.InterfaceByName(name)
+       if err != nil {
+               windows.Close(handle)
+               return nil, err
+       }
+
+       addrs, err := inter.Addrs()
+       if err != nil {
+               windows.Close(handle)
+               return nil, err
+       }
+
+       var ip net.IP
+       for _, addr := range addrs {
+               ip = func() net.IP {
+                       switch v := addr.(type) {
+                       case *net.IPNet:
+                               return v.IP.To4()
+                       case *net.IPAddr:
+                               return v.IP.To4()
+                       }
+                       return nil
+               }()
+               if ip != nil {
+                       break
+               }
+       }
+
+       if ip == nil {
+               windows.Close(handle)
+               return nil, errors.New("No IPv4 address found for interface")
+       }
+
+       // bring up device.
+
+       if err := setStatus(handle, true); err != nil {
+               windows.Close(handle)
+               return nil, err
+       }
+
+       // set tun mode
+
+       mask := ip.String() + "/0"
+       if err := setTUN(handle, mask); err != nil {
+               windows.Close(handle)
+               return nil, err
+       }
+
+       // start listener
+
+       go func(native *NativeTUN, ifname string) {
+               // TODO: Fix this very niave implementation
+               var (
+                       statusUp  bool
+                       statusMTU int
+               )
+
+               for ; ; time.Sleep(time.Second) {
+                       intr, err := net.InterfaceByName(name)
+                       if err != nil {
+                               // TODO: handle
+                               return
+                       }
+
+                       // Up / Down event
+                       up := (intr.Flags & net.FlagUp) != 0
+                       if up != statusUp && up {
+                               native.events <- TUNEventUp
+                       }
+                       if up != statusUp && !up {
+                               native.events <- TUNEventDown
+                       }
+                       statusUp = up
+
+                       // MTU changes
+                       if intr.MTU != statusMTU {
+                               native.events <- TUNEventMTUUpdate
+                       }
+                       statusMTU = intr.MTU
+               }
+       }(tun, name)
+
+       return tun, nil
+}
diff --git a/src/uapi_windows.go b/src/uapi_windows.go
new file mode 100644 (file)
index 0000000..d56e965
--- /dev/null
@@ -0,0 +1,44 @@
+package main
+
+/* UAPI on windows uses a bidirectional named pipe
+ */
+
+import (
+       "fmt"
+       "github.com/Microsoft/go-winio"
+       "golang.org/x/sys/windows"
+       "net"
+)
+
+const (
+       ipcErrorIO         = -int64(windows.ERROR_BROKEN_PIPE)
+       ipcErrorNotDefined = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
+       ipcErrorProtocol   = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
+       ipcErrorInvalid    = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
+)
+
+const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s"
+
+type UAPIListener struct {
+       listener net.Listener
+}
+
+func (uapi *UAPIListener) Accept() (net.Conn, error) {
+       return nil, nil
+}
+
+func (uapi *UAPIListener) Close() error {
+       return uapi.listener.Close()
+}
+
+func (uapi *UAPIListener) Addr() net.Addr {
+       return nil
+}
+
+func NewUAPIListener(name string) (net.Listener, error) {
+       path := fmt.Sprintf(PipeNameFmt, name)
+       return winio.ListenPipe(path, &winio.PipeConfig{
+               InputBufferSize:  2048,
+               OutputBufferSize: 2048,
+       })
+}