#include "reg.h"
 #include "util.h"
 
+struct rtw89_eapol_2_of_2 {
+       struct ieee80211_hdr_3addr hdr;
+       u8 gtkbody[14];
+       u8 key_des_ver;
+       u8 rsvd[92];
+} __packed __aligned(2);
+
 static const u8 mss_signature[] = {0x4D, 0x53, 0x53, 0x4B, 0x50, 0x4F, 0x4F, 0x4C};
 
 union rtw89_fw_element_arg {
        return ret;
 }
 
+static struct sk_buff *rtw89_eapol_get(struct rtw89_dev *rtwdev,
+                                      struct rtw89_vif *rtwvif)
+{
+       static const u8 gtkbody[] = {0xAA, 0xAA, 0x03, 0x00, 0x00, 0x00, 0x88,
+                                    0x8E, 0x01, 0x03, 0x00, 0x5F, 0x02, 0x03};
+       struct ieee80211_vif *vif = rtwvif_to_vif(rtwvif);
+       struct ieee80211_bss_conf *bss_conf = &vif->bss_conf;
+       struct rtw89_wow_param *rtw_wow = &rtwdev->wow;
+       struct rtw89_eapol_2_of_2 *eapol_pkt;
+       struct sk_buff *skb;
+       u8 key_des_ver;
+
+       if (rtw_wow->ptk_alg == 3)
+               key_des_ver = 1;
+       else if (rtw_wow->akm == 1 || rtw_wow->akm == 2)
+               key_des_ver = 2;
+       else if (rtw_wow->akm > 2 && rtw_wow->akm < 7)
+               key_des_ver = 3;
+       else
+               key_des_ver = 0;
+
+       skb = dev_alloc_skb(sizeof(*eapol_pkt));
+       if (!skb)
+               return NULL;
+
+       eapol_pkt = skb_put_zero(skb, sizeof(*eapol_pkt));
+       eapol_pkt->hdr.frame_control = cpu_to_le16(IEEE80211_FTYPE_DATA |
+                                                  IEEE80211_FCTL_TODS |
+                                                  IEEE80211_FCTL_PROTECTED);
+       ether_addr_copy(eapol_pkt->hdr.addr1, bss_conf->bssid);
+       ether_addr_copy(eapol_pkt->hdr.addr2, vif->addr);
+       ether_addr_copy(eapol_pkt->hdr.addr3, bss_conf->bssid);
+       memcpy(eapol_pkt->gtkbody, gtkbody, sizeof(gtkbody));
+       eapol_pkt->key_des_ver = key_des_ver;
+
+       return skb;
+}
+
 static int rtw89_fw_h2c_add_general_pkt(struct rtw89_dev *rtwdev,
                                        struct rtw89_vif *rtwvif,
                                        enum rtw89_fw_pkt_ofld_type type,
        case RTW89_PKT_OFLD_TYPE_QOS_NULL:
                skb = ieee80211_nullfunc_get(rtwdev->hw, vif, -1, true);
                break;
+       case RTW89_PKT_OFLD_TYPE_EAPOL_KEY:
+               skb = rtw89_eapol_get(rtwdev, rtwvif);
+               break;
        default:
                goto err;
        }
        return ret;
 }
 
