]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
net: mctp: separate key correlation across nets
authorJeremy Kerr <jk@codeconstruct.com.au>
Mon, 19 Feb 2024 09:51:50 +0000 (17:51 +0800)
committerPaolo Abeni <pabeni@redhat.com>
Thu, 22 Feb 2024 12:32:55 +0000 (13:32 +0100)
Currently, we lookup sk_keys from the entire struct net_namespace, which
may contain multiple MCTP net IDs. In those cases we want to distinguish
between endpoints with the same EID but different net ID.

Add the net ID data to the struct mctp_sk_key, populate on add and
filter on this during route lookup.

For the ioctl interface, we use a default net of
MCTP_INITIAL_DEFAULT_NET (ie., what will be in use for single-net
configurations), but we'll extend the ioctl interface to provide
net-specific tag allocation in an upcoming change.

Signed-off-by: Jeremy Kerr <jk@codeconstruct.com.au>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
include/net/mctp.h
net/mctp/af_mctp.c
net/mctp/route.c
net/mctp/test/route-test.c

index f937a325ea6f2c809b4d6bdf514d90ffd1cc08cb..0dfae6f51a322c5dcb519131aa17c2ebc5c16519 100644 (file)
@@ -133,6 +133,7 @@ struct mctp_sock {
  *    - through an expiry timeout, on a per-socket timer
  */
 struct mctp_sk_key {
+       unsigned int    net;
        mctp_eid_t      peer_addr;
        mctp_eid_t      local_addr; /* MCTP_ADDR_ANY for local owned tags */
        __u8            tag; /* incoming tag match; invert TO for local */
@@ -254,6 +255,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 
 void mctp_key_unref(struct mctp_sk_key *key);
 struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
+                                        unsigned int netid,
                                         mctp_eid_t local, mctp_eid_t peer,
                                         bool manual, u8 *tagp);
 
index d8197e9e233bd06a4fe18e41a7277a503c3bc308..05315a422ffb33f1a1c5ec032931fb6799f392c2 100644 (file)
@@ -367,8 +367,8 @@ static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg)
        if (ctl.flags)
                return -EINVAL;
 
-       key = mctp_alloc_local_tag(msk, MCTP_ADDR_ANY, ctl.peer_addr,
-                                  true, &tag);
+       key = mctp_alloc_local_tag(msk, MCTP_INITIAL_DEFAULT_NET,
+                                  MCTP_ADDR_ANY, ctl.peer_addr, true, &tag);
        if (IS_ERR(key))
                return PTR_ERR(key);
 
index 0c2ed75a9e28275aed06f2519ef6d59d8464dabd..28648a7ec866664bfe35a28c51d34026bb6201cf 100644 (file)
@@ -107,9 +107,12 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
  * and peer addresses, or either being ANY.
  */
 
-static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
-                          mctp_eid_t peer, u8 tag)
+static bool mctp_key_match(struct mctp_sk_key *key, unsigned int net,
+                          mctp_eid_t local, mctp_eid_t peer, u8 tag)
 {
+       if (key->net != net)
+               return false;
+
        if (!mctp_address_matches(key->local_addr, local))
                return false;
 
@@ -126,7 +129,7 @@ static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
  * key exists.
  */
 static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
-                                          mctp_eid_t peer,
+                                          unsigned int netid, mctp_eid_t peer,
                                           unsigned long *irqflags)
        __acquires(&key->lock)
 {
@@ -142,7 +145,7 @@ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
        spin_lock_irqsave(&net->mctp.keys_lock, flags);
 
        hlist_for_each_entry(key, &net->mctp.keys, hlist) {
-               if (!mctp_key_match(key, mh->dest, peer, tag))
+               if (!mctp_key_match(key, netid, mh->dest, peer, tag))
                        continue;
 
                spin_lock(&key->lock);
@@ -165,6 +168,7 @@ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
 }
 
 static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
+                                         unsigned int net,
                                          mctp_eid_t local, mctp_eid_t peer,
                                          u8 tag, gfp_t gfp)
 {
@@ -174,6 +178,7 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
        if (!key)
                return NULL;
 
+       key->net = net;
        key->peer_addr = peer;
        key->local_addr = local;
        key->tag = tag;
@@ -219,8 +224,8 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
        }
 
        hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
