]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
net: implement ECN handling, rfc6040 style fd/propagate-DSCP-bits
authorFlorent Daigniere <nextgens@freenetproject.org>
Sat, 23 Feb 2019 20:50:04 +0000 (21:50 +0100)
committerFlorent Daigniere <nextgens@freenetproject.org>
Mon, 25 Feb 2019 17:20:23 +0000 (18:20 +0100)
To decide whether we should use the compatibility mode or the normal
mode with a peer, we use the handshake messages as a signaling channel.

If we receive the expected ECN bits, it most likely means they're
running a compatible version.

Signed-off-by: Florent Daigniere <nextgens@freenetproject.org>
conn.go
conn_default.go
conn_linux.go
misc.go
peer.go
receive.go
send.go

diff --git a/conn.go b/conn.go
index b8970e70a993faa61da190e1e7c4bfd21e739db2..e38160a258f1699758cbcdfbe0e573025eb4dc20 100644 (file)
--- a/conn.go
+++ b/conn.go
@@ -20,8 +20,8 @@ const (
  */
 type Bind interface {
        SetMark(value uint32) error
-       ReceiveIPv6(buff []byte) (int, Endpoint, error)
-       ReceiveIPv4(buff []byte) (int, Endpoint, error)
+       ReceiveIPv6(buff []byte) (int, Endpoint, byte, error)
+       ReceiveIPv4(buff []byte) (int, Endpoint, byte, error)
        Send(buff []byte, end Endpoint, tos byte) error
        Close() error
 }
index 6f17de5b61f191f927ff4b4f637840b6ca53c0fe..1b2586383bc8fdda4dbe662b67e660b23e738538 100644 (file)
@@ -133,26 +133,29 @@ func (bind *NativeBind) Close() error {
        return err2
 }
 
-func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+// TODO: implement TOS
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) {
        if bind.ipv4 == nil {
-               return 0, nil, syscall.EAFNOSUPPORT
+               return 0, nil, 0, syscall.EAFNOSUPPORT
        }
        n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
        if endpoint != nil {
                endpoint.IP = endpoint.IP.To4()
        }
-       return n, (*NativeEndpoint)(endpoint), err
+       return n, (*NativeEndpoint)(endpoint), 0, err
 }
 
-func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+// TODO: implement TOS
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) {
        if bind.ipv6 == nil {
-               return 0, nil, syscall.EAFNOSUPPORT
+               return 0, nil, 0, syscall.EAFNOSUPPORT
        }
        n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
-       return n, (*NativeEndpoint)(endpoint), err
+       return n, (*NativeEndpoint)(endpoint), 0, err
 }
 
