From: Daniele Varrazzo Date: Tue, 7 Dec 2021 20:01:13 +0000 (+0100) Subject: Raise OperationalError on Connection.cursor() if closed X-Git-Tag: pool-3.1~91 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bfcf726e8bdae0e79361d9d6991a108f25f4f5d0;p=thirdparty%2Fpsycopg.git Raise OperationalError on Connection.cursor() if closed --- diff --git a/docs/news.rst b/docs/news.rst index 0748b9fdd..f3c13193b 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -19,8 +19,9 @@ Current release Psycopg 3.0.6 ^^^^^^^^^^^^^ -- `ServerCursor.close()` doesn't raise exceptions if the connection is closed +- Don't raise exceptions on `ServerCursor.close()` if the connection is closed (:ticket:`#173`). +- Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`). Psycopg 3.0.5 diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 652252d2a..7cf508428 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -411,13 +411,7 @@ class BaseConnection(Generic[Row]): Only used to implement internal commands such as "commit", with eventual arguments bound client-side. The cursor can do more complex stuff. """ - if self.pgconn.status != ConnStatus.OK: - if self.pgconn.status == ConnStatus.BAD: - raise e.OperationalError("the connection is closed") - raise e.InterfaceError( - f"cannot execute operations: the connection is" - f" in status {self.pgconn.status}" - ) + self._check_connection_ok() if isinstance(command, str): command = command.encode(pgconn_encoding(self.pgconn)) @@ -444,6 +438,17 @@ class BaseConnection(Generic[Row]): ) return result + def _check_connection_ok(self) -> None: + if self.pgconn.status == ConnStatus.OK: + return + + if self.pgconn.status == ConnStatus.BAD: + raise e.OperationalError("the connection is closed") + raise e.InterfaceError( + f"cannot execute operations: the connection is" + f" in status {self.pgconn.status}" + ) + def _start_query(self) -> PQGen[None]: """Generator to start a transaction if necessary.""" if self._autocommit: @@ -769,6 +774,8 @@ class Connection(BaseConnection[Row]): """ Return a new cursor to send commands and queries to the connection. """ + self._check_connection_ok() + if not row_factory: row_factory = self.row_factory @@ -798,12 +805,13 @@ class Connection(BaseConnection[Row]): binary: bool = False, ) -> Cursor[Row]: """Execute a query and return a cursor to read its results.""" - cur = self.cursor() - if binary: - cur.format = Format.BINARY - try: + cur = self.cursor() + if binary: + cur.format = Format.BINARY + return cur.execute(query, params, prepare=prepare) + except e.Error as ex: raise ex.with_traceback(None) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index d892a974f..b27a44dfe 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -216,6 +216,8 @@ class AsyncConnection(BaseConnection[Row]): """ Return a new `AsyncCursor` to send commands and queries to the connection. """ + self._check_connection_ok() + if not row_factory: row_factory = self.row_factory @@ -244,12 +246,13 @@ class AsyncConnection(BaseConnection[Row]): prepare: Optional[bool] = None, binary: bool = False, ) -> AsyncCursor[Row]: - cur = self.cursor() - if binary: - cur.format = Format.BINARY - try: + cur = self.cursor() + if binary: + cur.format = Format.BINARY + return await cur.execute(query, params, prepare=prepare) + except e.Error as ex: raise ex.with_traceback(None) diff --git a/tests/test_connection.py b/tests/test_connection.py index f60eb4641..84cb81f25 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -70,13 +70,14 @@ def test_connect_timeout(): def test_close(conn): assert not conn.closed assert not conn.broken + + cur = conn.cursor() + conn.close() assert conn.closed assert not conn.broken assert conn.pgconn.status == conn.ConnStatus.BAD - cur = conn.cursor() - conn.close() assert conn.closed assert conn.pgconn.status == conn.ConnStatus.BAD @@ -97,6 +98,15 @@ def test_broken(conn): assert conn.broken +def test_cursor_closed(conn): + conn.close() + with pytest.raises(psycopg.OperationalError): + with conn.cursor("foo"): + pass + with pytest.raises(psycopg.OperationalError): + conn.cursor() + + def test_connection_warn_close(dsn, recwarn): conn = Connection.connect(dsn) conn.close() diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index a32cc2f3b..705b5f931 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -72,13 +72,14 @@ async def test_connect_timeout(): async def test_close(aconn): assert not aconn.closed assert not aconn.broken + + cur = aconn.cursor() + await aconn.close() assert aconn.closed assert not aconn.broken assert aconn.pgconn.status == aconn.ConnStatus.BAD - cur = aconn.cursor() - await aconn.close() assert aconn.closed assert aconn.pgconn.status == aconn.ConnStatus.BAD @@ -99,6 +100,16 @@ async def test_broken(aconn): assert aconn.broken +async def test_cursor_closed(aconn): + await aconn.close() + with pytest.raises(psycopg.OperationalError): + async with aconn.cursor("foo"): + pass + aconn.cursor("foo") + with pytest.raises(psycopg.OperationalError): + aconn.cursor() + + async def test_connection_warn_close(dsn, recwarn): conn = await AsyncConnection.connect(dsn) await conn.close()