]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
wifi: mt76: mt7996: fix key add/remove imbalance
authorFelix Fietkau <nbd@nbd.name>
Mon, 15 Sep 2025 07:59:02 +0000 (09:59 +0200)
committerFelix Fietkau <nbd@nbd.name>
Mon, 15 Sep 2025 11:23:01 +0000 (13:23 +0200)
Ensure that a key for a link is only added and removed once.
When bringing up a link, only upload keys for that particular link, instead
of iterating over all of them.

Link: https://patch.msgid.link/20250915075910.47558-7-nbd@nbd.name
Signed-off-by: Felix Fietkau <nbd@nbd.name>
drivers/net/wireless/mediatek/mt76/mt7996/main.c

index a81f2133cdc9eeed8ab3c9dcf06bbe46f7d9e532..d706b8bb244e21c0705b705732cb6cca2b844e67 100644 (file)
@@ -182,107 +182,96 @@ mt7996_init_bitrate_mask(struct ieee80211_vif *vif, struct mt7996_vif_link *mlin
 static int
 mt7996_set_hw_key(struct ieee80211_hw *hw, enum set_key_cmd cmd,
                  struct ieee80211_vif *vif, struct ieee80211_sta *sta,
-                 struct ieee80211_key_conf *key)
+                 unsigned int link_id, struct ieee80211_key_conf *key)
 {
        struct mt7996_dev *dev = mt7996_hw_dev(hw);
+       struct mt7996_sta_link *msta_link;
+       struct mt7996_vif_link *link;
        int idx = key->keyidx;
-       unsigned int link_id;
-       unsigned long links;
+       u8 *wcid_keyidx;
 
-       if (key->link_id >= 0)
-               links = BIT(key->link_id);
-       else if (sta && sta->valid_links)
-               links = sta->valid_links;
-       else if (vif->valid_links)
-               links = vif->valid_links;
-       else
-               links = BIT(0);
+       link = mt7996_vif_link(dev, vif, link_id);
+       if (!link)
+               return 0;
 
-       for_each_set_bit(link_id, &links, IEEE80211_MLD_MAX_NUM_LINKS) {
-               struct mt7996_sta_link *msta_link;
-               struct mt7996_vif_link *link;
-               u8 *wcid_keyidx;
-               int err;
+       if (!mt7996_vif_link_phy(link))
+               return 0;
 
-               link = mt7996_vif_link(dev, vif, link_id);
-               if (!link)
-                       continue;
+       if (sta) {
+               struct mt7996_sta *msta;
 
-               if (sta) {
-                       struct mt7996_sta *msta;
+               msta = (struct mt7996_sta *)sta->drv_priv;
+               msta_link = mt76_dereference(msta->link[link_id],
+                                            &dev->mt76);
+               if (!msta_link)
+                       return 0;
 
-                       msta = (struct mt7996_sta *)sta->drv_priv;
-                       msta_link = mt76_dereference(msta->link[link_id],
-                                                    &dev->mt76);
-                       if (!msta_link)
-                               continue;
+               if (!msta_link->wcid.sta)
+                       return -EOPNOTSUPP;
+       } else {
+               msta_link = &link->msta_link;
+       }
+       wcid_keyidx = &msta_link->wcid.hw_key_idx;
 
-                       if (!msta_link->wcid.sta)
-                               return -EOPNOTSUPP;
-               } else {
-                       msta_link = &link->msta_link;
-               }
-               wcid_keyidx = &msta_link->wcid.hw_key_idx;
-
-               switch (key->cipher) {
-               case WLAN_CIPHER_SUITE_AES_CMAC:
-               case WLAN_CIPHER_SUITE_BIP_CMAC_256:
-               case WLAN_CIPHER_SUITE_BIP_GMAC_128:
-               case WLAN_CIPHER_SUITE_BIP_GMAC_256:
-                       if (key->keyidx == 6 || key->keyidx == 7) {
-                               wcid_keyidx = &msta_link->wcid.hw_key_idx2;
-                               key->flags |= IEEE80211_KEY_FLAG_GENERATE_MMIE;
-                       }
-                       break;
-               default:
-                       break;
+       switch (key->cipher) {
+       case WLAN_CIPHER_SUITE_AES_CMAC:
+       case WLAN_CIPHER_SUITE_BIP_CMAC_256:
+       case WLAN_CIPHER_SUITE_BIP_GMAC_128:
+       case WLAN_CIPHER_SUITE_BIP_GMAC_256:
+               if (key->keyidx == 6 || key->keyidx == 7) {
+                       wcid_keyidx = &msta_link->wcid.hw_key_idx2;
+                       key->flags |= IEEE80211_KEY_FLAG_GENERATE_MMIE;
                }
+               break;
+       default:
+               break;
+       }
 
-               if (cmd == SET_KEY && !sta && !link->mt76.cipher) {
-                       struct ieee80211_bss_conf *link_conf;
-
-                       link_conf = link_conf_dereference_protected(vif,
-                                                                   link_id);
-                       if (!link_conf)
-                               link_conf = &vif->bss_conf;
+       if (cmd == SET_KEY && !sta && !link->mt76.cipher) {
+               struct ieee80211_bss_conf *link_conf;
 
-                       link->mt76.cipher =
-                               mt76_connac_mcu_get_cipher(key->cipher);
-                       mt7996_mcu_add_bss_info(link->phy, vif, link_conf,
-                                               &link->mt76, msta_link, true);
-               }
+               link_conf = link_conf_dereference_protected(vif,
+                                                           link_id);
+               if (!link_conf)
+                       link_conf = &vif->bss_conf;
 
-               if (cmd == SET_KEY)
-                       *wcid_keyidx = idx;
-               else if (idx == *wcid_keyidx)
-                       *wcid_keyidx = -1;
+               link->mt76.cipher =
+                       mt76_connac_mcu_get_cipher(key->cipher);
+               mt7996_mcu_add_bss_info(link->phy, vif, link_conf,
+                                       &link->mt76, msta_link, true);
+       }
 
-               if (cmd != SET_KEY && sta)
-                       continue;
+       if (cmd == SET_KEY)
+               *wcid_keyidx = idx;
+       else if (idx == *wcid_keyidx)
+               *wcid_keyidx = -1;
 
-               mt76_wcid_key_setup(&dev->mt76, &msta_link->wcid, key);
+       if (cmd != SET_KEY && sta)
+               return 0;
 
-               err = mt7996_mcu_add_key(&dev->mt76, vif, key,
-                                        MCU_WMWA_UNI_CMD(STA_REC_UPDATE),
-                                        &msta_link->wcid, cmd);
-               if (err)
-                       return err;
-       }
+       mt76_wcid_key_setup(&dev->mt76, &msta_link->wcid, key);
 
-       return 0;
+       return mt7996_mcu_add_key(&dev->mt76, vif, key,
+                                 MCU_WMWA_UNI_CMD(STA_REC_UPDATE),
+                                 &msta_link->wcid, cmd);
 }
 
+struct mt7996_key_iter_data {
+    enum set_key_cmd cmd;
+    unsigned int link_id;
+};
+
 static void
 mt7996_key_iter(struct ieee80211_hw *hw, struct ieee80211_vif *vif,
                struct ieee80211_sta *sta, struct ieee80211_key_conf *key,
                void *data)
 {
-       enum set_key_cmd *cmd = data;
+       struct mt7996_key_iter_data *it = data;
 
        if (sta)
                return;
 
-       WARN_ON(mt7996_set_hw_key(hw, *cmd, vif, NULL, key));
+       WARN_ON(mt7996_set_hw_key(hw, it->cmd, vif, NULL, it->link_id, key));
 }
 
 int mt7996_vif_link_add(struct mt76_phy *mphy, struct ieee80211_vif *vif,
@@ -293,9 +282,12 @@ int mt7996_vif_link_add(struct mt76_phy *mphy, struct ieee80211_vif *vif,
        struct mt7996_vif *mvif = (struct mt7996_vif *)vif->drv_priv;
        struct mt7996_sta_link *msta_link = &link->msta_link;
        struct mt7996_phy *phy = mphy->priv;
-       enum set_key_cmd key_cmd = SET_KEY;
        struct mt7996_dev *dev = phy->dev;
        u8 band_idx = phy->mt76->band_idx;
+       struct mt7996_key_iter_data it = {
+               .cmd = SET_KEY,
+               .link_id = link_conf->link_id,
+       };
        struct mt76_txq *mtxq;
        int mld_idx, idx, ret;
 
@@ -373,7 +365,7 @@ int mt7996_vif_link_add(struct mt76_phy *mphy, struct ieee80211_vif *vif,
                                   CONN_STATE_PORT_SECURE, true);
        rcu_assign_pointer(dev->mt76.wcid[idx], &msta_link->wcid);
 
-       ieee80211_iter_keys(mphy->hw, vif, mt7996_key_iter, &key_cmd);
+       ieee80211_iter_keys(mphy->hw, vif, mt7996_key_iter, &it);
 
        if (mvif->mt76.deflink_id == IEEE80211_LINK_UNSPECIFIED)
                mvif->mt76.deflink_id = link_conf->link_id;
@@ -388,12 +380,15 @@ void mt7996_vif_link_remove(struct mt76_phy *mphy, struct ieee80211_vif *vif,
        struct mt7996_vif_link *link = container_of(mlink, struct mt7996_vif_link, mt76);
        struct mt7996_vif *mvif = (struct mt7996_vif *)vif->drv_priv;
        struct mt7996_sta_link *msta_link = &link->msta_link;
-       enum set_key_cmd key_cmd = DISABLE_KEY;
        struct mt7996_phy *phy = mphy->priv;
        struct mt7996_dev *dev = phy->dev;
+       struct mt7996_key_iter_data it = {
+               .cmd = SET_KEY,
+               .link_id = link_conf->link_id,
+       };
        int idx = msta_link->wcid.idx;
 
-       ieee80211_iter_keys(mphy->hw, vif, mt7996_key_iter, &key_cmd);
+       ieee80211_iter_keys(mphy->hw, vif, mt7996_key_iter, &it);
 
        mt7996_mcu_add_sta(dev, link_conf, NULL, link, NULL,
                           CONN_STATE_DISCONNECT, false);
@@ -594,8 +589,9 @@ static int mt7996_set_key(struct ieee80211_hw *hw, enum set_key_cmd cmd,
                          struct ieee80211_key_conf *key)
 {
        struct mt7996_dev *dev = mt7996_hw_dev(hw);
-       struct mt7996_vif *mvif = (struct mt7996_vif *)vif->drv_priv;
-       int err;
+       unsigned int link_id;
+       unsigned long links;
+       int err = 0;
 
        /* The hardware does not support per-STA RX GTK, fallback
         * to software mode for these.
@@ -629,11 +625,22 @@ static int mt7996_set_key(struct ieee80211_hw *hw, enum set_key_cmd cmd,
                return -EOPNOTSUPP;
        }
 
-       if (!mt7996_vif_link_phy(&mvif->deflink))
-               return 0; /* defer until after link add */
-
        mutex_lock(&dev->mt76.mutex);
-       err = mt7996_set_hw_key(hw, cmd, vif, sta, key);
+
+       if (key->link_id >= 0)
+               links = BIT(key->link_id);
+       else if (sta && sta->valid_links)
+               links = sta->valid_links;
+       else if (vif->valid_links)
+               links = vif->valid_links;
+       else
+               links = BIT(0);
+
+       for_each_set_bit(link_id, &links, IEEE80211_MLD_MAX_NUM_LINKS) {
+               err = mt7996_set_hw_key(hw, cmd, vif, sta, link_id, key);
+               if (err)
+                       break;
+       }
        mutex_unlock(&dev->mt76.mutex);
 
        return err;