]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: fix TypeInfo.fetch() with connections using RawCursor as factory
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 13 Aug 2023 09:48:26 +0000 (10:48 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 15 Aug 2023 15:29:03 +0000 (16:29 +0100)
psycopg/psycopg/_typeinfo.py

index f33866175606281995e470c988eecfbef1924e2b..dcbb2c0950821f7f4229b34ae018c746d1eefa1d 100644 (file)
@@ -93,12 +93,13 @@ class TypeInfo:
         # the function with the connection in the state we found (either idle
         # or intrans)
         try:
-            with conn.transaction():
+            from psycopg import Cursor
+
+            with conn.transaction(), Cursor(conn, row_factory=dict_row) as cur:
                 if conn_encoding(conn) == "ascii":
-                    conn.execute("set local client_encoding to utf8")
-                with conn.cursor(row_factory=dict_row) as cur:
-                    cur.execute(cls._get_info_query(conn), {"name": name})
-                    recs = cur.fetchall()
+                    cur.execute("set local client_encoding to utf8")
+                cur.execute(cls._get_info_query(conn), {"name": name})
+                recs = cur.fetchall()
         except e.UndefinedObject:
             return None
 
@@ -109,10 +110,12 @@ class TypeInfo:
         cls: Type[T], conn: "AsyncConnection[Any]", name: str
     ) -> Optional[T]:
         try:
+            from psycopg import AsyncCursor
+
             async with conn.transaction():
-                if conn_encoding(conn) == "ascii":
-                    await conn.execute("set local client_encoding to utf8")
-                async with conn.cursor(row_factory=dict_row) as cur:
+                async with AsyncCursor(conn, row_factory=dict_row) as cur:
+                    if conn_encoding(conn) == "ascii":
+                        await cur.execute("set local client_encoding to utf8")
                     await cur.execute(cls._get_info_query(conn), {"name": name})
                     recs = await cur.fetchall()
         except e.UndefinedObject: