]> git.ipfire.org Git - thirdparty/kernel/stable.git/commitdiff
wifi: mac80211: verify BSS membership selectors and basic rates
authorBenjamin Berg <benjamin.berg@intel.com>
Wed, 1 Jan 2025 05:05:36 +0000 (07:05 +0200)
committerJohannes Berg <johannes.berg@intel.com>
Mon, 13 Jan 2025 14:26:45 +0000 (15:26 +0100)
We should not attempt a connection if the BSS we are connecting to
requires support for a basic rate or other feature using the BSS
membership selector. Add a check verifying this.

Signed-off-by: Benjamin Berg <benjamin.berg@intel.com>
Reviewed-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: Miri Korenblit <miriam.rachel.korenblit@intel.com>
Link: https://patch.msgid.link/20250101070249.e58a0f34c798.Ifeb3bfd7b157ffa2ccdb20ca1cba6cf068fd117d@changeid
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
net/mac80211/ieee80211_i.h
net/mac80211/mlme.c

index a214335f9fccc501cce594a6bf69e71e1b1e9db2..69a82298c7cd9baf8ea6f0fbb986f7e11ec340df 100644 (file)
@@ -404,6 +404,8 @@ struct ieee80211_mgd_auth_data {
        int tries;
        u16 algorithm, expected_transaction;
 
+       unsigned long userspace_selectors[BITS_TO_LONGS(128)];
+
        u8 key[WLAN_KEY_LEN_WEP104];
        u8 key_len, key_idx;
        bool done, waiting;
@@ -444,6 +446,8 @@ struct ieee80211_mgd_assoc_data {
        const u8 *supp_rates;
        u8 supp_rates_len;
 
+       unsigned long userspace_selectors[BITS_TO_LONGS(128)];
+
        unsigned long timeout;
        int tries;
 
index e9df97fc448b7aa12bc3e1aee9eb6dbc2e3bda7e..6b885e97e720cb4c4be1d91baeab80636f6efa05 100644 (file)
@@ -590,6 +590,63 @@ ieee80211_verify_sta_eht_mcs_support(struct ieee80211_sub_if_data *sdata,
        return true;
 }
 
+static void ieee80211_get_rates(struct ieee80211_supported_band *sband,
+                               const u8 *supp_rates,
+                               unsigned int supp_rates_len,
+                               u32 *rates, u32 *basic_rates,
+                               unsigned long *unknown_rates_selectors,
+                               bool *have_higher_than_11mbit,
+                               int *min_rate, int *min_rate_index)
+{
+       int i, j;
+
+       for (i = 0; i < supp_rates_len; i++) {
+               int rate = supp_rates[i] & 0x7f;
+               bool is_basic = !!(supp_rates[i] & 0x80);
+
+               if ((rate * 5) > 110 && have_higher_than_11mbit)
+                       *have_higher_than_11mbit = true;
+
+               /*
+                * Skip membership selectors since they're not rates.
+                *
+                * Note: Even though the membership selector and the basic
+                *       rate flag share the same bit, they are not exactly
+                *       the same.
+                */
+               if (is_basic && rate >= BSS_MEMBERSHIP_SELECTOR_MIN) {
+                       if (unknown_rates_selectors)
+                               set_bit(rate, unknown_rates_selectors);
+                       continue;
+               }
+
+               for (j = 0; j < sband->n_bitrates; j++) {
+                       struct ieee80211_rate *br;
+                       int brate;
+
+                       br = &sband->bitrates[j];
+
+                       brate = DIV_ROUND_UP(br->bitrate, 5);
+                       if (brate == rate) {
+                               if (rates)
+                                       *rates |= BIT(j);
+                               if (is_basic && basic_rates)
+                                       *basic_rates |= BIT(j);
+                               if (min_rate && (rate * 5) < *min_rate) {
+                                       *min_rate = rate * 5;
+                                       if (min_rate_index)
+                                               *min_rate_index = j;
+                               }
+                               break;
+                       }
+               }
+
+               /* Handle an unknown entry as if it is an unknown selector */
+               if (is_basic && unknown_rates_selectors && j == sband->n_bitrates)
+                       set_bit(rate, unknown_rates_selectors);
+       }
+}
+
 static bool ieee80211_chandef_usable(struct ieee80211_sub_if_data *sdata,
                                     const struct cfg80211_chan_def *chandef,
                                     u32 prohibited_flags)
@@ -820,7 +877,8 @@ ieee80211_determine_chan_mode(struct ieee80211_sub_if_data *sdata,
                              struct ieee80211_conn_settings *conn,
                              struct cfg80211_bss *cbss, int link_id,
                              struct ieee80211_chan_req *chanreq,
-                             struct cfg80211_chan_def *ap_chandef)
+                             struct cfg80211_chan_def *ap_chandef,
+                             unsigned long *userspace_selectors)
 {
        const struct cfg80211_bss_ies *ies = rcu_dereference(cbss->ies);
        struct ieee80211_bss *bss = (void *)cbss->priv;
@@ -834,6 +892,8 @@ ieee80211_determine_chan_mode(struct ieee80211_sub_if_data *sdata,
        struct ieee802_11_elems *elems;
        struct ieee80211_supported_band *sband;
        enum ieee80211_conn_mode ap_mode;
+       unsigned long unknown_rates_selectors[BITS_TO_LONGS(128)] = {};
+       unsigned long sta_selectors[BITS_TO_LONGS(128)] = {};
        int ret;
 
 again:
@@ -862,6 +922,10 @@ again:
 
        sband = sdata->local->hw.wiphy->bands[channel->band];
 
+       ieee80211_get_rates(sband, elems->supp_rates, elems->supp_rates_len,
+                           NULL, NULL, unknown_rates_selectors, NULL, NULL,
+                           NULL);
+
        switch (channel->band) {
        case NL80211_BAND_S1GHZ:
                if (WARN_ON(ap_mode != IEEE80211_CONN_MODE_S1G)) {
@@ -912,6 +976,29 @@ again:
 
        chanreq->oper = *ap_chandef;
 
+       bitmap_copy(sta_selectors, userspace_selectors, 128);
+       if (conn->mode >= IEEE80211_CONN_MODE_HT)
+               set_bit(BSS_MEMBERSHIP_SELECTOR_HT_PHY, sta_selectors);
+       if (conn->mode >= IEEE80211_CONN_MODE_VHT)
+               set_bit(BSS_MEMBERSHIP_SELECTOR_VHT_PHY, sta_selectors);
+       if (conn->mode >= IEEE80211_CONN_MODE_HE)
+               set_bit(BSS_MEMBERSHIP_SELECTOR_HE_PHY, sta_selectors);
+       if (conn->mode >= IEEE80211_CONN_MODE_EHT)
+               set_bit(BSS_MEMBERSHIP_SELECTOR_EHT_PHY, sta_selectors);
+
+       /*
+        * We do not support EPD or GLK so never add them.
+        * SAE_H2E is handled through userspace_selectors.
+        */
+
+       /* Check if we support all required features */
+       if (!bitmap_subset(unknown_rates_selectors, sta_selectors, 128)) {
+               link_id_info(sdata, link_id,
+                            "required basic rate or BSS membership selectors not supported or disabled, rejecting connection\n");
+               ret = -EINVAL;
+               goto free;
+       }
+
        ieee80211_set_chanreq_ap(sdata, chanreq, conn, ap_chandef);
 
        while (!ieee80211_chandef_usable(sdata, &chanreq->oper,
@@ -4625,62 +4712,6 @@ static void ieee80211_rx_mgmt_disassoc(struct ieee80211_sub_if_data *sdata,
                                    false);
 }
 
-static void ieee80211_get_rates(struct ieee80211_supported_band *sband,
-                               u8 *supp_rates, unsigned int supp_rates_len,
-                               u32 *rates, u32 *basic_rates,
-                               unsigned long *unknown_rates_selectors,
-                               bool *have_higher_than_11mbit,
-                               int *min_rate, int *min_rate_index)
-{
-       int i, j;
-
-       for (i = 0; i < supp_rates_len; i++) {
-               int rate = supp_rates[i] & 0x7f;
-               bool is_basic = !!(supp_rates[i] & 0x80);
-
-               if ((rate * 5) > 110 && have_higher_than_11mbit)
-                       *have_higher_than_11mbit = true;
-
-               /*
-                * Skip membership selectors since they're not rates.
-                *
-                * Note: Even though the membership selector and the basic
-                *       rate flag share the same bit, they are not exactly
-                *       the same.
-                */
-               if (is_basic && rate >= BSS_MEMBERSHIP_SELECTOR_MIN) {
-                       if (unknown_rates_selectors)
-                               set_bit(rate, unknown_rates_selectors);
-                       continue;
-               }
-
-               for (j = 0; j < sband->n_bitrates; j++) {
-                       struct ieee80211_rate *br;
-                       int brate;
-
-                       br = &sband->bitrates[j];
-
-                       brate = DIV_ROUND_UP(br->bitrate, 5);
-                       if (brate == rate) {
-                               if (rates)
-                                       *rates |= BIT(j);
-                               if (is_basic && basic_rates)
-                                       *basic_rates |= BIT(j);
-                               if (min_rate && (rate * 5) < *min_rate) {
-                                       *min_rate = rate * 5;
-                                       if (min_rate_index)
-                                               *min_rate_index = j;
-                               }
-                               break;
-                       }
-               }
-
-               /* Handle an unknown entry as if it is an unknown selector */
-               if (is_basic && unknown_rates_selectors && j == sband->n_bitrates)
-                       set_bit(rate, unknown_rates_selectors);
-       }
-}
-
 static bool ieee80211_twt_req_supported(struct ieee80211_sub_if_data *sdata,
                                        struct ieee80211_supported_band *sband,
                                        const struct link_sta_info *link_sta,
@@ -5546,7 +5577,8 @@ static int ieee80211_prep_channel(struct ieee80211_sub_if_data *sdata,
                                  struct ieee80211_link_data *link,
                                  int link_id,
                                  struct cfg80211_bss *cbss, bool mlo,
-                                 struct ieee80211_conn_settings *conn)
+                                 struct ieee80211_conn_settings *conn,
+                                 unsigned long *userspace_selectors)
 {
        struct ieee80211_local *local = sdata->local;
        bool is_6ghz = cbss->channel->band == NL80211_BAND_6GHZ;
@@ -5559,7 +5591,8 @@ static int ieee80211_prep_channel(struct ieee80211_sub_if_data *sdata,
 
        rcu_read_lock();
        elems = ieee80211_determine_chan_mode(sdata, conn, cbss, link_id,
-                                             &chanreq, &ap_chandef);
+                                             &chanreq, &ap_chandef,
+                                             userspace_selectors);
 
        if (IS_ERR(elems)) {
                rcu_read_unlock();
@@ -5753,7 +5786,8 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata,
                        link->u.mgd.conn = assoc_data->link[link_id].conn;
 
                        err = ieee80211_prep_channel(sdata, link, link_id, cbss,
-                                                    true, &link->u.mgd.conn);
+                                                    true, &link->u.mgd.conn,
+                                                    assoc_data->userspace_selectors);
                        if (err) {
                                link_info(link, "prep_channel failed\n");
                                goto out_err;
@@ -8263,7 +8297,8 @@ static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata,
                                     struct cfg80211_bss *cbss, s8 link_id,
                                     const u8 *ap_mld_addr, bool assoc,
                                     struct ieee80211_conn_settings *conn,
-                                    bool override)
+                                    bool override,
+                                    unsigned long *userspace_selectors)
 {
        struct ieee80211_local *local = sdata->local;
        struct ieee80211_if_managed *ifmgd = &sdata->u.mgd;
@@ -8402,7 +8437,8 @@ static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata,
                 */
                link->u.mgd.conn = *conn;
                err = ieee80211_prep_channel(sdata, link, link->link_id, cbss,
-                                            mlo, &link->u.mgd.conn);
+                                            mlo, &link->u.mgd.conn,
+                                            userspace_selectors);
                if (err) {
                        if (new_sta)
                                sta_info_free(local, new_sta);
@@ -8518,6 +8554,22 @@ out:
        return ret;
 }
 
+static void ieee80211_parse_cfg_selectors(unsigned long *userspace_selectors,
+                                         const u8 *supported_selectors,
+                                         u8 supported_selectors_len)
+{
+       if (supported_selectors) {
+               for (int i = 0; i < supported_selectors_len; i++) {
+                       set_bit(supported_selectors[i],
+                               userspace_selectors);
+               }
+       } else {
+               /* Assume SAE_H2E support for backward compatibility. */
+               set_bit(BSS_MEMBERSHIP_SELECTOR_SAE_H2E,
+                       userspace_selectors);
+       }
+}
+
 /* config hooks */
 int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
                       struct cfg80211_auth_request *req)
@@ -8619,6 +8671,10 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
                memcpy(auth_data->key, req->key, req->key_len);
        }
 
+       ieee80211_parse_cfg_selectors(auth_data->userspace_selectors,
+                                     req->supported_selectors,
+                                     req->supported_selectors_len);
+
        auth_data->algorithm = auth_alg;
 
        /* try to authenticate/probe */
@@ -8672,7 +8728,8 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 
        err = ieee80211_prep_connection(sdata, req->bss, req->link_id,
                                        req->ap_mld_addr, cont_auth,
-                                       &conn, false);
+                                       &conn, false,
+                                       auth_data->userspace_selectors);
        if (err)
                goto err_clear;
 
@@ -8959,6 +9016,10 @@ int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata,
                                            false);
        }
 
+       ieee80211_parse_cfg_selectors(assoc_data->userspace_selectors,
+                                     req->supported_selectors,
+                                     req->supported_selectors_len);
+
        memcpy(&ifmgd->ht_capa, &req->ht_capa, sizeof(ifmgd->ht_capa));
        memcpy(&ifmgd->ht_capa_mask, &req->ht_capa_mask,
               sizeof(ifmgd->ht_capa_mask));
@@ -9205,7 +9266,8 @@ int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata,
                /* only calculate the mode, hence link == NULL */
                err = ieee80211_prep_channel(sdata, NULL, i,
                                             assoc_data->link[i].bss, true,
-                                            &assoc_data->link[i].conn);
+                                            &assoc_data->link[i].conn,
+                                            assoc_data->userspace_selectors);
                if (err) {
                        req->links[i].error = err;
                        goto err_clear;
@@ -9221,7 +9283,8 @@ int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata,
        err = ieee80211_prep_connection(sdata, cbss, req->link_id,
                                        req->ap_mld_addr, true,
                                        &assoc_data->link[assoc_link_id].conn,
-                                       override);
+                                       override,
+                                       assoc_data->userspace_selectors);
        if (err)
                goto err_clear;