+int rtw89_fw_h2c_wow_gtk_ofld(struct rtw89_dev *rtwdev,
+                             struct rtw89_vif *rtwvif,
+                             bool enable)
+{
+       struct rtw89_wow_param *rtw_wow = &rtwdev->wow;
+       struct rtw89_wow_gtk_info *gtk_info = &rtw_wow->gtk_info;
+       struct rtw89_h2c_wow_gtk_ofld *h2c;
+       u8 macid = rtwvif->mac_id;
+       u32 len = sizeof(*h2c);
+       struct sk_buff *skb;
+       u8 pkt_id_eapol = 0;
+       int ret;
+
+       if (!rtw_wow->gtk_alg)
+               return 0;
+
+       skb = rtw89_fw_h2c_alloc_skb_with_hdr(rtwdev, len);
+       if (!skb) {
+               rtw89_err(rtwdev, "failed to alloc skb for gtk ofld\n");
+               return -ENOMEM;
+       }
+
+       skb_put(skb, len);
+       h2c = (struct rtw89_h2c_wow_gtk_ofld *)skb->data;
+
+       if (!enable) {
+               skb_put_zero(skb, sizeof(*gtk_info));
+               goto hdr;
+       }
+
+       ret = rtw89_fw_h2c_add_general_pkt(rtwdev, rtwvif,
+                                          RTW89_PKT_OFLD_TYPE_EAPOL_KEY,
+                                          &pkt_id_eapol);
+       if (ret)
+               goto fail;
+
+       /* not support TKIP and IEEE80211W yet */
+       h2c->w0 = le32_encode_bits(enable, RTW89_H2C_WOW_GTK_OFLD_W0_EN) |
+                 le32_encode_bits(0, RTW89_H2C_WOW_GTK_OFLD_W0_TKIP_EN) |
+                 le32_encode_bits(0, RTW89_H2C_WOW_GTK_OFLD_W0_IEEE80211W_EN) |
+                 le32_encode_bits(macid, RTW89_H2C_WOW_GTK_OFLD_W0_MAC_ID) |
+                 le32_encode_bits(pkt_id_eapol, RTW89_H2C_WOW_GTK_OFLD_W0_GTK_RSP_ID);
+       h2c->w1 = le32_encode_bits(rtw_wow->akm, RTW89_H2C_WOW_GTK_OFLD_W1_ALGO_AKM_SUIT);
+       h2c->gtk_info = rtw_wow->gtk_info;
+
+hdr:
+       rtw89_h2c_pkt_set_hdr(rtwdev, skb, FWCMD_TYPE_H2C,
+                             H2C_CAT_MAC,
+                             H2C_CL_MAC_WOW,
+                             H2C_FUNC_GTK_OFLD, 0, 1,
+                             len);
+
+       ret = rtw89_h2c_tx(rtwdev, skb, false);
+       if (ret) {
+               rtw89_err(rtwdev, "failed to send h2c\n");
+               goto fail;
+       }
+       return 0;
+
+fail:
+       dev_kfree_skb_any(skb);
+
+       return ret;
+}
+
 /* Return < 0, if failures happen during waiting for the condition.
  * Return 0, when waiting for the condition succeeds.
  * Return > 0, if the wait is considered unreachable due to driver/FW design,
 
        le32p_replace_bits((__le32 *)h2c + 5, val, BIT(31));
 }
 
+struct rtw89_h2c_wow_gtk_ofld {
+       __le32 w0;
+       __le32 w1;
+       struct rtw89_wow_gtk_info gtk_info;
+} __packed;
+
+#define RTW89_H2C_WOW_GTK_OFLD_W0_EN BIT(0)
+#define RTW89_H2C_WOW_GTK_OFLD_W0_TKIP_EN BIT(1)
+#define RTW89_H2C_WOW_GTK_OFLD_W0_IEEE80211W_EN BIT(2)
+#define RTW89_H2C_WOW_GTK_OFLD_W0_PAIRWISE_WAKEUP BIT(3)
+#define RTW89_H2C_WOW_GTK_OFLD_W0_NOREKEY_WAKEUP BIT(4)
+#define RTW89_H2C_WOW_GTK_OFLD_W0_MAC_ID GENMASK(23, 16)
+#define RTW89_H2C_WOW_GTK_OFLD_W0_GTK_RSP_ID GENMASK(31, 24)
+#define RTW89_H2C_WOW_GTK_OFLD_W1_PMF_SA_QUERY_ID GENMASK(7, 0)
+#define RTW89_H2C_WOW_GTK_OFLD_W1_PMF_BIP_SEC_ALGO GENMASK(9, 8)
+#define RTW89_H2C_WOW_GTK_OFLD_W1_ALGO_AKM_SUIT GENMASK(17, 10)
+
 enum rtw89_btc_btf_h2c_class {
        BTFC_SET = 0x10,
        BTFC_GET = 0x11,
 #define H2C_FUNC_KEEP_ALIVE            0x0
 #define H2C_FUNC_DISCONNECT_DETECT     0x1
 #define H2C_FUNC_WOW_GLOBAL            0x2
+#define H2C_FUNC_GTK_OFLD              0x3
 #define H2C_FUNC_WAKEUP_CTRL           0x8
 #define H2C_FUNC_WOW_CAM_UPD           0xC
 
                                 struct rtw89_vif *rtwvif, bool enable);
 int rtw89_fw_wow_cam_update(struct rtw89_dev *rtwdev,
                            struct rtw89_wow_cam_info *cam_info);
+int rtw89_fw_h2c_wow_gtk_ofld(struct rtw89_dev *rtwdev,
+                             struct rtw89_vif *rtwvif,
+                             bool enable);
 int rtw89_fw_h2c_add_mcc(struct rtw89_dev *rtwdev,
                         const struct rtw89_fw_mcc_add_req *p);
 int rtw89_fw_h2c_start_mcc(struct rtw89_dev *rtwdev,
 
 
        device_set_wakeup_enable(rtwdev->dev, enabled);
 }
+
+static void rtw89_set_rekey_data(struct ieee80211_hw *hw,
+                                struct ieee80211_vif *vif,
+                                struct cfg80211_gtk_rekey_data *data)
+{
+       struct rtw89_dev *rtwdev = hw->priv;
+       struct rtw89_wow_param *rtw_wow = &rtwdev->wow;
+       struct rtw89_wow_gtk_info *gtk_info = &rtw_wow->gtk_info;
+
+       if (data->kek_len > sizeof(gtk_info->kek) ||
+           data->kck_len > sizeof(gtk_info->kck)) {
+               rtw89_warn(rtwdev, "kek or kck length over fw limit\n");
+               return;
+       }
+
+       mutex_lock(&rtwdev->mutex);
+
+       memcpy(gtk_info->kek, data->kek, data->kek_len);
+       memcpy(gtk_info->kck, data->kck, data->kck_len);
+
+       mutex_unlock(&rtwdev->mutex);
+}
 #endif
 
 const struct ieee80211_ops rtw89_ops = {
        .suspend                = rtw89_ops_suspend,
        .resume                 = rtw89_ops_resume,
        .set_wakeup             = rtw89_ops_set_wakeup,
+       .set_rekey_data         = rtw89_set_rekey_data,
 #endif
 };
 EXPORT_SYMBOL(rtw89_ops);