]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wgcfg: new config package
authorDavid Crawshaw <crawshaw@tailscale.com>
Wed, 17 Apr 2019 13:41:25 +0000 (09:41 -0400)
committerDavid Crawshaw <david@zentus.com>
Mon, 30 Mar 2020 22:32:52 +0000 (09:32 +1100)
Based on types and config parser from wireguard-windows.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
wgcfg/config.go [new file with mode: 0644]
wgcfg/ip.go [new file with mode: 0644]
wgcfg/key.go [new file with mode: 0644]
wgcfg/key_test.go [new file with mode: 0644]
wgcfg/name.go [new file with mode: 0644]
wgcfg/parser.go [new file with mode: 0644]
wgcfg/parser_test.go [new file with mode: 0644]
wgcfg/writer.go [new file with mode: 0644]

diff --git a/wgcfg/config.go b/wgcfg/config.go
new file mode 100644 (file)
index 0000000..2b5e714
--- /dev/null
@@ -0,0 +1,78 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+// Package wgcfg has types and a parser for representing WireGuard config.
+package wgcfg
+
+import (
+       "fmt"
+       "strings"
+)
+
+// Config is a wireguard configuration.
+type Config struct {
+       Name       string
+       PrivateKey PrivateKey
+       Addresses  []CIDR
+       ListenPort uint16
+       MTU        uint16
+       DNS        []IP
+       Peers      []Peer
+}
+
+type Peer struct {
+       PublicKey           Key
+       PresharedKey        SymmetricKey
+       AllowedIPs          []CIDR
+       Endpoints           []Endpoint
+       PersistentKeepalive uint16
+}
+
+type Endpoint struct {
+       Host string
+       Port uint16
+}
+
+func (e *Endpoint) String() string {
+       if strings.IndexByte(e.Host, ':') > 0 {
+               return fmt.Sprintf("[%s]:%d", e.Host, e.Port)
+       }
+       return fmt.Sprintf("%s:%d", e.Host, e.Port)
+}
+
+func (e *Endpoint) IsEmpty() bool {
+       return len(e.Host) == 0
+}
+
+// Copy makes a deep copy of Config.
+// The result aliases no memory with the original.
+func (cfg Config) Copy() Config {
+       res := cfg
+       if res.Addresses != nil {
+               res.Addresses = append([]CIDR{}, res.Addresses...)
+       }
+       if res.DNS != nil {
+               res.DNS = append([]IP{}, res.DNS...)
+       }
+       peers := make([]Peer, 0, len(res.Peers))
+       for _, peer := range res.Peers {
+               peers = append(peers, peer.Copy())
+       }
+       res.Peers = peers
+       return res
+}
+
+// Copy makes a deep copy of Peer.
+// The result aliases no memory with the original.
+func (peer Peer) Copy() Peer {
+       res := peer
+       if res.AllowedIPs != nil {
+               res.AllowedIPs = append([]CIDR{}, res.AllowedIPs...)
+       }
+       if res.Endpoints != nil {
+               res.Endpoints = append([]Endpoint{}, res.Endpoints...)
+       }
+       return res
+}
diff --git a/wgcfg/ip.go b/wgcfg/ip.go
new file mode 100644 (file)
index 0000000..ecf5faf
--- /dev/null
@@ -0,0 +1,128 @@
+package wgcfg
+
+import (
+       "fmt"
+       "net"
+)
+
+// IP is an IPv4 or an IPv6 address.
+//
+// Internally the address is always represented in its IPv6 form.
+// IPv4 addresses use the IPv4-in-IPv6 syntax.
+type IP struct {
+       Addr [16]byte
+}
+
+func (ip IP) String() string { return net.IP(ip.Addr[:]).String() }
+
+func (ip *IP) IP() net.IP { return net.IP(ip.Addr[:]) }
+func (ip *IP) Is6() bool  { return !ip.Is4() }
+func (ip *IP) Is4() bool {
+       return ip.Addr[0] == 0 && ip.Addr[1] == 0 &&
+               ip.Addr[2] == 0 && ip.Addr[3] == 0 &&
+               ip.Addr[4] == 0 && ip.Addr[5] == 0 &&
+               ip.Addr[6] == 0 && ip.Addr[7] == 0 &&
+               ip.Addr[8] == 0 && ip.Addr[9] == 0 &&
+               ip.Addr[10] == 0xff && ip.Addr[11] == 0xff
+}
+func (ip *IP) To4() []byte {
+       if ip.Is4() {
+               return ip.Addr[12:16]
+       } else {
+               return nil
+       }
+}
+func (ip *IP) Equal(x *IP) bool {
+       if ip == nil || x == nil {
+               return false
+       }
+       // TODO: this isn't hard, write a more efficient implementation.
+       return ip.IP().Equal(x.IP())
+}
+
+func (ip IP) MarshalText() ([]byte, error) {
+       return []byte(ip.String()), nil
+}
+
+func (ip *IP) UnmarshalText(text []byte) error {
+       parsedIP := ParseIP(string(text))
+       if parsedIP == nil {
+               return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", string(text))
+       }
+       *ip = *parsedIP
+       return nil
+}
+
+func IPv4(b0, b1, b2, b3 byte) (ip IP) {
+       ip.Addr[10], ip.Addr[11] = 0xff, 0xff // IPv4-in-IPv6 prefix
+       ip.Addr[12] = b0
+       ip.Addr[13] = b1
+       ip.Addr[14] = b2
+       ip.Addr[15] = b3
+       return ip
+}
+
+// ParseIP parses the string representation of an address into an IP.
+//
+// It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0".
+// If the string is not a valid IP address, ParseIP returns nil.
+func ParseIP(s string) *IP {
+       netIP := net.ParseIP(s)
+       if netIP == nil {
+               return nil
+       }
+       ip := new(IP)
+       copy(ip.Addr[:], netIP.To16())
+       return ip
+}
+
+// CIDR is a compact IP address and subnet mask.
+type CIDR struct {
+       IP   IP
+       Mask uint8 // 0-32 for IsIPv4, 4-128 for IsIPv6
+}
+
+// ParseCIDR parses CIDR notation into a CIDR type.
+// Typical CIDR strings look like "192.168.1.0/24".
+func ParseCIDR(s string) (cidr *CIDR, err error) {
+       netIP, netAddr, err := net.ParseCIDR(s)
+       if err != nil {
+               return nil, err
+       }
+       cidr = new(CIDR)
+       copy(cidr.IP.Addr[:], netIP.To16())
+       ones, _ := netAddr.Mask.Size()
+       cidr.Mask = uint8(ones)
+
+       return cidr, nil
+}
+
+func (r CIDR) String() string { return r.IPNet().String() }
+
+func (r *CIDR) IPNet() *net.IPNet {
+       bits := 128
+       if r.IP.Is4() {
+               bits = 32
+       }
+       return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)}
+}
+func (r *CIDR) Contains(ip *IP) bool {
+       if r == nil || ip == nil {
+               return false
+       }
+       // TODO: this isn't hard, write a more efficient implementation.
+       return r.IPNet().Contains(ip.IP())
+}
+
+func (r CIDR) MarshalText() ([]byte, error) {
+       return []byte(r.String()), nil
+}
+
+func (r *CIDR) UnmarshalText(text []byte) error {
+       cidr, err := ParseCIDR(string(text))
+       if err != nil {
+               return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err)
+       }
+       *r = *cidr
+       return nil
+}
diff --git a/wgcfg/key.go b/wgcfg/key.go
new file mode 100644 (file)
index 0000000..1597203
--- /dev/null
@@ -0,0 +1,240 @@
+package wgcfg
+
+import (
+       "bytes"
+       "crypto/rand"
+       "crypto/subtle"
+       "encoding/base64"
+       "encoding/hex"
+       "errors"
+       "fmt"
+       "strings"
+
+       "golang.org/x/crypto/chacha20poly1305"
+       "golang.org/x/crypto/curve25519"
+)
+
+const KeySize = 32
+
+// Key is curve25519 key.
+// It is used by WireGuard to represent public and preshared keys.
+type Key [KeySize]byte
+
+// NewPresharedKey generates a new random key.
+func NewPresharedKey() (*Key, error) {
+       var k [KeySize]byte
+       _, err := rand.Read(k[:])
+       if err != nil {
+               return nil, err
+       }
+       return (*Key)(&k), nil
+}
+
+func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) }
+
+func ParseHexKey(s string) (Key, error) {
+       b, err := hex.DecodeString(s)
+       if err != nil {
+               return Key{}, &ParseError{"invalid hex key: " + err.Error(), s}
+       }
+       if len(b) != KeySize {
+               return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s}
+       }
+
+       var key Key
+       copy(key[:], b)
+       return key, nil
+}
+
+func ParsePrivateHexKey(v string) (PrivateKey, error) {
+       k, err := ParseHexKey(v)
+       if err != nil {
+               return PrivateKey{}, err
+       }
+       pk := PrivateKey(k)
+       if pk.IsZero() {
+               // Do not clamp a zero key, pass the zero through
+               // (much like NaN propagation) so that IsZero reports
+               // a useful result.
+               return pk, nil
+       }
+       pk.clamp()
+       return pk, nil
+}
+
+func (k Key) Base64() string    { return base64.StdEncoding.EncodeToString(k[:]) }
+func (k Key) String() string    { return "pub:" + k.Base64()[:8] }
+func (k Key) HexString() string { return hex.EncodeToString(k[:]) }
+func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
+
+func (k *Key) ShortString() string {
+       if k.IsZero() {
+               return "[empty]"
+       }
+       long := k.String()
+       if len(long) < 10 {
+               return "invalid"
+       }
+       return "[" + long[0:4] + "…" + long[len(long)-5:len(long)-1] + "]"
+}
+
+func (k *Key) IsZero() bool {
+       if k == nil {
+               return true
+       }
+       var zeros Key
+       return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
+}
+
+func (k *Key) MarshalJSON() ([]byte, error) {
+       if k == nil {
+               return []byte("null"), nil
+       }
+       buf := new(bytes.Buffer)
+       fmt.Fprintf(buf, `"%x"`, k[:])
+       return buf.Bytes(), nil
+}
+
+func (k *Key) UnmarshalJSON(b []byte) error {
+       if k == nil {
+               return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer")
+       }
+       if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' {
+               return errors.New("wgcfg.Key: UnmarshalJSON not given a string")
+       }
+       b = b[1 : len(b)-1]
+       key, err := ParseHexKey(string(b))
+       if err != nil {
+               return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err)
+       }
+       copy(k[:], key[:])
+       return nil
+}
+
+func (a *Key) LessThan(b *Key) bool {
+       for i := range a {
+               if a[i] < b[i] {
+                       return true
+               } else if a[i] > b[i] {
+                       return false
+               }
+       }
+       return false
+}
+
+// PrivateKey is curve25519 key.
+// It is used by WireGuard to represent private keys.
+type PrivateKey [KeySize]byte
+
+// NewPrivateKey generates a new curve25519 secret key.
+// It conforms to the format described on https://cr.yp.to/ecdh.html.
+func NewPrivateKey() (PrivateKey, error) {
+       k, err := NewPresharedKey()
+       if err != nil {
+               return PrivateKey{}, err
+       }
+       k[0] &= 248
+       k[31] = (k[31] & 127) | 64
+       return (PrivateKey)(*k), nil
+}
+
+func ParsePrivateKey(b64 string) (*PrivateKey, error) {
+       k, err := parseKeyBase64(base64.StdEncoding, b64)
+       return (*PrivateKey)(k), err
+}
+
+func (k *PrivateKey) String() string           { return base64.StdEncoding.EncodeToString(k[:]) }
+func (k *PrivateKey) HexString() string        { return hex.EncodeToString(k[:]) }
+func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
+
+func (k *PrivateKey) IsZero() bool {
+       pk := Key(*k)
+       return pk.IsZero()
+}
+
+func (k *PrivateKey) clamp() {
+       k[0] &= 248
+       k[31] = (k[31] & 127) | 64
+}
+
+// Public computes the public key matching this curve25519 secret key.
+func (k *PrivateKey) Public() Key {
+       pk := Key(*k)
+       if pk.IsZero() {
+               panic("Tried to generate emptyPrivateKey.Public()")
+       }
+       var p [KeySize]byte
+       curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k))
+       return (Key)(p)
+}
+
+func (k PrivateKey) MarshalText() ([]byte, error) {
+       buf := new(bytes.Buffer)
+       fmt.Fprintf(buf, `privkey:%x`, k[:])
+       return buf.Bytes(), nil
+}
+
+func (k *PrivateKey) UnmarshalText(b []byte) error {
+       s := string(b)
+       if !strings.HasPrefix(s, `privkey:`) {
+               return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string")
+       }
+       s = strings.TrimPrefix(s, `privkey:`)
+       key, err := ParseHexKey(s)
+       if err != nil {
+               return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err)
+       }
+       copy(k[:], key[:])
+       return nil
+}
+
+func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) {
+       apk := (*[KeySize]byte)(&pub)
+       ask := (*[KeySize]byte)(&k)
+       curve25519.ScalarMult(&ss, ask, apk)
+       return ss
+}
+
+func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) {
+       k, err := enc.DecodeString(s)
+       if err != nil {
+               return nil, &ParseError{"Invalid key: " + err.Error(), s}
+       }
+       if len(k) != KeySize {
+               return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
+       }
+       var key Key
+       copy(key[:], k)
+       return &key, nil
+}
+
+func ParseSymmetricKey(b64 string) (SymmetricKey, error) {
+       k, err := parseKeyBase64(base64.StdEncoding, b64)
+       if err != nil {
+               return SymmetricKey{}, err
+       }
+       return SymmetricKey(*k), nil
+}
+
+func ParseSymmetricHexKey(s string) (SymmetricKey, error) {
+       b, err := hex.DecodeString(s)
+       if err != nil {
+               return SymmetricKey{}, &ParseError{"invalid symmetric hex key: " + err.Error(), s}
+       }
+       if len(b) != chacha20poly1305.KeySize {
+               return SymmetricKey{}, &ParseError{fmt.Sprintf("invalid symmetric hex key length: %d", len(b)), s}
+       }
+       var key SymmetricKey
+       copy(key[:], b)
+       return key, nil
+}
+
+// SymmetricKey is a chacha20poly1305 key.
+// It is used by WireGuard to represent pre-shared symmetric keys.
+type SymmetricKey [chacha20poly1305.KeySize]byte
+
+func (k SymmetricKey) Base64() string             { return base64.StdEncoding.EncodeToString(k[:]) }
+func (k SymmetricKey) String() string             { return "sym:" + k.Base64()[:8] }
+func (k SymmetricKey) HexString() string          { return hex.EncodeToString(k[:]) }
+func (k SymmetricKey) IsZero() bool               { return k.Equal(SymmetricKey{}) }
+func (k SymmetricKey) Equal(k2 SymmetricKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
diff --git a/wgcfg/key_test.go b/wgcfg/key_test.go
new file mode 100644 (file)
index 0000000..0b82d5f
--- /dev/null
@@ -0,0 +1,107 @@
+package wgcfg
+
+import (
+       "bytes"
+       "testing"
+)
+
+func TestKeyBasics(t *testing.T) {
+       k1, err := NewPresharedKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       b, err := k1.MarshalJSON()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       t.Run("JSON round-trip", func(t *testing.T) {
+               // should preserve the keys
+               k2 := new(Key)
+               if err := k2.UnmarshalJSON(b); err != nil {
+                       t.Fatal(err)
+               }
+               if !bytes.Equal(k1[:], k2[:]) {
+                       t.Fatalf("k1 %v != k2 %v", k1[:], k2[:])
+               }
+               if b1, b2 := k1.String(), k2.String(); b1 != b2 {
+                       t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2)
+               }
+       })
+
+       t.Run("JSON incompatible with PrivateKey", func(t *testing.T) {
+               k2 := new(PrivateKey)
+               if err := k2.UnmarshalText(b); err == nil {
+                       t.Fatalf("successfully decoded key as private key")
+               }
+       })
+
+       t.Run("second key", func(t *testing.T) {
+               // A second call to NewPresharedKey should make a new key.
+               k3, err := NewPresharedKey()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if bytes.Equal(k1[:], k3[:]) {
+                       t.Fatalf("k1 %v == k3 %v", k1[:], k3[:])
+               }
+               // Check for obvious comparables to make sure we are not generating bad strings somewhere.
+               if b1, b2 := k1.String(), k3.String(); b1 == b2 {
+                       t.Fatalf("base64-encoded keys match: %s, %s", b1, b2)
+               }
+       })
+}
+func TestPrivateKeyBasics(t *testing.T) {
+       pri, err := NewPrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       b, err := pri.MarshalText()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       t.Run("JSON round-trip", func(t *testing.T) {
+               // should preserve the keys
+               pri2 := new(PrivateKey)
+               if err := pri2.UnmarshalText(b); err != nil {
+                       t.Fatal(err)
+               }
+               if !bytes.Equal(pri[:], pri2[:]) {
+                       t.Fatalf("pri %v != pri2 %v", pri[:], pri2[:])
+               }
+               if b1, b2 := pri.String(), pri2.String(); b1 != b2 {
+                       t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2)
+               }
+               if pub1, pub2 := pri.Public().String(), pri2.Public().String(); pub1 != pub2 {
+                       t.Fatalf("base64-encoded public keys do not match: %s, %s", pub1, pub2)
+               }
+       })
+
+       t.Run("JSON incompatible with Key", func(t *testing.T) {
+               k2 := new(Key)
+               if err := k2.UnmarshalJSON(b); err == nil {
+                       t.Fatalf("successfully decoded private key as key")
+               }
+       })
+
+       t.Run("second key", func(t *testing.T) {
+               // A second call to New should make a new key.
+               pri3, err := NewPrivateKey()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if bytes.Equal(pri[:], pri3[:]) {
+                       t.Fatalf("pri %v == pri3 %v", pri[:], pri3[:])
+               }
+               // Check for obvious comparables to make sure we are not generating bad strings somewhere.
+               if b1, b2 := pri.String(), pri3.String(); b1 == b2 {
+                       t.Fatalf("base64-encoded keys match: %s, %s", b1, b2)
+               }
+               if pub1, pub2 := pri.Public().String(), pri3.Public().String(); pub1 == pub2 {
+                       t.Fatalf("base64-encoded public keys match: %s, %s", pub1, pub2)
+               }
+       })
+}
diff --git a/wgcfg/name.go b/wgcfg/name.go
new file mode 100644 (file)
index 0000000..28bc0f0
--- /dev/null
@@ -0,0 +1,49 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wgcfg
+
+import (
+       "regexp"
+       "strings"
+)
+
+var reservedNames = []string{
+       "CON", "PRN", "AUX", "NUL",
+       "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9",
+       "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
+}
+
+const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00"
+
+var allowedNameFormat *regexp.Regexp
+
+func init() {
+       allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$")
+}
+
+func isReserved(name string) bool {
+       if len(name) == 0 {
+               return false
+       }
+       for _, reserved := range reservedNames {
+               if strings.EqualFold(name, reserved) {
+                       return true
+               }
+       }
+       return false
+}
+
+func hasSpecialChars(name string) bool {
+       return strings.ContainsAny(name, specialChars)
+}
+
+func TunnelNameIsValid(name string) bool {
+       // Aside from our own restrictions, let's impose the Windows restrictions first
+       if isReserved(name) || hasSpecialChars(name) {
+               return false
+       }
+       return allowedNameFormat.MatchString(name)
+}
diff --git a/wgcfg/parser.go b/wgcfg/parser.go
new file mode 100644 (file)
index 0000000..45a6057
--- /dev/null
@@ -0,0 +1,397 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wgcfg
+
+import (
+       "encoding/hex"
+       "fmt"
+       "net"
+       "strconv"
+       "strings"
+)
+
+type ParseError struct {
+       why      string
+       offender string
+}
+
+func (e *ParseError) Error() string {
+       return fmt.Sprintf("%s: ‘%s’", e.why, e.offender)
+}
+
+func parseEndpoints(s string) ([]Endpoint, error) {
+       var eps []Endpoint
+       vals := strings.Split(s, ",")
+       for _, val := range vals {
+               e, err := parseEndpoint(val)
+               if err != nil {
+                       return nil, err
+               }
+               eps = append(eps, *e)
+       }
+       return eps, nil
+}
+
+func parseEndpoint(s string) (*Endpoint, error) {
+       i := strings.LastIndexByte(s, ':')
+       if i < 0 {
+               return nil, &ParseError{"Missing port from endpoint", s}
+       }
+       host, portStr := s[:i], s[i+1:]
+       if len(host) < 1 {
+               return nil, &ParseError{"Invalid endpoint host", host}
+       }
+       port, err := parsePort(portStr)
+       if err != nil {
+               return nil, err
+       }
+       hostColon := strings.IndexByte(host, ':')
+       if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
+               err := &ParseError{"Brackets must contain an IPv6 address", host}
+               if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
+                       maybeV6 := net.ParseIP(host[1 : len(host)-1])
+                       if maybeV6 == nil || len(maybeV6) != net.IPv6len {
+                               return nil, err
+                       }
+               } else {
+                       return nil, err
+               }
+               host = host[1 : len(host)-1]
+       }
+       return &Endpoint{host, uint16(port)}, nil
+}
+
+func parseMTU(s string) (uint16, error) {
+       m, err := strconv.Atoi(s)
+       if err != nil {
+               return 0, err
+       }
+       if m < 576 || m > 65535 {
+               return 0, &ParseError{"Invalid MTU", s}
+       }
+       return uint16(m), nil
+}
+
+func parsePort(s string) (uint16, error) {
+       m, err := strconv.Atoi(s)
+       if err != nil {
+               return 0, err
+       }
+       if m < 0 || m > 65535 {
+               return 0, &ParseError{"Invalid port", s}
+       }
+       return uint16(m), nil
+}
+
+func parsePersistentKeepalive(s string) (uint16, error) {
+       if s == "off" {
+               return 0, nil
+       }
+       m, err := strconv.Atoi(s)
+       if err != nil {
+               return 0, err
+       }
+       if m < 0 || m > 65535 {
+               return 0, &ParseError{"Invalid persistent keepalive", s}
+       }
+       return uint16(m), nil
+}
+
+func parseKeyHex(s string) (*Key, error) {
+       k, err := hex.DecodeString(s)
+       if err != nil {
+               return nil, &ParseError{"Invalid key: " + err.Error(), s}
+       }
+       if len(k) != KeySize {
+               return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
+       }
+       var key Key
+       copy(key[:], k)
+       return &key, nil
+}
+
+func parseBytesOrStamp(s string) (uint64, error) {
+       b, err := strconv.ParseUint(s, 10, 64)
+       if err != nil {
+               return 0, &ParseError{"Number must be a number between 0 and 2^64-1: " + err.Error(), s}
+       }
+       return b, nil
+}
+
+func splitList(s string) ([]string, error) {
+       var out []string
+       for _, split := range strings.Split(s, ",") {
+               trim := strings.TrimSpace(split)
+               if len(trim) == 0 {
+                       return nil, &ParseError{"Two commas in a row", s}
+               }
+               out = append(out, trim)
+       }
+       return out, nil
+}
+
+type parserState int
+
+const (
+       inInterfaceSection parserState = iota
+       inPeerSection
+       notInASection
+)
+
+func (c *Config) maybeAddPeer(p *Peer) {
+       if p != nil {
+               c.Peers = append(c.Peers, *p)
+       }
+}
+
+func FromWgQuick(s string, name string) (*Config, error) {
+       if !TunnelNameIsValid(name) {
+               return nil, &ParseError{"Tunnel name is not valid", name}
+       }
+       lines := strings.Split(s, "\n")
+       parserState := notInASection
+       conf := Config{Name: name}
+       sawPrivateKey := false
+       var peer *Peer
+       for _, line := range lines {
+               pound := strings.IndexByte(line, '#')
+               if pound >= 0 {
+                       line = line[:pound]
+               }
+               line = strings.TrimSpace(line)
+               lineLower := strings.ToLower(line)
+               if len(line) == 0 {
+                       continue
+               }
+               if lineLower == "[interface]" {
+                       conf.maybeAddPeer(peer)
+                       parserState = inInterfaceSection
+                       continue
+               }
+               if lineLower == "[peer]" {
+                       conf.maybeAddPeer(peer)
+                       peer = &Peer{}
+                       parserState = inPeerSection
+                       continue
+               }
+               if parserState == notInASection {
+                       return nil, &ParseError{"Line must occur in a section", line}
+               }
+               equals := strings.IndexByte(line, '=')
+               if equals < 0 {
+                       return nil, &ParseError{"Invalid config key is missing an equals separator", line}
+               }
+               key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:])
+               if len(val) == 0 {
+                       return nil, &ParseError{"Key must have a value", line}
+               }
+               if parserState == inInterfaceSection {
+                       switch key {
+                       case "privatekey":
+                               k, err := ParseKey(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               conf.PrivateKey = PrivateKey(*k)
+                               sawPrivateKey = true
+                       case "listenport":
+                               p, err := parsePort(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               conf.ListenPort = p
+                       case "mtu":
+                               m, err := parseMTU(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               conf.MTU = m
+                       case "address":
+                               addresses, err := splitList(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               for _, address := range addresses {
+                                       a, err := ParseCIDR(address)
+                                       if err != nil {
+                                               return nil, err
+                                       }
+                                       conf.Addresses = append(conf.Addresses, *a)
+                               }
+                       case "dns":
+                               addresses, err := splitList(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               for _, address := range addresses {
+                                       a := ParseIP(address)
+                                       if a == nil {
+                                               return nil, &ParseError{"Invalid IP address", address}
+                                       }
+                                       conf.DNS = append(conf.DNS, *a)
+                               }
+                       default:
+                               return nil, &ParseError{"Invalid key for [Interface] section", key}
+                       }
+               } else if parserState == inPeerSection {
+                       switch key {
+                       case "publickey":
+                               k, err := ParseKey(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.PublicKey = *k
+                       case "presharedkey":
+                               k, err := ParseKey(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.PresharedKey = SymmetricKey(*k)
+                       case "allowedips":
+                               addresses, err := splitList(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               for _, address := range addresses {
+                                       a, err := ParseCIDR(address)
+                                       if err != nil {
+                                               return nil, err
+                                       }
+                                       peer.AllowedIPs = append(peer.AllowedIPs, *a)
+                               }
+                       case "persistentkeepalive":
+                               p, err := parsePersistentKeepalive(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.PersistentKeepalive = p
+                       case "endpoint":
+                               eps, err := parseEndpoints(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.Endpoints = eps
+                       default:
+                               return nil, &ParseError{"Invalid key for [Peer] section", key}
+                       }
+               }
+       }
+       conf.maybeAddPeer(peer)
+
+       if !sawPrivateKey {
+               return nil, &ParseError{"An interface must have a private key", "[none specified]"}
+       }
+       for _, p := range conf.Peers {
+               if p.PublicKey.IsZero() {
+                       return nil, &ParseError{"All peers must have public keys", "[none specified]"}
+               }
+       }
+
+       return &conf, nil
+}
+
+// TODO(apenwarr): This is incompatibe with current Device.IpcSetOperation.
+//  It duplicates all the parser stuff in there, but is missing some
+//  keywords. Nothing useful seems to need it anymore.
+func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) {
+       lines := strings.Split(s, "\n")
+       parserState := inInterfaceSection
+       conf := Config{
+               Name:      existingConfig.Name,
+               Addresses: existingConfig.Addresses,
+               DNS:       existingConfig.DNS,
+               MTU:       existingConfig.MTU,
+       }
+       var peer *Peer
+       for _, line := range lines {
+               if len(line) == 0 {
+                       continue
+               }
+               equals := strings.IndexByte(line, '=')
+               if equals < 0 {
+                       return nil, &ParseError{"Invalid config key is missing an equals separator", line}
+               }
+               key, val := line[:equals], line[equals+1:]
+               if len(val) == 0 {
+                       return nil, &ParseError{"Key must have a value", line}
+               }
+               switch key {
+               case "public_key":
+                       conf.maybeAddPeer(peer)
+                       peer = &Peer{}
+                       parserState = inPeerSection
+               case "errno":
+                       if val == "0" {
+                               continue
+                       } else {
+                               return nil, &ParseError{"Error in getting configuration", val}
+                       }
+               }
+               if parserState == inInterfaceSection {
+                       switch key {
+                       case "private_key":
+                               k, err := parseKeyHex(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               conf.PrivateKey = PrivateKey(*k)
+                       case "listen_port":
+                               p, err := parsePort(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               conf.ListenPort = p
+                       case "fwmark":
+                               // Ignored for now.
+
+                       default:
+                               return nil, &ParseError{"Invalid key for interface section", key}
+                       }
+               } else if parserState == inPeerSection {
+                       switch key {
+                       case "public_key":
+                               k, err := parseKeyHex(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.PublicKey = *k
+                       case "preshared_key":
+                               k, err := parseKeyHex(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.PresharedKey = SymmetricKey(*k)
+                       case "protocol_version":
+                               if val != "1" {
+                                       return nil, &ParseError{"Protocol version must be 1", val}
+                               }
+                       case "allowed_ip":
+                               a, err := ParseCIDR(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.AllowedIPs = append(peer.AllowedIPs, *a)
+                       case "persistent_keepalive_interval":
+                               p, err := parsePersistentKeepalive(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.PersistentKeepalive = p
+                       case "endpoint":
+                               eps, err := parseEndpoints(val)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               peer.Endpoints = eps
+                       default:
+                               return nil, &ParseError{"Invalid key for peer section", key}
+                       }
+               }
+       }
+       conf.maybeAddPeer(peer)
+
+       return &conf, nil
+}
diff --git a/wgcfg/parser_test.go b/wgcfg/parser_test.go
new file mode 100644 (file)
index 0000000..d0df537
--- /dev/null
@@ -0,0 +1,127 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wgcfg
+
+import (
+       "reflect"
+       "runtime"
+       "testing"
+)
+
+const testInput = `
+[Interface] 
+Address = 10.192.122.1/24 
+Address = 10.10.0.1/16 
+PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= 
+ListenPort = 51820  #comments don't matter
+
+[Peer] 
+PublicKey   =   xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=    
+Endpoint = 192.95.5.67:1234 
+AllowedIPs = 10.192.122.3/32, 10.192.124.1/24
+
+[Peer] 
+PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= 
+Endpoint = [2607:5300:60:6b0::c05f:543]:2468 
+AllowedIPs = 10.192.122.4/32, 192.168.0.0/16
+PersistentKeepalive = 100
+
+[Peer] 
+PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= 
+PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= 
+Endpoint = test.wireguard.com:18981 
+AllowedIPs = 10.10.10.230/32`
+
+func noError(t *testing.T, err error) bool {
+       if err == nil {
+               return true
+       }
+       _, fn, line, _ := runtime.Caller(1)
+       t.Errorf("Error at %s:%d: %#v", fn, line, err)
+       return false
+}
+
+func equal(t *testing.T, expected, actual interface{}) bool {
+       if reflect.DeepEqual(expected, actual) {
+               return true
+       }
+       _, fn, line, _ := runtime.Caller(1)
+       t.Errorf("Failed equals at %s:%d\nactual   %#v\nexpected %#v", fn, line, actual, expected)
+       return false
+}
+func lenTest(t *testing.T, actualO interface{}, expected int) bool {
+       actual := reflect.ValueOf(actualO).Len()
+       if reflect.DeepEqual(expected, actual) {
+               return true
+       }
+       _, fn, line, _ := runtime.Caller(1)
+       t.Errorf("Wrong length at %s:%d\nactual   %#v\nexpected %#v", fn, line, actual, expected)
+       return false
+}
+func contains(t *testing.T, list, element interface{}) bool {
+       listValue := reflect.ValueOf(list)
+       for i := 0; i < listValue.Len(); i++ {
+               if reflect.DeepEqual(listValue.Index(i).Interface(), element) {
+                       return true
+               }
+       }
+       _, fn, line, _ := runtime.Caller(1)
+       t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element)
+       return false
+}
+
+func TestFromWgQuick(t *testing.T) {
+       conf, err := FromWgQuick(testInput, "test")
+       if noError(t, err) {
+
+               lenTest(t, conf.Addresses, 2)
+               contains(t, conf.Addresses, CIDR{IPv4(10, 10, 0, 1), 16})
+               contains(t, conf.Addresses, CIDR{IPv4(10, 192, 122, 1), 24})
+               equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.PrivateKey.String())
+               equal(t, uint16(51820), conf.ListenPort)
+
+               lenTest(t, conf.Peers, 3)
+               lenTest(t, conf.Peers[0].AllowedIPs, 2)
+               equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoints[0])
+               equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.Base64())
+
+               lenTest(t, conf.Peers[1].AllowedIPs, 2)
+               equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoints[0])
+               equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.Base64())
+               equal(t, uint16(100), conf.Peers[1].PersistentKeepalive)
+
+               lenTest(t, conf.Peers[2].AllowedIPs, 1)
+               equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoints[0])
+               equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.Base64())
+               equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.Base64())
+       }
+}
+
+func TestParseEndpoint(t *testing.T) {
+       _, err := parseEndpoint("[192.168.42.0:]:51880")
+       if err == nil {
+               t.Error("Error was expected")
+       }
+       e, err := parseEndpoint("192.168.42.0:51880")
+       if noError(t, err) {
+               equal(t, "192.168.42.0", e.Host)
+               equal(t, uint16(51880), e.Port)
+       }
+       e, err = parseEndpoint("test.wireguard.com:18981")
+       if noError(t, err) {
+               equal(t, "test.wireguard.com", e.Host)
+               equal(t, uint16(18981), e.Port)
+       }
+       e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468")
+       if noError(t, err) {
+               equal(t, "2607:5300:60:6b0::c05f:543", e.Host)
+               equal(t, uint16(2468), e.Port)
+       }
+       _, err = parseEndpoint("[::::::invalid:18981")
+       if err == nil {
+               t.Error("Error was expected")
+       }
+}
diff --git a/wgcfg/writer.go b/wgcfg/writer.go
new file mode 100644 (file)
index 0000000..aafb2a7
--- /dev/null
@@ -0,0 +1,75 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wgcfg
+
+import (
+       "errors"
+       "fmt"
+       "net"
+       "strings"
+)
+
+func (conf *Config) ToUAPI() (string, error) {
+       output := new(strings.Builder)
+       fmt.Fprintf(output, "private_key=%s\n", conf.PrivateKey.HexString())
+
+       if conf.ListenPort > 0 {
+               fmt.Fprintf(output, "listen_port=%d\n", conf.ListenPort)
+       }
+
+       output.WriteString("replace_peers=true\n")
+
+       for _, peer := range conf.Peers {
+               fmt.Fprintf(output, "public_key=%s\n", peer.PublicKey.HexString())
+               fmt.Fprintf(output, "protocol_version=1\n")
+               fmt.Fprintf(output, "replace_allowed_ips=true\n")
+
+               if !peer.PresharedKey.IsZero() {
+                       fmt.Fprintf(output, "preshared_key = %s\n", peer.PresharedKey.String())
+               }
+
+               if len(peer.AllowedIPs) > 0 {
+                       for _, address := range peer.AllowedIPs {
+                               fmt.Fprintf(output, "allowed_ip=%s\n", address.String())
+                       }
+               }
+
+               if len(peer.Endpoints) > 0 {
+                       var reps []string
+                       for _, ep := range peer.Endpoints {
+                               ips, err := net.LookupIP(ep.Host)
+                               if err != nil {
+                                       return "", err
+                               }
+                               var ip net.IP
+                               for _, iterip := range ips {
+                                       iterip = iterip.To4()
+                                       if iterip != nil {
+                                               ip = iterip
+                                               break
+                                       }
+                                       if ip == nil {
+                                               ip = iterip
+                                       }
+                               }
+                               if ip == nil {
+                                       return "", errors.New("Unable to resolve IP address of endpoint")
+                               }
+                               resolvedEndpoint := Endpoint{ip.String(), ep.Port}
+                               reps = append(reps, resolvedEndpoint.String())
+                       }
+                       fmt.Fprintf(output, "endpoint=%s\n", strings.Join(reps, ","))
+               } else {
+                       fmt.Fprint(output, "endpoint=\n")
+               }
+
+               // Note: this needs to come *after* endpoint definitions,
+               // because setting it will trigger a handshake to all
+               // already-defined endpoints.
+               fmt.Fprintf(output, "persistent_keepalive_interval=%d\n", peer.PersistentKeepalive)
+       }
+       return output.String(), nil
+}