]> git.ipfire.org Git - thirdparty/postgresql.git/commitdiff
Merge similar algorithms into roles_is_member_of().
authorNoah Misch <noah@leadboat.com>
Fri, 26 Mar 2021 17:42:16 +0000 (10:42 -0700)
committerNoah Misch <noah@leadboat.com>
Fri, 26 Mar 2021 17:42:16 +0000 (10:42 -0700)
The next commit would have complicated two or three algorithms, so take
this opportunity to consolidate.  No functional changes.

Reviewed by John Naylor.

Discussion: https://postgr.es/m/20201228043148.GA1053024@rfd.leadboat.com

src/backend/utils/adt/acl.c

index c7f029e2186a1cd73683187b516cd37b54850d12..e6b4bdbd7685b590e3802438ed753906417f0586 100644 (file)
@@ -50,32 +50,24 @@ typedef struct
 /*
  * We frequently need to test whether a given role is a member of some other
  * role.  In most of these tests the "given role" is the same, namely the
- * active current user.  So we can optimize it by keeping a cached list of
- * all the roles the "given role" is a member of, directly or indirectly.
- *
- * There are actually two caches, one computed under "has_privs" rules
- * (do not recurse where rolinherit isn't true) and one computed under
- * "is_member" rules (recurse regardless of rolinherit).
+ * active current user.  So we can optimize it by keeping cached lists of all
+ * the roles the "given role" is a member of, directly or indirectly.
  *
  * Possibly this mechanism should be generalized to allow caching membership
  * info for multiple roles?
  *
- * The has_privs cache is:
- * cached_privs_role is the role OID the cache is for.
- * cached_privs_roles is an OID list of roles that cached_privs_role
- *             has the privileges of (always including itself).
- * The cache is valid if cached_privs_role is not InvalidOid.
- *
- * The is_member cache is similarly:
- * cached_member_role is the role OID the cache is for.
- * cached_membership_roles is an OID list of roles that cached_member_role
- *             is a member of (always including itself).
- * The cache is valid if cached_member_role is not InvalidOid.
+ * Each element of cached_roles is an OID list of constituent roles for the
+ * corresponding element of cached_role (always including the cached_role
+ * itself).  One cache has ROLERECURSE_PRIVS semantics, and the other has
+ * ROLERECURSE_MEMBERS semantics.
  */
-static Oid     cached_privs_role = InvalidOid;
-static List *cached_privs_roles = NIL;
-static Oid     cached_member_role = InvalidOid;
-static List *cached_membership_roles = NIL;
+enum RoleRecurseType
+{
+       ROLERECURSE_PRIVS = 0,          /* recurse if rolinherit */
+       ROLERECURSE_MEMBERS = 1         /* recurse unconditionally */
+};
+static Oid     cached_role[] = {InvalidOid, InvalidOid};
+static List *cached_roles[] = {NIL, NIL};
 
 
 static const char *getid(const char *s, char *n);
