-package main
-
-import (
- "bytes"
- "encoding/binary"
- "golang.org/x/crypto/blake2s"
- "math/rand"
- "sync/atomic"
- "time"
-)
-
-/* Called when a new authenticated message has been send
- *
- */
-func (peer *Peer) KeepKeyFreshSending() {
- kp := peer.keyPairs.Current()
- if kp == nil {
- return
- }
- nonce := atomic.LoadUint64(&kp.sendNonce)
- if nonce > RekeyAfterMessages {
- signalSend(peer.signal.handshakeBegin)
- }
- if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
- signalSend(peer.signal.handshakeBegin)
- }
-}
-
-/* Called when a new authenticated message has been recevied
- *
- */
-func (peer *Peer) KeepKeyFreshReceiving() {
- // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
- kp := peer.keyPairs.Current()
- if kp == nil {
- return
- }
- if !kp.isInitiator {
- return
- }
- nonce := atomic.LoadUint64(&kp.sendNonce)
- send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
- if send {
- signalSend(peer.signal.handshakeBegin)
- }
-}
-
-/* Queues a keep-alive if no packets are queued for peer
- */
-func (peer *Peer) SendKeepAlive() bool {
- elem := peer.device.NewOutboundElement()
- elem.packet = nil
- if len(peer.queue.nonce) == 0 {
- select {
- case peer.queue.nonce <- elem:
- return true
- default:
- return false
- }
- }
- return true
-}
-
-/* Event:
- * Sent non-empty (authenticated) transport message
- */
-func (peer *Peer) TimerDataSent() {
- timerStop(peer.timer.keepalivePassive)
- if !peer.timer.pendingNewHandshake {
- peer.timer.pendingNewHandshake = true
- peer.timer.newHandshake.Reset(NewHandshakeTime)
- }
-}
-
-/* Event:
- * Received non-empty (authenticated) transport message
- */
-func (peer *Peer) TimerDataReceived() {
- if peer.timer.pendingKeepalivePassive {
- peer.timer.needAnotherKeepalive = true
- return
- }
- peer.timer.pendingKeepalivePassive = false
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
-}
-
-/* Event:
- * Any (authenticated) packet received
- */
-func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- timerStop(peer.timer.newHandshake)
-}
-
-/* Event:
- * Any authenticated packet send / received.
- */
-func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
- interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
- if interval > 0 {
- duration := time.Duration(interval) * time.Second
- peer.timer.keepalivePersistent.Reset(duration)
- }
-}
-
-/* Called after succesfully completing a handshake.
- * i.e. after:
- *
- * - Valid handshake response
- * - First transport message under the "next" key
- */
-func (peer *Peer) TimerHandshakeComplete() {
- atomic.StoreInt64(
- &peer.stats.lastHandshakeNano,
- time.Now().UnixNano(),
- )
- signalSend(peer.signal.handshakeCompleted)
- peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
-}
-
-/* Event:
- * An ephemeral key is generated
- *
- * i.e after:
- *
- * CreateMessageInitiation
- * CreateMessageResponse
- *
- * Schedules the deletion of all key material
- * upon failure to complete a handshake
- */
-func (peer *Peer) TimerEphemeralKeyCreated() {
- peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
-}
-
-func (peer *Peer) RoutineTimerHandler() {
- device := peer.device
- indices := &device.indices
-
- logDebug := device.log.Debug
- logDebug.Println("Routine, timer handler, started for peer", peer.String())
-
- for {
- select {
-
- case <-peer.signal.stop:
- return
-
- // keep-alives
-
- case <-peer.timer.keepalivePersistent.C:
-
- interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
- if interval > 0 {
- logDebug.Println("Sending keep-alive to", peer.String())
- peer.SendKeepAlive()
- }
-
- case <-peer.timer.keepalivePassive.C:
-
- logDebug.Println("Sending keep-alive to", peer.String())
-
- peer.SendKeepAlive()
-
- if peer.timer.needAnotherKeepalive {
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
- peer.timer.needAnotherKeepalive = false
- }
-
- // unresponsive session
-
- case <-peer.timer.newHandshake.C:
-
- logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
-
- signalSend(peer.signal.handshakeBegin)
-
- // clear key material
-
- case <-peer.timer.zeroAllKeys.C:
-
- logDebug.Println("Clearing all key material for", peer.String())
-
- hs := &peer.handshake
- hs.mutex.Lock()
-
- kp := &peer.keyPairs
- kp.mutex.Lock()
-
- // unmap indecies
-
- indices.mutex.Lock()
- if kp.previous != nil {
- delete(indices.table, kp.previous.localIndex)
- }
- if kp.current != nil {
- delete(indices.table, kp.current.localIndex)
- }
- if kp.next != nil {
- delete(indices.table, kp.next.localIndex)
- }
- delete(indices.table, hs.localIndex)
- indices.mutex.Unlock()
-
- // zero out key pairs (TODO: better than wait for GC)
-
- kp.current = nil
- kp.previous = nil
- kp.next = nil
- kp.mutex.Unlock()
-
- // zero out handshake
-
- hs.localIndex = 0
- hs.localEphemeral = NoisePrivateKey{}
- hs.remoteEphemeral = NoisePublicKey{}
- hs.chainKey = [blake2s.Size]byte{}
- hs.hash = [blake2s.Size]byte{}
- hs.mutex.Unlock()
- }
- }
-}
-
-/* This is the state machine for handshake initiation
- *
- * Associated with this routine is the signal "handshakeBegin"
- * The routine will read from the "handshakeBegin" channel
- * at most every RekeyTimeout seconds
- */
-func (peer *Peer) RoutineHandshakeInitiator() {
- device := peer.device
-
- logInfo := device.log.Info
- logError := device.log.Error
- logDebug := device.log.Debug
- logDebug.Println("Routine, handshake initator, started for", peer.String())
-
- var temp [256]byte
-
- for {
-
- // wait for signal
-
- select {
- case <-peer.signal.handshakeBegin:
- case <-peer.signal.stop:
- return
- }
-
- // set deadline
-
- BeginHandshakes:
-
- signalClear(peer.signal.handshakeReset)
- deadline := time.NewTimer(RekeyAttemptTime)
-
- AttemptHandshakes:
-
- for attempts := uint(1); ; attempts++ {
-
- // check if deadline reached
-
- select {
- case <-deadline.C:
- logInfo.Println("Handshake negotiation timed out for:", peer.String())
- signalSend(peer.signal.flushNonceQueue)
- timerStop(peer.timer.keepalivePersistent)
- break
- case <-peer.signal.stop:
- return
- default:
- }
-
- signalClear(peer.signal.handshakeCompleted)
-
- // create initiation message
-
- msg, err := peer.device.CreateMessageInitiation(peer)
- if err != nil {
- logError.Println("Failed to create handshake initiation message:", err)
- break AttemptHandshakes
- }
-
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
-
- // marshal and send
-
- writer := bytes.NewBuffer(temp[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- _, err = peer.SendBuffer(packet)
- if err != nil {
- logError.Println(
- "Failed to send handshake initiation message to",
- peer.String(), ":", err,
- )
- break
- }
-
- peer.TimerAnyAuthenticatedPacketTraversal()
-
- // set handshake timeout
-
- timeout := time.NewTimer(RekeyTimeout + jitter)
- logDebug.Println(
- "Handshake initiation attempt",
- attempts, "sent to", peer.String(),
- )
-
- // wait for handshake or timeout
-
- select {
-
- case <-peer.signal.stop:
- return
-
- case <-peer.signal.handshakeCompleted:
- <-timeout.C
- break AttemptHandshakes
-
- case <-peer.signal.handshakeReset:
- <-timeout.C
- goto BeginHandshakes
-
- case <-timeout.C:
- // TODO: Clear source address for peer
- continue
- }
- }
-
- // clear signal set in the meantime
-
- signalClear(peer.signal.handshakeBegin)
- }
-}
+package main\r
+\r
+import (\r
+ "bytes"\r
+ "encoding/binary"\r
+ "golang.org/x/crypto/blake2s"\r
+ "math/rand"\r
+ "sync/atomic"\r
+ "time"\r
+)\r
+\r
+/* Called when a new authenticated message has been send\r
+ *\r
+ */\r
+func (peer *Peer) KeepKeyFreshSending() {\r
+ kp := peer.keyPairs.Current()\r
+ if kp == nil {\r
+ return\r
+ }\r
+ nonce := atomic.LoadUint64(&kp.sendNonce)\r
+ if nonce > RekeyAfterMessages {\r
+ signalSend(peer.signal.handshakeBegin)\r
+ }\r
+ if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {\r
+ signalSend(peer.signal.handshakeBegin)\r
+ }\r
+}\r
+\r
+/* Called when a new authenticated message has been recevied\r
+ *\r
+ */\r
+func (peer *Peer) KeepKeyFreshReceiving() {\r
+ // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)\r
+ kp := peer.keyPairs.Current()\r
+ if kp == nil {\r
+ return\r
+ }\r
+ if !kp.isInitiator {\r
+ return\r
+ }\r
+ nonce := atomic.LoadUint64(&kp.sendNonce)\r
+ send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving\r
+ if send {\r
+ signalSend(peer.signal.handshakeBegin)\r
+ }\r
+}\r
+\r
+/* Queues a keep-alive if no packets are queued for peer\r
+ */\r
+func (peer *Peer) SendKeepAlive() bool {\r
+ elem := peer.device.NewOutboundElement()\r
+ elem.packet = nil\r
+ if len(peer.queue.nonce) == 0 {\r
+ select {\r
+ case peer.queue.nonce <- elem:\r
+ return true\r
+ default:\r
+ return false\r
+ }\r
+ }\r
+ return true\r
+}\r
+\r
+/* Event:\r
+ * Sent non-empty (authenticated) transport message\r
+ */\r
+func (peer *Peer) TimerDataSent() {\r
+ timerStop(peer.timer.keepalivePassive)\r
+ if !peer.timer.pendingNewHandshake {\r
+ peer.timer.pendingNewHandshake = true\r
+ peer.timer.newHandshake.Reset(NewHandshakeTime)\r
+ }\r
+}\r
+\r
+/* Event:\r
+ * Received non-empty (authenticated) transport message\r
+ */\r
+func (peer *Peer) TimerDataReceived() {\r
+ if peer.timer.pendingKeepalivePassive {\r
+ peer.timer.needAnotherKeepalive = true\r
+ return\r
+ }\r
+ peer.timer.pendingKeepalivePassive = false\r
+ peer.timer.keepalivePassive.Reset(KeepaliveTimeout)\r
+}\r
+\r
+/* Event:\r
+ * Any (authenticated) packet received\r
+ */\r
+func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {\r
+ timerStop(peer.timer.newHandshake)\r
+}\r
+\r
+/* Event:\r
+ * Any authenticated packet send / received.\r
+ */\r
+func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {\r
+ interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)\r
+ if interval > 0 {\r
+ duration := time.Duration(interval) * time.Second\r
+ peer.timer.keepalivePersistent.Reset(duration)\r
+ }\r
+}\r
+\r
+/* Called after succesfully completing a handshake.\r
+ * i.e. after:\r
+ *\r
+ * - Valid handshake response\r
+ * - First transport message under the "next" key\r
+ */\r
+func (peer *Peer) TimerHandshakeComplete() {\r
+ atomic.StoreInt64(\r
+ &peer.stats.lastHandshakeNano,\r
+ time.Now().UnixNano(),\r
+ )\r
+ signalSend(peer.signal.handshakeCompleted)\r
+ peer.device.log.Info.Println("Negotiated new handshake for", peer.String())\r
+}\r
+\r
+/* Event:\r
+ * An ephemeral key is generated\r
+ *\r
+ * i.e after:\r
+ *\r
+ * CreateMessageInitiation\r
+ * CreateMessageResponse\r
+ *\r
+ * Schedules the deletion of all key material\r
+ * upon failure to complete a handshake\r
+ */\r
+func (peer *Peer) TimerEphemeralKeyCreated() {\r
+ peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)\r
+}\r
+\r
+func (peer *Peer) RoutineTimerHandler() {\r
+ device := peer.device\r
+ indices := &device.indices\r
+\r
+ logDebug := device.log.Debug\r
+ logDebug.Println("Routine, timer handler, started for peer", peer.String())\r
+\r
+ for {\r
+ select {\r
+\r
+ case <-peer.signal.stop:\r
+ return\r
+\r
+ // keep-alives\r
+\r
+ case <-peer.timer.keepalivePersistent.C:\r
+\r
+ interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)\r
+ if interval > 0 {\r
+ logDebug.Println("Sending keep-alive to", peer.String())\r
+ peer.SendKeepAlive()\r
+ }\r
+\r
+ case <-peer.timer.keepalivePassive.C:\r
+\r
+ logDebug.Println("Sending keep-alive to", peer.String())\r
+\r
+ peer.SendKeepAlive()\r
+\r
+ if peer.timer.needAnotherKeepalive {\r
+ peer.timer.keepalivePassive.Reset(KeepaliveTimeout)\r
+ peer.timer.needAnotherKeepalive = false\r
+ }\r
+\r
+ // unresponsive session\r
+\r
+ case <-peer.timer.newHandshake.C:\r
+\r
+ logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")\r
+\r
+ signalSend(peer.signal.handshakeBegin)\r
+\r
+ // clear key material\r
+\r
+ case <-peer.timer.zeroAllKeys.C:\r
+\r
+ logDebug.Println("Clearing all key material for", peer.String())\r
+\r
+ hs := &peer.handshake\r
+ hs.mutex.Lock()\r
+\r
+ kp := &peer.keyPairs\r
+ kp.mutex.Lock()\r
+\r
+ // unmap indecies\r
+\r
+ indices.mutex.Lock()\r
+ if kp.previous != nil {\r
+ delete(indices.table, kp.previous.localIndex)\r
+ }\r
+ if kp.current != nil {\r
+ delete(indices.table, kp.current.localIndex)\r
+ }\r
+ if kp.next != nil {\r
+ delete(indices.table, kp.next.localIndex)\r
+ }\r
+ delete(indices.table, hs.localIndex)\r
+ indices.mutex.Unlock()\r
+\r
+ // zero out key pairs (TODO: better than wait for GC)\r
+\r
+ kp.current = nil\r
+ kp.previous = nil\r
+ kp.next = nil\r
+ kp.mutex.Unlock()\r
+\r
+ // zero out handshake\r
+\r
+ hs.localIndex = 0\r
+ hs.localEphemeral = NoisePrivateKey{}\r
+ hs.remoteEphemeral = NoisePublicKey{}\r
+ hs.chainKey = [blake2s.Size]byte{}\r
+ hs.hash = [blake2s.Size]byte{}\r
+ hs.mutex.Unlock()\r
+ }\r
+ }\r
+}\r
+\r
+/* This is the state machine for handshake initiation\r
+ *\r
+ * Associated with this routine is the signal "handshakeBegin"\r
+ * The routine will read from the "handshakeBegin" channel\r
+ * at most every RekeyTimeout seconds\r
+ */\r
+func (peer *Peer) RoutineHandshakeInitiator() {\r
+ device := peer.device\r
+\r
+ logInfo := device.log.Info\r
+ logError := device.log.Error\r
+ logDebug := device.log.Debug\r
+ logDebug.Println("Routine, handshake initator, started for", peer.String())\r
+\r
+ var temp [256]byte\r
+\r
+ for {\r
+\r
+ // wait for signal\r
+\r
+ select {\r
+ case <-peer.signal.handshakeBegin:\r
+ case <-peer.signal.stop:\r
+ return\r
+ }\r
+\r
+ // set deadline\r
+\r
+ BeginHandshakes:\r
+\r
+ signalClear(peer.signal.handshakeReset)\r
+ deadline := time.NewTimer(RekeyAttemptTime)\r
+\r
+ AttemptHandshakes:\r
+\r
+ for attempts := uint(1); ; attempts++ {\r
+\r
+ // check if deadline reached\r
+\r
+ select {\r
+ case <-deadline.C:\r
+ logInfo.Println("Handshake negotiation timed out for:", peer.String())\r
+ signalSend(peer.signal.flushNonceQueue)\r
+ timerStop(peer.timer.keepalivePersistent)\r
+ break\r
+ case <-peer.signal.stop:\r
+ return\r
+ default:\r
+ }\r
+\r
+ signalClear(peer.signal.handshakeCompleted)\r
+\r
+ // create initiation message\r
+\r
+ msg, err := peer.device.CreateMessageInitiation(peer)\r
+ if err != nil {\r
+ logError.Println("Failed to create handshake initiation message:", err)\r
+ break AttemptHandshakes\r
+ }\r
+\r
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)\r
+\r
+ // marshal and send\r
+\r
+ writer := bytes.NewBuffer(temp[:0])\r
+ binary.Write(writer, binary.LittleEndian, msg)\r
+ packet := writer.Bytes()\r
+ peer.mac.AddMacs(packet)\r
+\r
+ _, err = peer.SendBuffer(packet)\r
+ if err != nil {\r
+ logError.Println(\r
+ "Failed to send handshake initiation message to",\r
+ peer.String(), ":", err,\r
+ )\r
+ continue\r
+ }\r
+\r
+ peer.TimerAnyAuthenticatedPacketTraversal()\r
+\r
+ // set handshake timeout\r
+\r
+ timeout := time.NewTimer(RekeyTimeout + jitter)\r
+ logDebug.Println(\r
+ "Handshake initiation attempt",\r
+ attempts, "sent to", peer.String(),\r
+ )\r
+\r
+ // wait for handshake or timeout\r
+\r
+ select {\r
+\r
+ case <-peer.signal.stop:\r
+ return\r
+\r
+ case <-peer.signal.handshakeCompleted:\r
+ <-timeout.C\r
+ break AttemptHandshakes\r
+\r
+ case <-peer.signal.handshakeReset:\r
+ <-timeout.C\r
+ goto BeginHandshakes\r
+\r
+ case <-timeout.C:\r
+ // TODO: Clear source address for peer\r
+ continue\r
+ }\r
+ }\r
+\r
+ // clear signal set in the meantime\r
+\r
+ signalClear(peer.signal.handshakeBegin)\r
+ }\r
+}\r
--- /dev/null
+package main
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
+ "net"
+ "sync"
+ "syscall"
+ "time"
+ "unsafe"
+)
+
+/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version)
+ *
+ * https://github.com/OpenVPN/tap-windows
+ */
+
+type NativeTUN struct {
+ fd windows.Handle
+ rl sync.Mutex
+ wl sync.Mutex
+ ro *windows.Overlapped
+ wo *windows.Overlapped
+ events chan TUNEvent
+ name string
+}
+
+const (
+ METHOD_BUFFERED = 0
+ ComponentID = "tap0901" // tap0801
+)
+
+func ctl_code(device_type, function, method, access uint32) uint32 {
+ return (device_type << 16) | (access << 14) | (function << 2) | method
+}
+
+func TAP_CONTROL_CODE(request, method uint32) uint32 {
+ return ctl_code(file_device_unknown, request, method, 0)
+}
+
+var (
+ errIfceNameNotFound = errors.New("Failed to find the name of interface")
+
+ TAP_IOCTL_GET_MAC = TAP_CONTROL_CODE(1, METHOD_BUFFERED)
+ TAP_IOCTL_GET_VERSION = TAP_CONTROL_CODE(2, METHOD_BUFFERED)
+ TAP_IOCTL_GET_MTU = TAP_CONTROL_CODE(3, METHOD_BUFFERED)
+ TAP_IOCTL_GET_INFO = TAP_CONTROL_CODE(4, METHOD_BUFFERED)
+ TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED)
+ TAP_IOCTL_SET_MEDIA_STATUS = TAP_CONTROL_CODE(6, METHOD_BUFFERED)
+ TAP_IOCTL_CONFIG_DHCP_MASQ = TAP_CONTROL_CODE(7, METHOD_BUFFERED)
+ TAP_IOCTL_GET_LOG_LINE = TAP_CONTROL_CODE(8, METHOD_BUFFERED)
+ TAP_IOCTL_CONFIG_DHCP_SET_OPT = TAP_CONTROL_CODE(9, METHOD_BUFFERED)
+ TAP_IOCTL_CONFIG_TUN = TAP_CONTROL_CODE(10, METHOD_BUFFERED)
+
+ file_device_unknown = uint32(0x00000022)
+ nCreateEvent,
+ nResetEvent,
+ nGetOverlappedResult uintptr
+)
+
+func init() {
+ k32, err := windows.LoadLibrary("kernel32.dll")
+ if err != nil {
+ panic("LoadLibrary " + err.Error())
+ }
+ defer windows.FreeLibrary(k32)
+ nCreateEvent = getProcAddr(k32, "CreateEventW")
+ nResetEvent = getProcAddr(k32, "ResetEvent")
+ nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult")
+}
+
+/* implementation of the read/write/closer interface */
+
+func getProcAddr(lib windows.Handle, name string) uintptr {
+ addr, err := windows.GetProcAddress(lib, name)
+ if err != nil {
+ panic(name + " " + err.Error())
+ }
+ return addr
+}
+
+func resetEvent(h windows.Handle) error {
+ r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0)
+ if r == 0 {
+ return err
+ }
+ return nil
+}
+
+func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) {
+ var n int
+ r, _, err := syscall.Syscall6(
+ nGetOverlappedResult,
+ 4,
+ uintptr(h),
+ uintptr(unsafe.Pointer(overlapped)),
+ uintptr(unsafe.Pointer(&n)), 1, 0, 0)
+
+ if r == 0 {
+ return n, err
+ }
+ return n, nil
+}
+
+func newOverlapped() (*windows.Overlapped, error) {
+ var overlapped windows.Overlapped
+ r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0)
+ if r == 0 {
+ return nil, err
+ }
+ overlapped.HEvent = windows.Handle(r)
+ return &overlapped, nil
+}
+
+func (f *NativeTUN) Events() chan TUNEvent {
+ return f.events
+}
+
+func (f *NativeTUN) Close() error {
+ return windows.Close(f.fd)
+}
+
+func (f *NativeTUN) Write(b []byte) (int, error) {
+ f.wl.Lock()
+ defer f.wl.Unlock()
+
+ if err := resetEvent(f.wo.HEvent); err != nil {
+ return 0, err
+ }
+ var n uint32
+ err := windows.WriteFile(f.fd, b, &n, f.wo)
+ if err != nil && err != windows.ERROR_IO_PENDING {
+ return int(n), err
+ }
+ return getOverlappedResult(f.fd, f.wo)
+}
+
+func (f *NativeTUN) Read(b []byte) (int, error) {
+ f.rl.Lock()
+ defer f.rl.Unlock()
+
+ if err := resetEvent(f.ro.HEvent); err != nil {
+ return 0, err
+ }
+ var done uint32
+ err := windows.ReadFile(f.fd, b, &done, f.ro)
+ if err != nil && err != windows.ERROR_IO_PENDING {
+ return int(done), err
+ }
+ return getOverlappedResult(f.fd, f.ro)
+}
+
+func getdeviceid(
+ targetComponentId string,
+ targetDeviceName string,
+) (deviceid string, err error) {
+
+ getName := func(instanceId string) (string, error) {
+ path := fmt.Sprintf(
+ `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`,
+ instanceId,
+ )
+
+ key, err := registry.OpenKey(
+ registry.LOCAL_MACHINE,
+ path,
+ registry.READ,
+ )
+
+ if err != nil {
+ return "", err
+ }
+ defer key.Close()
+
+ val, _, err := key.GetStringValue("Name")
+ key.Close()
+ return val, err
+ }
+
+ getInstanceId := func(keyName string) (string, string, error) {
+ path := fmt.Sprintf(
+ `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`,
+ keyName,
+ )
+
+ key, err := registry.OpenKey(
+ registry.LOCAL_MACHINE,
+ path,
+ registry.READ,
+ )
+
+ if err != nil {
+ return "", "", err
+ }
+ defer key.Close()
+
+ componentId, _, err := key.GetStringValue("ComponentId")
+ if err != nil {
+ return "", "", err
+ }
+
+ instanceId, _, err := key.GetStringValue("NetCfgInstanceId")
+
+ return componentId, instanceId, err
+ }
+
+ // find list of all network devices
+
+ k, err := registry.OpenKey(
+ registry.LOCAL_MACHINE,
+ `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`,
+ registry.READ,
+ )
+
+ if err != nil {
+ return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err)
+ }
+
+ defer k.Close()
+
+ keys, err := k.ReadSubKeyNames(-1)
+
+ if err != nil {
+ return "", err
+ }
+
+ // look for matching component id and name
+
+ var componentFound bool
+
+ for _, v := range keys {
+
+ componentId, instanceId, err := getInstanceId(v)
+ if err != nil || componentId != targetComponentId {
+ continue
+ }
+
+ componentFound = true
+
+ deviceName, err := getName(instanceId)
+ if err != nil || deviceName != targetDeviceName {
+ continue
+ }
+
+ return instanceId, nil
+ }
+
+ // provide a descriptive error message
+
+ if componentFound {
+ return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName)
+ }
+
+ return "", fmt.Errorf(
+ "Unable to find device in registry with ComponentId = %s, is tap-windows installed?",
+ targetComponentId,
+ )
+}
+
+// setStatus is used to bring up or bring down the interface
+func setStatus(fd windows.Handle, status bool) error {
+ var code [4]byte
+ if status {
+ binary.LittleEndian.PutUint32(code[:], 1)
+ }
+
+ var bytesReturned uint32
+ rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
+ return windows.DeviceIoControl(
+ fd,
+ TAP_IOCTL_SET_MEDIA_STATUS,
+ &code[0],
+ uint32(4),
+ &rdbbuf[0],
+ uint32(len(rdbbuf)),
+ &bytesReturned,
+ nil,
+ )
+}
+
+/* When operating in TUN mode we must assign an ip address & subnet to the device.
+ *
+ */
+func setTUN(fd windows.Handle, network string) error {
+ var bytesReturned uint32
+ rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
+ localIP, remoteNet, err := net.ParseCIDR(network)
+
+ if err != nil {
+ return fmt.Errorf("Failed to parse network CIDR in config, %v", err)
+ }
+
+ if localIP.To4() == nil {
+ return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network)
+ }
+
+ var param [12]byte
+
+ copy(param[0:4], localIP.To4())
+ copy(param[4:8], remoteNet.IP.To4())
+ copy(param[8:12], remoteNet.Mask)
+
+ return windows.DeviceIoControl(
+ fd,
+ TAP_IOCTL_CONFIG_TUN,
+ ¶m[0],
+ uint32(12),
+ &rdbbuf[0],
+ uint32(len(rdbbuf)),
+ &bytesReturned,
+ nil,
+ )
+}
+
+func (tun *NativeTUN) MTU() (int, error) {
+ var mtu [4]byte
+ var bytesReturned uint32
+ err := windows.DeviceIoControl(
+ tun.fd,
+ TAP_IOCTL_GET_MTU,
+ &mtu[0],
+ uint32(len(mtu)),
+ &mtu[0],
+ uint32(len(mtu)),
+ &bytesReturned,
+ nil,
+ )
+ val := binary.LittleEndian.Uint32(mtu[:])
+ return int(val), err
+}
+
+func (tun *NativeTUN) Name() string {
+ return tun.name
+}
+
+func CreateTUN(name string) (TUNDevice, error) {
+
+ // find the device in registry.
+
+ deviceid, err := getdeviceid(ComponentID, name)
+ if err != nil {
+ return nil, err
+ }
+ path := "\\\\.\\Global\\" + deviceid + ".tap"
+ pathp, err := windows.UTF16PtrFromString(path)
+ if err != nil {
+ return nil, err
+ }
+
+ // create TUN device
+
+ handle, err := windows.CreateFile(
+ pathp,
+ windows.GENERIC_READ|windows.GENERIC_WRITE,
+ 0,
+ nil,
+ windows.OPEN_EXISTING,
+ windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED,
+ 0,
+ )
+
+ if err != nil {
+ return nil, err
+ }
+
+ ro, err := newOverlapped()
+ if err != nil {
+ windows.Close(handle)
+ return nil, err
+ }
+
+ wo, err := newOverlapped()
+ if err != nil {
+ windows.Close(handle)
+ return nil, err
+ }
+
+ tun := &NativeTUN{
+ fd: handle,
+ name: name,
+ ro: ro,
+ wo: wo,
+ events: make(chan TUNEvent, 5),
+ }
+
+ // find addresses of interface
+ // TODO: fix this hack, the question is how
+
+ inter, err := net.InterfaceByName(name)
+ if err != nil {
+ windows.Close(handle)
+ return nil, err
+ }
+
+ addrs, err := inter.Addrs()
+ if err != nil {
+ windows.Close(handle)
+ return nil, err
+ }
+
+ var ip net.IP
+ for _, addr := range addrs {
+ ip = func() net.IP {
+ switch v := addr.(type) {
+ case *net.IPNet:
+ return v.IP.To4()
+ case *net.IPAddr:
+ return v.IP.To4()
+ }
+ return nil
+ }()
+ if ip != nil {
+ break
+ }
+ }
+
+ if ip == nil {
+ windows.Close(handle)
+ return nil, errors.New("No IPv4 address found for interface")
+ }
+
+ // bring up device.
+
+ if err := setStatus(handle, true); err != nil {
+ windows.Close(handle)
+ return nil, err
+ }
+
+ // set tun mode
+
+ mask := ip.String() + "/0"
+ if err := setTUN(handle, mask); err != nil {
+ windows.Close(handle)
+ return nil, err
+ }
+
+ // start listener
+
+ go func(native *NativeTUN, ifname string) {
+ // TODO: Fix this very niave implementation
+ var (
+ statusUp bool
+ statusMTU int
+ )
+
+ for ; ; time.Sleep(time.Second) {
+ intr, err := net.InterfaceByName(name)
+ if err != nil {
+ // TODO: handle
+ return
+ }
+
+ // Up / Down event
+ up := (intr.Flags & net.FlagUp) != 0
+ if up != statusUp && up {
+ native.events <- TUNEventUp
+ }
+ if up != statusUp && !up {
+ native.events <- TUNEventDown
+ }
+ statusUp = up
+
+ // MTU changes
+ if intr.MTU != statusMTU {
+ native.events <- TUNEventMTUUpdate
+ }
+ statusMTU = intr.MTU
+ }
+ }(tun, name)
+
+ return tun, nil
+}