]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Raise OperationalError on Connection.cursor() if closed
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Dec 2021 20:01:13 +0000 (21:01 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Dec 2021 20:32:01 +0000 (21:32 +0100)
docs/news.rst
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_connection.py
tests/test_connection_async.py

index 0748b9fdd5777cd35d1d9aca0fa92f8a5bdf88bc..f3c13193b05466b6fa3c8ed46c1a92093d02dc58 100644 (file)
@@ -19,8 +19,9 @@ Current release
 Psycopg 3.0.6
 ^^^^^^^^^^^^^
 
-- `ServerCursor.close()` doesn't raise exceptions if the connection is closed
+- Don't raise exceptions on `ServerCursor.close()` if the connection is closed
   (:ticket:`#173`).
+- Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`).
 
 
 Psycopg 3.0.5
index 652252d2ad7aff63ba1c01b85807f994e55073fb..7cf508428174731073ee8f439fd4ec808887436c 100644 (file)
@@ -411,13 +411,7 @@ class BaseConnection(Generic[Row]):
         Only used to implement internal commands such as "commit", with eventual
         arguments bound client-side. The cursor can do more complex stuff.
         """
-        if self.pgconn.status != ConnStatus.OK:
-            if self.pgconn.status == ConnStatus.BAD:
-                raise e.OperationalError("the connection is closed")
-            raise e.InterfaceError(
-                f"cannot execute operations: the connection is"
-                f" in status {self.pgconn.status}"
-            )
+        self._check_connection_ok()
 
         if isinstance(command, str):
             command = command.encode(pgconn_encoding(self.pgconn))
@@ -444,6 +438,17 @@ class BaseConnection(Generic[Row]):
                 )
         return result
 
+    def _check_connection_ok(self) -> None:
+        if self.pgconn.status == ConnStatus.OK:
+            return
+
+        if self.pgconn.status == ConnStatus.BAD:
+            raise e.OperationalError("the connection is closed")
+        raise e.InterfaceError(
+            f"cannot execute operations: the connection is"
+            f" in status {self.pgconn.status}"
+        )
+
     def _start_query(self) -> PQGen[None]:
         """Generator to start a transaction if necessary."""
         if self._autocommit:
@@ -769,6 +774,8 @@ class Connection(BaseConnection[Row]):
         """
         Return a new cursor to send commands and queries to the connection.
         """
+        self._check_connection_ok()
+
         if not row_factory:
             row_factory = self.row_factory
 
@@ -798,12 +805,13 @@ class Connection(BaseConnection[Row]):
         binary: bool = False,
     ) -> Cursor[Row]:
         """Execute a query and return a cursor to read its results."""
-        cur = self.cursor()
-        if binary:
-            cur.format = Format.BINARY
-
         try:
+            cur = self.cursor()
+            if binary:
+                cur.format = Format.BINARY
+
             return cur.execute(query, params, prepare=prepare)
+
         except e.Error as ex:
             raise ex.with_traceback(None)
 
index d892a974f7f9e17b407b68154dca704eac86dfc5..b27a44dfed4bb3bc73e6e4c5e73105d87582998d 100644 (file)
@@ -216,6 +216,8 @@ class AsyncConnection(BaseConnection[Row]):
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
         """
+        self._check_connection_ok()
+
         if not row_factory:
             row_factory = self.row_factory
 
@@ -244,12 +246,13 @@ class AsyncConnection(BaseConnection[Row]):
         prepare: Optional[bool] = None,
         binary: bool = False,
     ) -> AsyncCursor[Row]:
-        cur = self.cursor()
-        if binary:
-            cur.format = Format.BINARY
-
         try:
+            cur = self.cursor()
+            if binary:
+                cur.format = Format.BINARY
+
             return await cur.execute(query, params, prepare=prepare)
+
         except e.Error as ex:
             raise ex.with_traceback(None)
 
index f60eb4641ecbe5c14e5d7b70dde035a92635d357..84cb81f2577a2941f9eaa50da2bfed692faa06a5 100644 (file)
@@ -70,13 +70,14 @@ def test_connect_timeout():
 def test_close(conn):
     assert not conn.closed
     assert not conn.broken
+
+    cur = conn.cursor()
+
     conn.close()
     assert conn.closed
     assert not conn.broken
     assert conn.pgconn.status == conn.ConnStatus.BAD
 
-    cur = conn.cursor()
-
     conn.close()
     assert conn.closed
     assert conn.pgconn.status == conn.ConnStatus.BAD
@@ -97,6 +98,15 @@ def test_broken(conn):
     assert conn.broken
 
 
+def test_cursor_closed(conn):
+    conn.close()
+    with pytest.raises(psycopg.OperationalError):
+        with conn.cursor("foo"):
+            pass
+    with pytest.raises(psycopg.OperationalError):
+        conn.cursor()
+
+
 def test_connection_warn_close(dsn, recwarn):
     conn = Connection.connect(dsn)
     conn.close()
index a32cc2f3b28d8270a84cdb5f92f6e97e5faa071f..705b5f931468070bd26c1829ff8865f60bdffba3 100644 (file)
@@ -72,13 +72,14 @@ async def test_connect_timeout():
 async def test_close(aconn):
     assert not aconn.closed
     assert not aconn.broken
+
+    cur = aconn.cursor()
+
     await aconn.close()
     assert aconn.closed
     assert not aconn.broken
     assert aconn.pgconn.status == aconn.ConnStatus.BAD
 
-    cur = aconn.cursor()
-
     await aconn.close()
     assert aconn.closed
     assert aconn.pgconn.status == aconn.ConnStatus.BAD
@@ -99,6 +100,16 @@ async def test_broken(aconn):
     assert aconn.broken
 
 
+async def test_cursor_closed(aconn):
+    await aconn.close()
+    with pytest.raises(psycopg.OperationalError):
+        async with aconn.cursor("foo"):
+            pass
+        aconn.cursor("foo")
+    with pytest.raises(psycopg.OperationalError):
+        aconn.cursor()
+
+
 async def test_connection_warn_close(dsn, recwarn):
     conn = await AsyncConnection.connect(dsn)
     await conn.close()