]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn: use netip for std bind
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 18 Mar 2022 04:23:02 +0000 (22:23 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 18 Mar 2022 04:23:02 +0000 (22:23 -0600)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
conn/bind_std.go

index ab800ba29cf22986fc99469da2868e2b5e874f2f..e0f6cdddb5428bb4b7351725f3dc056db3f8597f 100644 (file)
@@ -27,7 +27,7 @@ type StdNetBind struct {
 
 func NewStdNetBind() Bind { return &StdNetBind{} }
 
-type StdNetEndpoint net.UDPAddr
+type StdNetEndpoint netip.AddrPort
 
 var (
        _ Bind     = (*StdNetBind)(nil)
@@ -36,18 +36,13 @@ var (
 
 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
        e, err := netip.ParseAddrPort(s)
-       return (*StdNetEndpoint)(&net.UDPAddr{
-               IP:   e.Addr().AsSlice(),
-               Port: int(e.Port()),
-               Zone: e.Addr().Zone(),
-       }), err
+       return (*StdNetEndpoint)(&e), err
 }
 
 func (*StdNetEndpoint) ClearSrc() {}
 
 func (e *StdNetEndpoint) DstIP() netip.Addr {
-       a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
-       return a
+       return (*netip.AddrPort)(e).Addr()
 }
 
 func (e *StdNetEndpoint) SrcIP() netip.Addr {
@@ -55,18 +50,12 @@ func (e *StdNetEndpoint) SrcIP() netip.Addr {
 }
 
 func (e *StdNetEndpoint) DstToBytes() []byte {
-       addr := (*net.UDPAddr)(e)
-       out := addr.IP.To4()
-       if out == nil {
-               out = addr.IP
-       }
-       out = append(out, byte(addr.Port&0xff))
-       out = append(out, byte((addr.Port>>8)&0xff))
-       return out
+       b, _ := (*netip.AddrPort)(e).MarshalBinary()
+       return b
 }
 
 func (e *StdNetEndpoint) DstToString() string {
-       return (*net.UDPAddr)(e).String()
+       return (*netip.AddrPort)(e).String()
 }
 
 func (e *StdNetEndpoint) SrcToString() string {
@@ -162,18 +151,15 @@ func (bind *StdNetBind) Close() error {
 
 func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
        return func(buff []byte) (int, Endpoint, error) {
-               n, endpoint, err := conn.ReadFromUDP(buff)
-               if endpoint != nil {
-                       endpoint.IP = endpoint.IP.To4()
-               }
-               return n, (*StdNetEndpoint)(endpoint), err
+               n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
+               return n, (*StdNetEndpoint)(&endpoint), err
        }
 }
 
 func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
        return func(buff []byte) (int, Endpoint, error) {
-               n, endpoint, err := conn.ReadFromUDP(buff)
-               return n, (*StdNetEndpoint)(endpoint), err
+               n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
+               return n, (*StdNetEndpoint)(&endpoint), err
        }
 }
 
@@ -183,11 +169,12 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
        if !ok {
                return ErrWrongEndpointType
        }
+       addrPort := (*netip.AddrPort)(nend)
 
        bind.mu.Lock()
        blackhole := bind.blackhole4
        conn := bind.ipv4
-       if nend.IP.To4() == nil {
+       if addrPort.Addr().Is6() {
                blackhole = bind.blackhole6
                conn = bind.ipv6
        }
@@ -199,6 +186,6 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
        if conn == nil {
                return syscall.EAFNOSUPPORT
        }
-       _, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
+       _, err = conn.WriteToUDPAddrPort(buff, *addrPort)
        return err
 }