-               if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
-                                  key->tag)) {
+               if (mctp_key_match(tmp, key->net, key->local_addr,
+                                  key->peer_addr, key->tag)) {
                        spin_lock(&tmp->lock);
                        if (tmp->valid)
                                rc = -EEXIST;
@@ -361,6 +366,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
        struct net *net = dev_net(skb->dev);
        struct mctp_sock *msk;
        struct mctp_hdr *mh;
+       unsigned int netid;
        unsigned long f;
        u8 tag, flags;
        int rc;
@@ -379,6 +385,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 
        /* grab header, advance data ptr */
        mh = mctp_hdr(skb);
+       netid = mctp_cb(skb)->net;
        skb_pull(skb, sizeof(struct mctp_hdr));
 
        if (mh->ver != 1)
@@ -392,7 +399,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
        /* lookup socket / reasm context, exactly matching (src,dest,tag).
         * we hold a ref on the key, and key->lock held.
         */
-       key = mctp_lookup_key(net, skb, mh->src, &f);
+       key = mctp_lookup_key(net, skb, netid, mh->src, &f);
 
        if (flags & MCTP_HDR_FLAG_SOM) {
                if (key) {
@@ -406,7 +413,8 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
                         * this lookup requires key->peer to be MCTP_ADDR_ANY,
                         * it doesn't match just any key->peer.
                         */
-                       any_key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
+                       any_key = mctp_lookup_key(net, skb, netid,
+                                                 MCTP_ADDR_ANY, &f);
                        if (any_key) {
                                msk = container_of(any_key->sk,
                                                   struct mctp_sock, sk);
@@ -443,7 +451,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
                 * packets for this message
                 */
                if (!key) {
-                       key = mctp_key_alloc(msk, mh->dest, mh->src,
+                       key = mctp_key_alloc(msk, netid, mh->dest, mh->src,
                                             tag, GFP_ATOMIC);
                        if (!key) {
                                rc = -ENOMEM;
@@ -637,6 +645,7 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
  * it for the socket msk
  */
 struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
+                                        unsigned int netid,
                                         mctp_eid_t local, mctp_eid_t peer,
                                         bool manual, u8 *tagp)
 {
@@ -651,7 +660,7 @@ struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
                peer = MCTP_ADDR_ANY;
 
        /* be optimistic, alloc now */
-       key = mctp_key_alloc(msk, local, peer, 0, GFP_KERNEL);
+       key = mctp_key_alloc(msk, netid, local, peer, 0, GFP_KERNEL);
        if (!key)
                return ERR_PTR(-ENOMEM);
 
@@ -668,6 +677,10 @@ struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
                 * lock held, they don't change over the lifetime of the key.
                 */
 
+               /* tags are net-specific */
+               if (tmp->net != netid)
+                       continue;
+
                /* if we don't own the tag, it can't conflict */
                if (tmp->tag & MCTP_HDR_FLAG_TO)
                        continue;
@@ -716,6 +729,7 @@ struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
 }
 
 static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
+                                                   unsigned int netid,
                                                    mctp_eid_t daddr,
                                                    u8 req_tag, u8 *tagp)
 {
@@ -730,6 +744,9 @@ static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
        spin_lock_irqsave(&mns->keys_lock, flags);
 
        hlist_for_each_entry(tmp, &mns->keys, hlist) {
+               if (tmp->net != netid)
+                       continue;
+
                if (tmp->tag != req_tag)
                        continue;
 
@@ -910,6 +927,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
        struct mctp_sk_key *key;
        struct mctp_hdr *hdr;
        unsigned long flags;
+       unsigned int netid;
        unsigned int mtu;
        mctp_eid_t saddr;
        bool ext_rt;
@@ -960,16 +978,17 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
                rc = 0;
        }
        spin_unlock_irqrestore(&rt->dev->addrs_lock, flags);
+       netid = READ_ONCE(rt->dev->net);
 
        if (rc)
                goto out_release;
 
        if (req_tag & MCTP_TAG_OWNER) {
                if (req_tag & MCTP_TAG_PREALLOC)
-                       key = mctp_lookup_prealloc_tag(msk, daddr,
+                       key = mctp_lookup_prealloc_tag(msk, netid, daddr,
                                                       req_tag, &tag);
                else
-                       key = mctp_alloc_local_tag(msk, saddr, daddr,
+                       key = mctp_alloc_local_tag(msk, netid, saddr, daddr,
                                                   false, &tag);
 
                if (IS_ERR(key)) {
index 714e5ae47629f53b8c8864417ead167ab857f16a..b3dbd3600d916eda41852033b74bbd60ccf1c4a3 100644 (file)
@@ -552,6 +552,7 @@ static void mctp_test_route_input_sk_keys(struct kunit *test)
        struct mctp_sock *msk;
        struct socket *sock;
        unsigned long flags;
+       unsigned int net;
        int rc;
        u8 c;
 
@@ -559,6 +560,7 @@ static void mctp_test_route_input_sk_keys(struct kunit *test)
 
        dev = mctp_test_create_dev();
        KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
+       net = READ_ONCE(dev->mdev->net);
 
        rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
        KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
@@ -570,8 +572,9 @@ static void mctp_test_route_input_sk_keys(struct kunit *test)
        mns = &sock_net(sock->sk)->mctp;
 
        /* set the incoming tag according to test params */
-       key = mctp_key_alloc(msk, params->key_local_addr, params->key_peer_addr,
-                            params->key_tag, GFP_KERNEL);
+       key = mctp_key_alloc(msk, net, params->key_local_addr,
+                            params->key_peer_addr, params->key_tag,
+                            GFP_KERNEL);
 
        KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key);