]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wgcfg: rename Key to PublicKey
authorDavid Crawshaw <crawshaw@tailscale.com>
Tue, 7 Apr 2020 05:49:47 +0000 (15:49 +1000)
committerDavid Crawshaw <crawshaw@tailscale.com>
Fri, 1 May 2020 14:52:01 +0000 (00:52 +1000)
A few minor review cleanups while here (e.g. remove unused LessThan).

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
wgcfg/config.go
wgcfg/key.go
wgcfg/key_test.go
wgcfg/parser.go

index 2b5e7148ddab9496b0ab5e600538460c55e13725..ffb7556a9c78753bc859a3032cf2f4e236b6c9c7 100644 (file)
@@ -23,7 +23,7 @@ type Config struct {
 }
 
 type Peer struct {
-       PublicKey           Key
+       PublicKey           PublicKey
        PresharedKey        SymmetricKey
        AllowedIPs          []CIDR
        Endpoints           []Endpoint
index cdbbeea7099119a2d8a93dcd20b96098fb2090e5..cfb59d35d865c1f348c84263e6a691806317635d 100644 (file)
@@ -2,7 +2,7 @@ package wgcfg
 
 import (
        "bytes"
-       "crypto/rand"
+       cryptorand "crypto/rand"
        "crypto/subtle"
        "encoding/base64"
        "encoding/hex"
@@ -16,32 +16,22 @@ import (
 
 const KeySize = 32
 
-// Key is curve25519 key.
+// PublicKey is curve25519 key.
 // It is used by WireGuard to represent public and preshared keys.
-type Key [KeySize]byte
+type PublicKey [KeySize]byte
 
-// NewPresharedKey generates a new random key.
-func NewPresharedKey() (*Key, error) {
-       var k [KeySize]byte
-       _, err := rand.Read(k[:])
-       if err != nil {
-               return nil, err
-       }
-       return (*Key)(&k), nil
-}
+func ParseKey(b64 string) (*PublicKey, error) { return parseKeyBase64(base64.StdEncoding, b64) }
 
-func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) }
-
-func ParseHexKey(s string) (Key, error) {
+func ParseHexKey(s string) (PublicKey, error) {
        b, err := hex.DecodeString(s)
        if err != nil {
-               return Key{}, &ParseError{"invalid hex key: " + err.Error(), s}
+               return PublicKey{}, &ParseError{"invalid hex key: " + err.Error(), s}
        }
        if len(b) != KeySize {
-               return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s}
+               return PublicKey{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s}
        }
 
-       var key Key
+       var key PublicKey
        copy(key[:], b)
        return key, nil
 }
@@ -62,31 +52,22 @@ func ParsePrivateHexKey(v string) (PrivateKey, error) {
        return pk, nil
 }
 
-func (k Key) Base64() string    { return base64.StdEncoding.EncodeToString(k[:]) }
-func (k Key) String() string    { return "pub:" + k.Base64()[:8] }
-func (k Key) HexString() string { return hex.EncodeToString(k[:]) }
-func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
+func (k PublicKey) Base64() string          { return base64.StdEncoding.EncodeToString(k[:]) }
+func (k PublicKey) String() string          { return k.ShortString() }
+func (k PublicKey) HexString() string       { return hex.EncodeToString(k[:]) }
+func (k PublicKey) Equal(k2 PublicKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
 
-func (k *Key) ShortString() string {
-       if k.IsZero() {
-               return "[empty]"
-       }
-       long := k.String()
-       if len(long) < 10 {
-               return "invalid"
-       }
-       return "[" + long[0:4] + "…" + long[len(long)-5:len(long)-1] + "]"
+func (k *PublicKey) ShortString() string {
+       long := k.Base64()
+       return "[" + long[0:5] + "]"
 }
 
-func (k *Key) IsZero() bool {
-       if k == nil {
-               return true
-       }
-       var zeros Key
+func (k PublicKey) IsZero() bool {
+       var zeros PublicKey
        return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
 }
 
-func (k *Key) MarshalJSON() ([]byte, error) {
+func (k *PublicKey) MarshalJSON() ([]byte, error) {
        if k == nil {
                return []byte("null"), nil
        }
@@ -95,47 +76,35 @@ func (k *Key) MarshalJSON() ([]byte, error) {
        return buf.Bytes(), nil
 }
 
-func (k *Key) UnmarshalJSON(b []byte) error {
+func (k *PublicKey) UnmarshalJSON(b []byte) error {
        if k == nil {
-               return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer")
+               return errors.New("wgcfg.PublicKey: UnmarshalJSON on nil pointer")
        }
        if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' {
-               return errors.New("wgcfg.Key: UnmarshalJSON not given a string")
+               return errors.New("wgcfg.PublicKey: UnmarshalJSON not given a string")
        }
        b = b[1 : len(b)-1]
        key, err := ParseHexKey(string(b))
        if err != nil {
-               return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err)
+               return fmt.Errorf("wgcfg.PublicKey: UnmarshalJSON: %v", err)
        }
        copy(k[:], key[:])
        return nil
 }
 
-func (a *Key) LessThan(b *Key) bool {
-       for i := range a {
-               if a[i] < b[i] {
-                       return true
-               } else if a[i] > b[i] {
-                       return false
-               }
-       }
-       return false
-}
-
 // PrivateKey is curve25519 key.
 // It is used by WireGuard to represent private keys.
 type PrivateKey [KeySize]byte
 
 // NewPrivateKey generates a new curve25519 secret key.
 // It conforms to the format described on https://cr.yp.to/ecdh.html.
-func NewPrivateKey() (PrivateKey, error) {
-       k, err := NewPresharedKey()
+func NewPrivateKey() (pk PrivateKey, err error) {
+       _, err = cryptorand.Read(pk[:])
        if err != nil {
                return PrivateKey{}, err
        }
-       k[0] &= 248
-       k[31] = (k[31] & 127) | 64
-       return (PrivateKey)(*k), nil
+       pk.clamp()
+       return pk, nil
 }
 
 func ParsePrivateKey(b64 string) (*PrivateKey, error) {
@@ -147,9 +116,9 @@ func (k *PrivateKey) String() string           { return base64.StdEncoding.Encod
 func (k *PrivateKey) HexString() string        { return hex.EncodeToString(k[:]) }
 func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
 
-func (k *PrivateKey) IsZero() bool {
-       pk := Key(*k)
-       return pk.IsZero()
+func (k PrivateKey) IsZero() bool {
+       var zeros PrivateKey
+       return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
 }
 
 func (k *PrivateKey) clamp() {
@@ -158,14 +127,13 @@ func (k *PrivateKey) clamp() {
 }
 
 // Public computes the public key matching this curve25519 secret key.
-func (k *PrivateKey) Public() Key {
-       pk := Key(*k)
-       if pk.IsZero() {
-               panic("Tried to generate emptyPrivateKey.Public()")
+func (k PrivateKey) Public() PublicKey {
+       if k.IsZero() {
+               panic("wgcfg: tried to generate public key for a zero key")
        }
        var p [KeySize]byte
-       curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k))
-       return (Key)(p)
+       curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(&k))
+       return (PublicKey)(p)
 }
 
 func (k PrivateKey) MarshalText() ([]byte, error) {
@@ -188,14 +156,14 @@ func (k *PrivateKey) UnmarshalText(b []byte) error {
        return nil
 }
 
-func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) {
+func (k PrivateKey) SharedSecret(pub PublicKey) (ss [KeySize]byte) {
        apk := (*[KeySize]byte)(&pub)
        ask := (*[KeySize]byte)(&k)
        curve25519.ScalarMult(&ss, ask, apk)
        return ss
 }
 
-func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) {
+func parseKeyBase64(enc *base64.Encoding, s string) (*PublicKey, error) {
        k, err := enc.DecodeString(s)
        if err != nil {
                return nil, &ParseError{"Invalid key: " + err.Error(), s}
@@ -203,7 +171,7 @@ func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) {
        if len(k) != KeySize {
                return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
        }
-       var key Key
+       var key PublicKey
        copy(key[:], k)
        return &key, nil
 }
index 0b82d5fcd295abd77afb5dc5a3ea1f288f0a0514..21bffbc9896cccb0b550bf34488fd00d7fa3b80c 100644 (file)
@@ -6,10 +6,11 @@ import (
 )
 
 func TestKeyBasics(t *testing.T) {
-       k1, err := NewPresharedKey()
+       pk1, err := NewPrivateKey()
        if err != nil {
                t.Fatal(err)
        }
+       k1 := pk1.Public()
 
        b, err := k1.MarshalJSON()
        if err != nil {
@@ -18,7 +19,7 @@ func TestKeyBasics(t *testing.T) {
 
        t.Run("JSON round-trip", func(t *testing.T) {
                // should preserve the keys
-               k2 := new(Key)
+               k2 := new(PublicKey)
                if err := k2.UnmarshalJSON(b); err != nil {
                        t.Fatal(err)
                }
@@ -39,10 +40,11 @@ func TestKeyBasics(t *testing.T) {
 
        t.Run("second key", func(t *testing.T) {
                // A second call to NewPresharedKey should make a new key.
-               k3, err := NewPresharedKey()
+               pk3, err := NewPrivateKey()
                if err != nil {
                        t.Fatal(err)
                }
+               k3 := pk3.Public()
                if bytes.Equal(k1[:], k3[:]) {
                        t.Fatalf("k1 %v == k3 %v", k1[:], k3[:])
                }
@@ -52,6 +54,7 @@ func TestKeyBasics(t *testing.T) {
                }
        })
 }
+
 func TestPrivateKeyBasics(t *testing.T) {
        pri, err := NewPrivateKey()
        if err != nil {
@@ -81,7 +84,7 @@ func TestPrivateKeyBasics(t *testing.T) {
        })
 
        t.Run("JSON incompatible with Key", func(t *testing.T) {
-               k2 := new(Key)
+               k2 := new(PublicKey)
                if err := k2.UnmarshalJSON(b); err == nil {
                        t.Fatalf("successfully decoded private key as key")
                }
index e71d32b1f722dbfc7ed8652cdb0bf8e8d0aaaaa0..8db18f350b47fe923d33adca422c996a2cab5f0e 100644 (file)
@@ -100,7 +100,7 @@ func parsePersistentKeepalive(s string) (uint16, error) {
        return uint16(m), nil
 }
 
-func parseKeyHex(s string) (*Key, error) {
+func parseKeyHex(s string) (*PublicKey, error) {
        k, err := hex.DecodeString(s)
        if err != nil {
                return nil, &ParseError{"Invalid key: " + err.Error(), s}
@@ -108,7 +108,7 @@ func parseKeyHex(s string) (*Key, error) {
        if len(k) != KeySize {
                return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
        }
-       var key Key
+       var key PublicKey
        copy(key[:], k)
        return &key, nil
 }