@@ -4675,8 +4667,8 @@ initialize_acl(void)
        {
                /*
                 * In normal mode, set a callback on any syscache invalidation of rows
-                * of pg_auth_members (for each AUTHMEM search in this file) or
-                * pg_authid (for has_rolinherit())
+                * of pg_auth_members (for roles_is_member_of()) or pg_authid (for
+                * has_rolinherit())
                 */
                CacheRegisterSyscacheCallback(AUTHMEMROLEMEM,
                                                                          RoleMembershipCacheCallback,
@@ -4695,8 +4687,8 @@ static void
 RoleMembershipCacheCallback(Datum arg, int cacheid, uint32 hashvalue)
 {
        /* Force membership caches to be recomputed on next use */
-       cached_privs_role = InvalidOid;
-       cached_member_role = InvalidOid;
+       cached_role[ROLERECURSE_PRIVS] = InvalidOid;
+       cached_role[ROLERECURSE_MEMBERS] = InvalidOid;
 }
 
 
@@ -4718,30 +4710,35 @@ has_rolinherit(Oid roleid)
 
 
 /*
- * Get a list of roles that the specified roleid has the privileges of
+ * Get a list of roles that the specified roleid is a member of
  *
- * This is defined not to recurse through roles that don't have rolinherit
- * set; for such roles, membership implies the ability to do SET ROLE, but
- * the privileges are not available until you've done so.
+ * Type ROLERECURSE_PRIVS recurses only through roles that have rolinherit
+ * set, while ROLERECURSE_MEMBERS recurses through all roles.  This sets
+ * *is_admin==true if and only if role "roleid" has an ADMIN OPTION membership
+ * in role "admin_of".
  *
  * Since indirect membership testing is relatively expensive, we cache
  * a list of memberships.  Hence, the result is only guaranteed good until
- * the next call of roles_has_privs_of()!
+ * the next call of roles_is_member_of()!
  *
  * For the benefit of select_best_grantor, the result is defined to be
  * in breadth-first order, ie, closer relationships earlier.
  */
 static List *
-roles_has_privs_of(Oid roleid)
+roles_is_member_of(Oid roleid, enum RoleRecurseType type,
+                                  Oid admin_of, bool *is_admin)
 {
        List       *roles_list;
        ListCell   *l;
-       List       *new_cached_privs_roles;
+       List       *new_cached_roles;
        MemoryContext oldctx;
 
-       /* If cache is already valid, just return the list */
-       if (OidIsValid(cached_privs_role) && cached_privs_role == roleid)
-               return cached_privs_roles;
+       Assert(OidIsValid(admin_of) == PointerIsValid(is_admin));
+
+       /* If cache is valid and ADMIN OPTION not sought, just return the list */
+       if (cached_role[type] == roleid && !OidIsValid(admin_of) &&
+               OidIsValid(cached_role[type]))
+               return cached_roles[type];
 
        /*
         * Find all the roles that roleid is a member of, including multi-level
@@ -4762,9 +4759,8 @@ roles_has_privs_of(Oid roleid)
                CatCList   *memlist;
                int                     i;
 
-               /* Ignore non-inheriting roles */
-               if (!has_rolinherit(memberid))
-                       continue;
+               if (type == ROLERECURSE_PRIVS && !has_rolinherit(memberid))
+                       continue;                       /* ignore non-inheriting roles */
 
                /* Find roles that memberid is directly a member of */
                memlist = SearchSysCacheList1(AUTHMEMMEMROLE,
@@ -4775,83 +4771,13 @@ roles_has_privs_of(Oid roleid)
                        Oid                     otherid = ((Form_pg_auth_members) GETSTRUCT(tup))->roleid;
 
                        /*
-                        * Even though there shouldn't be any loops in the membership
-                        * graph, we must test for having already seen this role. It is
-                        * legal for instance to have both A->B and A->C->B.
+                        * While otherid==InvalidOid shouldn't appear in the catalog, the
+                        * OidIsValid() avoids crashing if that arises.
                         */
-                       roles_list = list_append_unique_oid(roles_list, otherid);
-               }
-               ReleaseSysCacheList(memlist);
-       }
-
-       /*
-        * Copy the completed list into TopMemoryContext so it will persist.
-        */
-       oldctx = MemoryContextSwitchTo(TopMemoryContext);
-       new_cached_privs_roles = list_copy(roles_list);
-       MemoryContextSwitchTo(oldctx);
-       list_free(roles_list);
-
-       /*
-        * Now safe to assign to state variable
-        */
-       cached_privs_role = InvalidOid; /* just paranoia */
-       list_free(cached_privs_roles);
-       cached_privs_roles = new_cached_privs_roles;
-       cached_privs_role = roleid;
-
-       /* And now we can return the answer */
-       return cached_privs_roles;
-}
-
-
-/*
- * Get a list of roles that the specified roleid is a member of
- *
- * This is defined to recurse through roles regardless of rolinherit.
- *
- * Since indirect membership testing is relatively expensive, we cache
- * a list of memberships.  Hence, the result is only guaranteed good until
- * the next call of roles_is_member_of()!
- */
-static List *
-roles_is_member_of(Oid roleid)
-{
-       List       *roles_list;
-       ListCell   *l;
-       List       *new_cached_membership_roles;
-       MemoryContext oldctx;
-
-       /* If cache is already valid, just return the list */
-       if (OidIsValid(cached_member_role) && cached_member_role == roleid)
-               return cached_membership_roles;
-
-       /*
-        * Find all the roles that roleid is a member of, including multi-level
-        * recursion.  The role itself will always be the first element of the
-        * resulting list.
-        *
-        * Each element of the list is scanned to see if it adds any indirect
-        * memberships.  We can use a single list as both the record of
-        * already-found memberships and the agenda of roles yet to be scanned.
-        * This is a bit tricky but works because the foreach() macro doesn't
-        * fetch the next list element until the bottom of the loop.
-        */
-       roles_list = list_make1_oid(roleid);
-
-       foreach(l, roles_list)
-       {
-               Oid                     memberid = lfirst_oid(l);
-               CatCList   *memlist;
-               int                     i;
-
-               /* Find roles that memberid is directly a member of */
-               memlist = SearchSysCacheList1(AUTHMEMMEMROLE,
-                                                                         ObjectIdGetDatum(memberid));
-               for (i = 0; i < memlist->n_members; i++)
-               {
-                       HeapTuple       tup = &memlist->members[i]->tuple;
-                       Oid                     otherid = ((Form_pg_auth_members) GETSTRUCT(tup))->roleid;
+                       if (otherid == admin_of &&
+                               ((Form_pg_auth_members) GETSTRUCT(tup))->admin_option &&
+                               OidIsValid(admin_of))
+                               *is_admin = true;
 
                        /*
                         * Even though there shouldn't be any loops in the membership
@@ -4867,20 +4793,20 @@ roles_is_member_of(Oid roleid)
         * Copy the completed list into TopMemoryContext so it will persist.
         */
        oldctx = MemoryContextSwitchTo(TopMemoryContext);
-       new_cached_membership_roles = list_copy(roles_list);
+       new_cached_roles = list_copy(roles_list);
        MemoryContextSwitchTo(oldctx);
        list_free(roles_list);
 
        /*
         * Now safe to assign to state variable
         */
-       cached_member_role = InvalidOid;        /* just paranoia */
-       list_free(cached_membership_roles);
-       cached_membership_roles = new_cached_membership_roles;
-       cached_member_role = roleid;
+       cached_role[type] = InvalidOid; /* just paranoia */
+       list_free(cached_roles[type]);
+       cached_roles[type] = new_cached_roles;
+       cached_role[type] = roleid;
 
        /* And now we can return the answer */
-       return cached_membership_roles;
+       return cached_roles[type];
 }
 
 
@@ -4906,7 +4832,9 @@ has_privs_of_role(Oid member, Oid role)
         * Find all the roles that member has the privileges of, including
         * multi-level recursion, then see if target role is any one of them.
         */
-       return list_member_oid(roles_has_privs_of(member), role);
+       return list_member_oid(roles_is_member_of(member, ROLERECURSE_PRIVS,
+                                                                                         InvalidOid, NULL),
+                                                  role);
 }
 
 
@@ -4930,7 +4858,9 @@ is_member_of_role(Oid member, Oid role)
         * Find all the roles that member is a member of, including multi-level
         * recursion, then see if target role is any one of them.
         */
-       return list_member_oid(roles_is_member_of(member), role);
+       return list_member_oid(roles_is_member_of(member, ROLERECURSE_MEMBERS,
+                                                                                         InvalidOid, NULL),
+                                                  role);
 }
 
 /*
@@ -4964,7 +4894,9 @@ is_member_of_role_nosuper(Oid member, Oid role)
         * Find all the roles that member is a member of, including multi-level
         * recursion, then see if target role is any one of them.
         */
-       return list_member_oid(roles_is_member_of(member), role);
+       return list_member_oid(roles_is_member_of(member, ROLERECURSE_MEMBERS,
+                                                                                         InvalidOid, NULL),
+                                                  role);
 }
 
 
