]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): make sync and async connection tests more similar
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 7 Aug 2023 22:02:04 +0000 (23:02 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/_test_connection.py [new file with mode: 0644]
tests/test_connection.py
tests/test_connection_async.py

diff --git a/tests/_test_connection.py b/tests/_test_connection.py
new file mode 100644 (file)
index 0000000..296a7f7
--- /dev/null
@@ -0,0 +1,100 @@
+"""
+Support module for test_connection[_async].py
+"""
+
+from typing import Any, List
+from dataclasses import dataclass
+
+import pytest
+import psycopg
+
+
+@pytest.fixture
+def testctx(svcconn):
+    svcconn.execute("create table if not exists testctx (id int primary key)")
+    svcconn.execute("delete from testctx")
+    return None
+
+
+@dataclass
+class ParamDef:
+    name: str
+    guc: str
+    values: List[Any]
+    non_default: str
+
+
+param_isolation = ParamDef(
+    name="isolation_level",
+    guc="isolation",
+    values=list(psycopg.IsolationLevel),
+    non_default="serializable",
+)
+param_read_only = ParamDef(
+    name="read_only",
+    guc="read_only",
+    values=[True, False],
+    non_default="on",
+)
+param_deferrable = ParamDef(
+    name="deferrable",
+    guc="deferrable",
+    values=[True, False],
+    non_default="on",
+)
+
+# Map Python values to Postgres values for the tx_params possible values
+tx_values_map = {
+    v.name.lower().replace("_", " "): v.value for v in psycopg.IsolationLevel
+}
+tx_values_map["on"] = True
+tx_values_map["off"] = False
+
+
+tx_params = [
+    param_isolation,
+    param_read_only,
+    pytest.param(param_deferrable, marks=pytest.mark.crdb_skip("deferrable")),
+]
+tx_params_isolation = [
+    pytest.param(
+        param_isolation,
+        id="isolation_level",
+        marks=pytest.mark.crdb("skip", reason="transaction isolation"),
+    ),
+    pytest.param(
+        param_read_only, id="read_only", marks=pytest.mark.crdb_skip("begin_read_only")
+    ),
+    pytest.param(
+        param_deferrable, id="deferrable", marks=pytest.mark.crdb_skip("deferrable")
+    ),
+]
+
+
+conninfo_params_timeout = [
+    (
+        "",
+        {"dbname": "mydb", "connect_timeout": None},
+        ({"dbname": "mydb"}, None),
+    ),
+    (
+        "",
+        {"dbname": "mydb", "connect_timeout": 1},
+        ({"dbname": "mydb", "connect_timeout": "1"}, 1),
+    ),
+    (
+        "dbname=postgres",
+        {},
+        ({"dbname": "postgres"}, None),
+    ),
+    (
+        "dbname=postgres connect_timeout=2",
+        {},
+        ({"dbname": "postgres", "connect_timeout": "2"}, 2),
+    ),
+    (
+        "postgresql:///postgres?connect_timeout=2",
+        {"connect_timeout": 10},
+        ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+    ),
+]
index 6092bebebdc4ffb3fb64faddc42fc25fe5086820..a15a21ec96027bb5a5c457c13b7a533bdb8f70b2 100644 (file)
@@ -4,7 +4,6 @@ import pytest
 import logging
 import weakref
 from typing import Any, List
-from dataclasses import dataclass
 
 import psycopg
 from psycopg import Notify, pq, errors as e
@@ -13,6 +12,9 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
 from .utils import gc_collect
 from ._test_cursor import my_row_factory
+from ._test_connection import tx_params, tx_params_isolation, tx_values_map
+from ._test_connection import conninfo_params_timeout
+from ._test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
 
 
@@ -23,6 +25,11 @@ def test_connect(conn_cls, dsn):
     conn.close()
 
 
+def test_connect_bad(conn_cls):
+    with pytest.raises(psycopg.OperationalError):
+        conn_cls.connect("dbname=nosuchdb")
+
+
 def test_connect_str_subclass(conn_cls, dsn):
     class MyString(str):
         pass
@@ -33,11 +40,6 @@ def test_connect_str_subclass(conn_cls, dsn):
     conn.close()
 
 
-def test_connect_bad(conn_cls):
-    with pytest.raises(psycopg.OperationalError):
-        conn_cls.connect("dbname=nosuchdb")
-
-
 @pytest.mark.slow
 @pytest.mark.timing
 def test_connect_timeout(conn_cls, deaf_port):
@@ -83,6 +85,8 @@ def test_cursor_closed(conn):
     with pytest.raises(psycopg.OperationalError):
         with conn.cursor("foo"):
             pass
+    with pytest.raises(psycopg.OperationalError):
+        conn.cursor("foo")
     with pytest.raises(psycopg.OperationalError):
         conn.cursor()
 
@@ -123,14 +127,8 @@ def test_connection_warn_close(conn_cls, dsn, recwarn):
     assert not recwarn, [str(w.message) for w in recwarn.list]
 
 
-@pytest.fixture
-def testctx(svcconn):
-    svcconn.execute("create table if not exists testctx (id int primary key)")
-    svcconn.execute("delete from testctx")
-    return None
-
-
-def test_context_commit(conn_cls, testctx, conn, dsn):
+@pytest.mark.usefixtures("testctx")
+def test_context_commit(conn_cls, conn, dsn):
     with conn:
         with conn.cursor() as cur:
             cur.execute("insert into testctx values (42)")
@@ -144,7 +142,8 @@ def test_context_commit(conn_cls, testctx, conn, dsn):
             assert cur.fetchall() == [(42,)]
 
 
-def test_context_rollback(conn_cls, testctx, conn, dsn):
+@pytest.mark.usefixtures("testctx")
+def test_context_rollback(conn_cls, conn, dsn):
     with pytest.raises(ZeroDivisionError):
         with conn:
             with conn.cursor() as cur:
@@ -585,61 +584,6 @@ def test_server_cursor_factory(conn):
         assert isinstance(cur, MyServerCursor)
 
 
-@dataclass
-class ParamDef:
-    name: str
-    guc: str
-    values: List[Any]
-    non_default: str
-
-
-param_isolation = ParamDef(
-    name="isolation_level",
-    guc="isolation",
-    values=list(psycopg.IsolationLevel),
-    non_default="serializable",
-)
-param_read_only = ParamDef(
-    name="read_only",
-    guc="read_only",
-    values=[True, False],
-    non_default="on",
-)
-param_deferrable = ParamDef(
-    name="deferrable",
-    guc="deferrable",
-    values=[True, False],
-    non_default="on",
-)
-
-# Map Python values to Postgres values for the tx_params possible values
-tx_values_map = {
-    v.name.lower().replace("_", " "): v.value for v in psycopg.IsolationLevel
-}
-tx_values_map["on"] = True
-tx_values_map["off"] = False
-
-
-tx_params = [
-    param_isolation,
-    param_read_only,
-    pytest.param(param_deferrable, marks=pytest.mark.crdb_skip("deferrable")),
-]
-tx_params_isolation = [
-    pytest.param(
-        param_isolation,
-        id="isolation_level",
-        marks=pytest.mark.crdb("skip", reason="transaction isolation"),
-    ),
-    pytest.param(
-        param_read_only, id="read_only", marks=pytest.mark.crdb_skip("begin_read_only")
-    ),
-    pytest.param(
-        param_deferrable, id="deferrable", marks=pytest.mark.crdb_skip("deferrable")
-    ),
-]
-
-
 @pytest.mark.parametrize("param", tx_params)
 def test_transaction_param_default(conn, param):
     assert getattr(conn, param.name) is None
@@ -757,35 +701,6 @@ def test_set_transaction_param_strange(conn):
     assert conn.deferrable is False
 
 
-conninfo_params_timeout = [
-    (
-        "",
-        {"dbname": "mydb", "connect_timeout": None},
-        ({"dbname": "mydb"}, None),
-    ),
-    (
-        "",
-        {"dbname": "mydb", "connect_timeout": 1},
-        ({"dbname": "mydb", "connect_timeout": "1"}, 1),
-    ),
-    (
-        "dbname=postgres",
-        {},
-        ({"dbname": "postgres"}, None),
-    ),
-    (
-        "dbname=postgres connect_timeout=2",
-        {},
-        ({"dbname": "postgres", "connect_timeout": "2"}, 2),
-    ),
-    (
-        "postgresql:///postgres?connect_timeout=2",
-        {"connect_timeout": 10},
-        ({"dbname": "postgres", "connect_timeout": "10"}, 10),
-    ),
-]
-
-
 @pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
 def test_get_connection_params(conn_cls, dsn, kwargs, exp):
     params = conn_cls._get_connection_params(dsn, **kwargs)
index b58b3e5851f237942053fe768551cbf026c1bd73..4b1d797ebd1c21e2a4bda7e10ae340173de4e5e3 100644 (file)
@@ -11,13 +11,11 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
 from .utils import gc_collect
 from ._test_cursor import my_row_factory
-from .test_connection import tx_params, tx_params_isolation, tx_values_map
-from .test_connection import conninfo_params_timeout
-from .test_connection import testctx  # noqa: F401  # fixture
+from ._test_connection import tx_params, tx_params_isolation, tx_values_map
+from ._test_connection import conninfo_params_timeout
+from ._test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
-from .test_conninfo import fake_resolve  # noqa: F401
-
-pytestmark = pytest.mark.anyio
+from .test_conninfo import fake_resolve  # noqa: F401  # fixture
 
 
 async def test_connect(aconn_cls, dsn):
@@ -89,6 +87,7 @@ async def test_cursor_closed(aconn):
     with pytest.raises(psycopg.OperationalError):
         async with aconn.cursor("foo"):
             pass
+    with pytest.raises(psycopg.OperationalError):
         aconn.cursor("foo")
     with pytest.raises(psycopg.OperationalError):
         aconn.cursor()
@@ -112,9 +111,10 @@ async def test_connection_warn_close(aconn_cls, dsn, recwarn):
     conn = await aconn_cls.connect(dsn)
     try:
         await conn.execute("select wat")
-    except Exception:
+    except psycopg.ProgrammingError:
         pass
     del conn
+    gc_collect()
     assert "INERROR" in str(recwarn.pop(ResourceWarning).message)
 
     async with await aconn_cls.connect(dsn) as conn:
@@ -163,6 +163,8 @@ async def test_context_close(aconn):
 
 @pytest.mark.crdb_skip("pg_terminate_backend")
 async def test_context_inerror_rollback_no_clobber(aconn_cls, conn, dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg")
+
     with pytest.raises(ZeroDivisionError):
         async with await aconn_cls.connect(dsn) as conn2:
             await conn2.execute("select 1")
@@ -289,6 +291,9 @@ async def test_auto_transaction_fail(aconn):
         await cur.execute("meh")
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
 
+    with pytest.raises(psycopg.errors.InFailedSqlTransaction):
+        await cur.execute("select 1")
+
     await aconn.commit()
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
     await cur.execute("select * from foo")
@@ -760,17 +765,18 @@ async def test_cancel_closed(aconn):
     aconn.cancel()
 
 
-async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):  # noqa: F811
+@pytest.mark.usefixtures("fake_resolve")
+async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
     got = []
 
     def fake_connect_gen(conninfo, **kwargs):
         got.append(conninfo)
         1 / 0
 
-    monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
+    monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen)
 
     with pytest.raises(ZeroDivisionError):
-        await psycopg.AsyncConnection.connect("host=foo.com")
+        await aconn_cls.connect("host=foo.com")
 
     assert len(got) == 1
     want = {"host": "foo.com", "hostaddr": "1.1.1.1"}