-func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
+// TODO: implement TOS
+func (bind *NativeBind) Send(buff []byte, endpoint Endpoint, tos byte) error {
        var err error
        nend := endpoint.(*NativeEndpoint)
        if nend.IP.To4() != nil {
index 83cf1a2683e82d645cf97d4035cec4b20ed4c3b3..cc1ce2e8a410a9e254b2d91d4ebade578784702d 100644 (file)
@@ -232,30 +232,32 @@ func (bind *NativeBind) Close() error {
        return err3
 }
 
-func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) {
        var end NativeEndpoint
+       var tos byte
        if bind.sock6 == -1 {
-               return 0, nil, syscall.EAFNOSUPPORT
+               return 0, nil, tos, syscall.EAFNOSUPPORT
        }
-       n, err := receive6(
+       n, tos, err := receive6(
                bind.sock6,
                buff,
                &end,
        )
-       return n, &end, err
+       return n, &end, tos, err
 }
 
-func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) {
        var end NativeEndpoint
+       var tos byte
        if bind.sock4 == -1 {
-               return 0, nil, syscall.EAFNOSUPPORT
+               return 0, nil, tos, syscall.EAFNOSUPPORT
        }
-       n, err := receive4(
+       n, tos, err := receive4(
                bind.sock4,
                buff,
                &end,
        )
-       return n, &end, err
+       return n, &end, tos, err
 }
 
 func (bind *NativeBind) Send(buff []byte, end Endpoint, tos byte) error {
@@ -384,6 +386,15 @@ func create4(port uint16) (int, uint16, error) {
                        return err
                }
 
+               if err := unix.SetsockoptInt(
+                       fd,
+                       unix.IPPROTO_IP,
+                       unix.IP_RECVTOS,
+                       1,
+               ); err != nil {
+                       return err
+               }
+
                return unix.Bind(fd, &addr)
        }(); err != nil {
                unix.Close(fd)
@@ -442,6 +453,15 @@ func create6(port uint16) (int, uint16, error) {
                        return err
                }
 
+               if err := unix.SetsockoptInt(
+                       fd,
+                       unix.IPPROTO_IPV6,
+                       unix.IPV6_RECVTCLASS,
+                       1,
+               ); err != nil {
+                       return err
+               }
+
                return unix.Bind(fd, &addr)
 
        }(); err != nil {
@@ -452,12 +472,13 @@ func create6(port uint16) (int, uint16, error) {
        return fd, uint16(addr.Port), err
 }
 
+type ipTos struct {
+       tos byte
+}
+
 func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
 
        // construct message header
-       type ipTos struct {
-               tos byte
-       }
 
        cmsg := struct {
                cmsghdr unix.Cmsghdr
@@ -505,9 +526,6 @@ func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
 func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
 
        // construct message header
-       type ipTos struct {
-               tos byte
-       }
 
        cmsg := struct {
                cmsghdr unix.Cmsghdr
@@ -555,19 +573,21 @@ func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
        return err
 }
 
-func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) {
 
        // contruct message header
 
        var cmsg struct {
-               cmsghdr unix.Cmsghdr
-               pktinfo unix.Inet4Pktinfo
+               cmsghdr  unix.Cmsghdr
+               pktinfo  unix.Inet4Pktinfo
+               cmsghdr2 unix.Cmsghdr
+               iptos    ipTos
        }
 
        size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
 
        if err != nil {
-               return 0, err
+               return 0, 0, err
        }
        end.isV6 = false
 
@@ -576,7 +596,6 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
        }
 
        // update source cache
-
        if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
                cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
                cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
@@ -584,22 +603,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
                end.src4().ifindex = cmsg.pktinfo.Ifindex
        }
 
-       return size, nil
+       tos := byte(0)
+       if cmsg.cmsghdr2.Level == unix.IPPROTO_IP &&
+               cmsg.cmsghdr2.Type == unix.IP_TOS &&
+               cmsg.cmsghdr2.Len >= 1 {
+               tos = cmsg.iptos.tos
+       }
+
+       return size, tos, nil
 }
 
-func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) {
 
        // contruct message header
 
        var cmsg struct {
-               cmsghdr unix.Cmsghdr
-               pktinfo unix.Inet6Pktinfo
+               cmsghdr  unix.Cmsghdr
+               pktinfo  unix.Inet6Pktinfo
+               cmsghdr2 unix.Cmsghdr
+               iptos    ipTos
        }
 
        size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
 
        if err != nil {
-               return 0, err
+               return 0, 0, err
        }
        end.isV6 = true
 
@@ -616,7 +644,14 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
                end.dst6().ZoneId = cmsg.pktinfo.Ifindex
        }
 
-       return size, nil
+       tos := byte(0)
+       if cmsg.cmsghdr2.Level == unix.IPPROTO_IPV6 &&
+               cmsg.cmsghdr2.Type == unix.IPV6_TCLASS &&
+               cmsg.cmsghdr2.Len >= 1 {
+               tos = cmsg.iptos.tos
+       }
+
+       return size, tos, nil
 }
 
 func (bind *NativeBind) routineRouteListener(device *Device) {
diff --git a/misc.go b/misc.go
index 6786cb5633cc95cd19644e58f59d9aa1494fb48b..e5688a59a8661a64b272c8de730c7c54781544bd 100644 (file)
--- a/misc.go
+++ b/misc.go
@@ -46,3 +46,62 @@ func min(a, b uint) uint {
        }
        return a
 }
