]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't throw an error on context exit if the connection is closed
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Mar 2021 01:24:09 +0000 (02:24 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/connection.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/test_connection.py
tests/test_connection_async.py

index 7cdb90f81d17f9d7b633e9ecf3b1bafd263a829c..d95434b2e79c551b04dd6fdfb59c82c76f72fa88 100644 (file)
@@ -465,17 +465,19 @@ class Connection(BaseConnection):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
+        if self.closed:
+            return
+
         if exc_type:
             # try to rollback, but if there are problems (connection in a bad
             # state) just warn without clobbering the exception bubbling up.
-            if not self.closed:
-                try:
-                    self.rollback()
-                except Exception as exc2:
-                    warnings.warn(
-                        f"error rolling back the transaction on {self}: {exc2}",
-                        RuntimeWarning,
-                    )
+            try:
+                self.rollback()
+            except Exception as exc2:
+                warnings.warn(
+                    f"error rolling back the transaction on {self}: {exc2}",
+                    RuntimeWarning,
+                )
         else:
             self.commit()
 
@@ -641,17 +643,19 @@ class AsyncConnection(BaseConnection):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
+        if self.closed:
+            return
+
         if exc_type:
             # try to rollback, but if there are problems (connection in a bad
             # state) just warn without clobbering the exception bubbling up.
-            if not self.closed:
-                try:
-                    await self.rollback()
-                except Exception as exc2:
-                    warnings.warn(
-                        f"error rolling back the transaction on {self}: {exc2}",
-                        RuntimeWarning,
-                    )
+            try:
+                await self.rollback()
+            except Exception as exc2:
+                warnings.warn(
+                    f"error rolling back the transaction on {self}: {exc2}",
+                    RuntimeWarning,
+                )
         else:
             await self.commit()
 
index fec574c3ca67b69c8a7b55dce1204d7960688b37..1c3177814d916b139356b6914f50460d71eedde8 100644 (file)
@@ -241,11 +241,10 @@ def test_queue_timeout_override(dsn):
 
 def test_broken_reconnect(dsn):
     with pool.ConnectionPool(dsn, minconn=1) as p:
-        with pytest.raises(psycopg3.OperationalError):
-            with p.connection() as conn:
-                with conn.execute("select pg_backend_pid()") as cur:
-                    (pid1,) = cur.fetchone()
-                conn.close()
+        with p.connection() as conn:
+            with conn.execute("select pg_backend_pid()") as cur:
+                (pid1,) = cur.fetchone()
+            conn.close()
 
         with p.connection() as conn2:
             with conn2.execute("select pg_backend_pid()") as cur:
index 810643f8dd427b93d8bde9777d9fd1716a02939d..fa46ba54351db82edf2319cd5e703e071981d541 100644 (file)
@@ -257,11 +257,10 @@ async def test_queue_timeout_override(dsn):
 
 async def test_broken_reconnect(dsn):
     async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
-        with pytest.raises(psycopg3.OperationalError):
-            async with p.connection() as conn:
-                cur = await conn.execute("select pg_backend_pid()")
-                (pid1,) = await cur.fetchone()
-                await conn.close()
+        async with p.connection() as conn:
+            cur = await conn.execute("select pg_backend_pid()")
+            (pid1,) = await cur.fetchone()
+            await conn.close()
 
         async with p.connection() as conn2:
             cur = await conn2.execute("select pg_backend_pid()")
index cdfbe2e0dfdb715a572b4d27d78c758460894944..b97570d49b828e38fc45adbd56ecc6a0b2661a2b 100644 (file)
@@ -152,6 +152,12 @@ def test_context_rollback(conn, dsn):
                 cur.execute("select * from textctx")
 
 
+def test_context_close(conn):
+    with conn:
+        conn.execute("select 1")
+        conn.close()
+
+
 def test_context_rollback_no_clobber(conn, dsn, recwarn):
     with pytest.raises(ZeroDivisionError):
         with psycopg3.connect(dsn) as conn2:
index 2c303e1409c38ffbed81e62343cfaf19632df63e..e4406d1a42ed89998e2426040a73024172e417b6 100644 (file)
@@ -157,6 +157,12 @@ async def test_context_rollback(aconn, dsn):
                 await cur.execute("select * from textctx")
 
 
+async def test_context_close(aconn):
+    async with aconn:
+        await aconn.execute("select 1")
+        await aconn.close()
+
+
 async def test_context_rollback_no_clobber(conn, dsn, recwarn):
     with pytest.raises(ZeroDivisionError):
         async with await psycopg3.AsyncConnection.connect(dsn) as conn2: