]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Respects transaction status of the connection used by TypeInfo.fetch
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 15 Jul 2021 11:34:22 +0000 (13:34 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Jul 2021 15:55:58 +0000 (17:55 +0200)
Also make sure that the method returns None instead of throwing an
exception and leaving the connection broken in case the object is not
found.

psycopg/psycopg/_typeinfo.py
tests/test_typeinfo.py [new file with mode: 0644]
tests/types/test_array.py
tests/types/test_range.py

index 437d1fb81994a789977dd55989843ddc5c606eb4..87faf61190c3513b63216163acf881d9565bec17 100644 (file)
@@ -24,11 +24,6 @@ T = TypeVar("T", bound="TypeInfo")
 class TypeInfo:
     """
     Hold information about a PostgreSQL base type.
-
-    The class allows to:
-
-    - read information about a range type using `fetch()` and `fetch_async()`
-    - configure a composite type adaptation using `register()`
     """
 
     __module__ = "psycopg.types"
@@ -71,8 +66,17 @@ class TypeInfo:
 
         if isinstance(name, Composable):
             name = name.as_string(conn)
+
         cur = conn.cursor(binary=True, row_factory=dict_row)
-        cur.execute(cls._info_query, {"name": name})
+        # This might result in a nested transaction. What we want is to leave
+        # the function with the connection in the state we found (either idle
+        # or intrans)
+        try:
+            with conn.transaction():
+                cur.execute(cls._info_query, {"name": name})
+        except e.UndefinedObject:
+            return None
+
         recs = cur.fetchall()
         return cls._fetch(name, recs)
 
@@ -93,7 +97,12 @@ class TypeInfo:
             name = name.as_string(conn)
 
         cur = conn.cursor(binary=True, row_factory=dict_row)
-        await cur.execute(cls._info_query, {"name": name})
+        try:
+            async with conn.transaction():
+                await cur.execute(cls._info_query, {"name": name})
+        except e.UndefinedObject:
+            return None
+
         recs = await cur.fetchall()
         return cls._fetch(name, recs)
 
diff --git a/tests/test_typeinfo.py b/tests/test_typeinfo.py
new file mode 100644 (file)
index 0000000..ed792c3
--- /dev/null
@@ -0,0 +1,106 @@
+import pytest
+
+import psycopg
+from psycopg import sql
+from psycopg.pq import TransactionStatus
+from psycopg.types import TypeInfo
+
+
+@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+def test_fetch(conn, name, status):
+    status = getattr(TransactionStatus, status)
+    if status == TransactionStatus.INTRANS:
+        conn.execute("select 1")
+
+    assert conn.info.transaction_status == status
+    info = TypeInfo.fetch(conn, name)
+    assert conn.info.transaction_status == status
+
+    assert info.name == "text"
+    # TODO: add the schema?
+    # assert info.schema == "pg_catalog"
+
+    assert info.oid == psycopg.adapters.types["text"].oid
+    assert info.array_oid == psycopg.adapters.types["text"].array_oid
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+async def test_fetch_async(aconn, name, status):
+    status = getattr(TransactionStatus, status)
+    if status == TransactionStatus.INTRANS:
+        await aconn.execute("select 1")
+
+    assert aconn.info.transaction_status == status
+    info = await TypeInfo.fetch_async(aconn, name)
+    assert aconn.info.transaction_status == status
+
+    assert info.name == "text"
+    # assert info.schema == "pg_catalog"
+    assert info.oid == psycopg.adapters.types["text"].oid
+    assert info.array_oid == psycopg.adapters.types["text"].array_oid
+
+
+@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+def test_fetch_not_found(conn, name, status):
+    status = getattr(TransactionStatus, status)
+    if status == TransactionStatus.INTRANS:
+        conn.execute("select 1")
+
+    assert conn.info.transaction_status == status
+    info = TypeInfo.fetch(conn, name)
+    assert conn.info.transaction_status == status
+    assert info is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+async def test_fetch_not_found_async(aconn, name, status):
+    status = getattr(TransactionStatus, status)
+    if status == TransactionStatus.INTRANS:
+        await aconn.execute("select 1")
+
+    assert aconn.info.transaction_status == status
+    info = await TypeInfo.fetch_async(aconn, name)
+    assert aconn.info.transaction_status == status
+
+    assert info is None
+
+
+@pytest.mark.parametrize(
+    "name", ["testschema.testtype", sql.Identifier("testschema", "testtype")]
+)
+def test_fetch_by_schema_qualified_string(conn, name):
+    conn.execute("create schema if not exists testschema")
+    conn.execute("create type testschema.testtype as (foo text)")
+
+    info = TypeInfo.fetch(conn, name)
+    assert info.name == "testtype"
+    # assert info.schema == "testschema"
+    cur = conn.execute(
+        """
+        select oid, typarray from pg_type
+        where oid = 'testschema.testtype'::regtype
+        """
+    )
+    assert cur.fetchone() == (info.oid, info.array_oid)
+
+
+@pytest.mark.parametrize(
+    "name",
+    [
+        "text",
+        # TODO: support these?
+        # "pg_catalog.text",
+        # sql.Identifier("text"),
+        # sql.Identifier("pg_catalog", "text"),
+    ],
+)
+def test_registry_by_builtin_name(conn, name):
+    info = psycopg.adapters.types[name]
+    assert info.name == "text"
+    assert info.oid == 25
index dbee442f3a451a79de445981aa302eef60a304e5..4e7296654b077fde761d9e5f42cb31a3af0d4a1d 100644 (file)
@@ -111,17 +111,16 @@ def test_load_list_int(conn, obj, want, fmt_out):
 
 
 def test_array_register(conn):
-    cur = conn.cursor()
-    cur.execute("create table mytype (data text)")
-    cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[] -- 1""")
+    conn.execute("create table mytype (data text)")
+    cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
     res = cur.fetchone()
     assert res[0] == "(foo)"
     assert res[1] == "{(foo)}"
 
     info = TypeInfo.fetch(conn, "mytype")
-    info.register(cur)
+    info.register(conn)
 
-    cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[] -- 2""")
+    cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
     res = cur.fetchone()
     assert res[0] == "(foo)"
     assert res[1] == ["(foo)"]
index 9a35230ad5944958ded33d0db57a4cffe9bb5613..10beb092ae9f49cd774a48650e482ec9426eef58 100644 (file)
@@ -246,8 +246,7 @@ def test_fetch_info(conn, testrange, name, subtype):
 
 
 def test_fetch_info_not_found(conn):
-    with pytest.raises(conn.ProgrammingError):
-        RangeInfo.fetch(conn, "nosuchrange")
+    assert RangeInfo.fetch(conn, "nosuchrange") is None
 
 
 @pytest.mark.asyncio
@@ -262,8 +261,7 @@ async def test_fetch_info_async(aconn, testrange, name, subtype):
 
 @pytest.mark.asyncio
 async def test_fetch_info_not_found_async(aconn):
-    with pytest.raises(aconn.ProgrammingError):
-        await RangeInfo.fetch_async(aconn, "nosuchrange")
+    assert await RangeInfo.fetch_async(aconn, "nosuchrange") is None
 
 
 def test_dump_custom_empty(conn, testrange):