@@ -4977,8 +4909,6 @@ bool
 is_admin_of_role(Oid member, Oid role)
 {
        bool            result = false;
-       List       *roles_list;
-       ListCell   *l;
 
        if (superuser_arg(member))
                return true;
@@ -5016,44 +4946,7 @@ is_admin_of_role(Oid member, Oid role)
                return member == GetSessionUserId() &&
                        !InLocalUserIdChange() && !InSecurityRestrictedOperation();
 
-       /*
-        * Find all the roles that member is a member of, including multi-level
-        * recursion.  We build a list in the same way that is_member_of_role does
-        * to track visited and unvisited roles.
-        */
-       roles_list = list_make1_oid(member);
-
-       foreach(l, roles_list)
-       {
-               Oid                     memberid = lfirst_oid(l);
-               CatCList   *memlist;
-               int                     i;
-
-               /* Find roles that memberid is directly a member of */
-               memlist = SearchSysCacheList1(AUTHMEMMEMROLE,
-                                                                         ObjectIdGetDatum(memberid));
-               for (i = 0; i < memlist->n_members; i++)
-               {
-                       HeapTuple       tup = &memlist->members[i]->tuple;
-                       Oid                     otherid = ((Form_pg_auth_members) GETSTRUCT(tup))->roleid;
-
-                       if (otherid == role &&
-                               ((Form_pg_auth_members) GETSTRUCT(tup))->admin_option)
-                       {
-                               /* Found what we came for, so can stop searching */
-                               result = true;
-                               break;
-                       }
-
-                       roles_list = list_append_unique_oid(roles_list, otherid);
-               }
-               ReleaseSysCacheList(memlist);
-               if (result)
-                       break;
-       }
-
-       list_free(roles_list);
-
+       (void) roles_is_member_of(member, ROLERECURSE_MEMBERS, role, &result);
        return result;
 }
 
@@ -5125,10 +5018,11 @@ select_best_grantor(Oid roleId, AclMode privileges,
        /*
         * Otherwise we have to do a careful search to see if roleId has the
         * privileges of any suitable role.  Note: we can hang onto the result of
-        * roles_has_privs_of() throughout this loop, because aclmask_direct()
+        * roles_is_member_of() throughout this loop, because aclmask_direct()
         * doesn't query any role memberships.
         */
-       roles_list = roles_has_privs_of(roleId);
+       roles_list = roles_is_member_of(roleId, ROLERECURSE_PRIVS,
+                                                                       InvalidOid, NULL);
 
        /* initialize candidate result as default */
        *grantorId = roleId;