+
+// called from receive
+func ecn_rfc6040_egress(inner byte, outer byte) (byte, bool) {
+       /*
+       +---------+------------------------------------------------+
+       |Arriving |            Arriving Outer Header               |
+       |   Inner +---------+------------+------------+------------+
+       |  Header | Not-ECT | ECT(0)     | ECT(1)     |     CE     |
+       +---------+---------+------------+------------+------------+
+       | Not-ECT | Not-ECT |Not-ECT(!!!)|Not-ECT(!!!)| <drop>(!!!)|
+       |  ECT(0) |  ECT(0) | ECT(0)     | ECT(1)     |     CE     |
+       |  ECT(1) |  ECT(1) | ECT(1) (!) | ECT(1)     |     CE     |
+       |    CE   |      CE |     CE     |     CE(!!!)|     CE     |
+       +---------+---------+------------+------------+------------+
+       */
+       innerECN := CongestionExperienced & inner
+       outerECN := CongestionExperienced & outer
+
+       switch outerECN {
+       case CongestionExperienced:
+               switch innerECN {
+               case NotECNTransport:
+                       return 0, true
+               }
+               return (inner  & (CongestionExperienced ^ 255)) | CongestionExperienced, false
+       case ECNTransport1:
+               switch innerECN {
+               case ECNTransport0:
+                       return (inner  & (CongestionExperienced ^ 255)) | ECNTransport1, false
+               }
+       }
+       return inner, false
+}
+
+// called from send
+func ecn_rfc6040_ingress(inner byte, useNormalMode bool) byte {
+       /*
+       +-----------------+-------------------------------+
+       | Incoming Header |    Departing Outer Header     |
+       | (also equal to  +---------------+---------------+
+       | departing Inner | Compatibility |    Normal     |
+       |     Header)     |     Mode      |     Mode      |
+       +-----------------+---------------+---------------+
+       |    Not-ECT      |   Not-ECT     |   Not-ECT     |
+       |     ECT(0)      |   Not-ECT     |    ECT(0)     |
+       |     ECT(1)      |   Not-ECT     |    ECT(1)     |
+       |       CE        |   Not-ECT     |      CE       |
+       +-----------------+---------------+---------------+
+       */
+       if !useNormalMode {
+               inner &= (CongestionExperienced ^ 255)
+       }
+
+       return inner
+}
+
+func ecn_rfc6040_enabled(tos byte) bool {
+       return (CongestionExperienced & tos) == ECNTransport0
+}
diff --git a/peer.go b/peer.go
index 96cfa61125760576cb67d9b97706c620046d8b70..642a0ee2710d7f8b8447019f4d6363a5cf6f8091 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -15,6 +15,14 @@ import (
 
 const (
        PeerRoutineNumber = 3
+
+       DiffServAF41          = 0x88 // AF41
+       NotECNTransport       = 0x00 // Not-ECT (Not ECN-Capable Transport)
+       ECNTransport1         = 0x01 // ECT(1) (ECN-Capable Transport(1))
+       ECNTransport0         = 0x02 // ECT(0) (ECN-Capable Transport(0))
+       CongestionExperienced = 0x03 // CE (Congestion Experienced)
+
+       HandshakeDSCP = DiffServAF41 | ECNTransport0 // AF41, plus 10 ECN
 )
 
 type Peer struct {
@@ -25,6 +33,7 @@ type Peer struct {
        device                      *Device
        endpoint                    Endpoint
        persistentKeepaliveInterval uint16
+       isECNConfirmed              AtomicBool
 
        // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
        stats struct {
index fb848eb064d7d44466dd97099c2efc91c96c46b8..03dbd4b6e7119726c01b89df5bafe18347a91940 100644 (file)
@@ -23,6 +23,7 @@ type QueueHandshakeElement struct {
        packet   []byte
        endpoint Endpoint
        buffer   *[MaxMessageSize]byte
+       isECNCompatible bool
 }
 
 type QueueInboundElement struct {
@@ -33,6 +34,7 @@ type QueueInboundElement struct {
        counter  uint64
        keypair  *Keypair
        endpoint Endpoint
+       tos      byte
 }
 
 func (elem *QueueInboundElement) Drop() {
@@ -108,6 +110,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                err      error
                size     int
                endpoint Endpoint
+               outerTOS byte
        )
 
        for {
@@ -116,9 +119,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
 
                switch IP {
                case ipv4.Version:
-                       size, endpoint, err = bind.ReceiveIPv4(buffer[:])
+                       size, endpoint, outerTOS, err = bind.ReceiveIPv4(buffer[:])
                case ipv6.Version:
-                       size, endpoint, err = bind.ReceiveIPv6(buffer[:])
+                       size, endpoint, outerTOS, err = bind.ReceiveIPv6(buffer[:])
                default:
                        panic("invalid IP version")
                }
@@ -178,6 +181,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                        elem.endpoint = endpoint
                        elem.counter = 0
                        elem.Mutex = sync.Mutex{}
+                       elem.tos = outerTOS
                        elem.Lock()
 
                        // add to decryption queues
@@ -213,6 +217,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                                        buffer:   buffer,
                                        packet:   packet,
                                        endpoint: endpoint,
+                                       isECNCompatible: ecn_rfc6040_enabled(outerTOS),
                                },
                        )) {
                                buffer = device.GetMessageBuffer()
@@ -426,7 +431,7 @@ func (device *Device) RoutineHandshake() {
                        peer.SetEndpointFromPacket(elem.endpoint)
 
                        logDebug.Println(peer, "- Received handshake initiation")
-
+                       peer.isECNConfirmed.Set(elem.isECNCompatible)
                        peer.SendHandshakeResponse()
 
                case MessageResponseType:
@@ -473,6 +478,7 @@ func (device *Device) RoutineHandshake() {
 
                        peer.timersSessionDerived()
                        peer.timersHandshakeComplete()
+                       peer.isECNConfirmed.Set(elem.isECNCompatible)
                        peer.SendKeepalive()
                        select {
                        case peer.signals.newKeypairArrived <- struct{}{}:
@@ -565,6 +571,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        }
                        peer.timersDataReceived()
 
+                       var shouldDrop bool
                        // verify source and strip padding
 
                        switch elem.packet[0] >> 4 {
@@ -595,6 +602,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                        continue
                                }
 
+                               elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos)
                        case ipv6.Version:
 
                                // strip padding
@@ -623,10 +631,15 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                        continue
                                }
 
+                               elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos);
                        default:
                                logInfo.Println("Packet with invalid IP version from", peer)
                                continue
                        }
+                       if shouldDrop {
+                               logInfo.Println("ECN/Congestion detected, dropping packet from", peer)
+                               continue
+                       }
 
                        // write to tun device
 
diff --git a/send.go b/send.go
index 57bb67b51568a0249f2c3e3e8726ba90ea311ee7..f787027ba89b99cbd10ed78c5a175f32f11802fc 100644 (file)
--- a/send.go
+++ b/send.go
@@ -41,10 +41,6 @@ import (
  * (to allow the construction of transport messages in-place)
  */
 
-const (
-       HandshakeDSCP = 0x88 // AF41, plus 00 ECN
-)
-
 type QueueOutboundElement struct {
        dropped int32
        sync.Mutex
@@ -299,14 +295,20 @@ func (device *Device) RoutineReadFromTUN() {
                        }
                        dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
                        peer = device.allowedips.LookupIPv4(dst)
-                       elem.tos = elem.packet[1];
+                       if peer == nil {
+                               continue
+                       }
+                       elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get())
                case ipv6.Version:
                        if len(elem.packet) < ipv6.HeaderLen {
                                continue
                        }
                        dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
                        peer = device.allowedips.LookupIPv6(dst)
-                       elem.tos = elem.packet[1];
+                       if peer == nil {
+                               continue
+                       }
+                       elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get())
                default:
                        logDebug.Println("Received packet with unknown IP version")
                }