]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't raise an exception using cur.description with closed connection
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Dec 2021 20:31:07 +0000 (21:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Dec 2021 21:21:52 +0000 (22:21 +0100)
Close #172.

docs/news.rst
psycopg/psycopg/_column.py
psycopg/psycopg/cursor.py
psycopg/psycopg/server_cursor.py
tests/test_cursor.py

index f3c13193b05466b6fa3c8ed46c1a92093d02dc58..8d254b12efc1332503bef953813e6de32e690beb 100644 (file)
@@ -19,6 +19,8 @@ Current release
 Psycopg 3.0.6
 ^^^^^^^^^^^^^
 
+- Allow to use `Cursor.description` if the connection is closed
+  (:ticket:`#172`).
 - Don't raise exceptions on `ServerCursor.close()` if the connection is closed
   (:ticket:`#173`).
 - Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`).
index 60773f9de86d713ac4f73efa5ed4d985b6f4147c..909106d2e20ab6a21a32ed869dd7d7149e6a19e3 100644 (file)
@@ -8,7 +8,6 @@ from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
 from operator import attrgetter
 
 from . import errors as e
-from ._encodings import pgconn_encoding
 
 if TYPE_CHECKING:
     from .cursor import BaseCursor
@@ -32,7 +31,7 @@ class Column(Sequence[Any]):
         if not fname:
             raise e.InterfaceError(f"no name available for column {index}")
 
-        self._name = fname.decode(pgconn_encoding(cursor._pgconn))
+        self._name = fname.decode(cursor._encoding)
 
         self._data = ColumnData(
             ftype=res.ftype(index),
index 93f5d8fa968b233d04c5b2de5a6ab2a5ce449b98..298638649b337321cd096f1d80c388c0af502710 100644 (file)
@@ -53,7 +53,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         __slots__ = """
             _conn format _adapters arraysize _closed _results pgresult _pos
             _iresult _rowcount _query _tx _last_query _row_factory _make_row
-            _pgconn
+            _pgconn _encoding
             __weakref__
             """.split()
 
@@ -80,6 +80,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._iresult = 0
         self._rowcount = -1
         self._query: Optional[PostgresQuery]
+        self._encoding = "utf-8"
         if reset_query:
             self._query = None
 
@@ -254,9 +255,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             self._send_prepare(name, pgq)
             (result,) = yield from execute(self._pgconn)
             if result.status == ExecStatus.FATAL_ERROR:
-                raise e.error_from_result(
-                    result, encoding=pgconn_encoding(self._pgconn)
-                )
+                raise e.error_from_result(result, encoding=self._encoding)
             self._send_query_prepared(name, pgq, binary=binary)
 
         # run the query
@@ -323,6 +322,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             raise e.InterfaceError("the cursor is closed")
 
         self._reset()
+        self._encoding = pgconn_encoding(self._pgconn)
         if not self._last_query or (self._last_query is not query):
             self._last_query = None
             self._tx = adapt.Transformer(self)
@@ -429,9 +429,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         statuses = {res.status for res in results}
         badstats = statuses.difference(self._status_ok)
         if results[-1].status == ExecStatus.FATAL_ERROR:
-            raise e.error_from_result(
-                results[-1], encoding=pgconn_encoding(self._pgconn)
-            )
+            raise e.error_from_result(results[-1], encoding=self._encoding)
         elif statuses.intersection(self._status_copy):
             raise e.ProgrammingError(
                 "COPY cannot be used with this method; use copy() insead"
@@ -476,9 +474,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         if status in (ExecStatus.COPY_IN, ExecStatus.COPY_OUT):
             return
         elif status == ExecStatus.FATAL_ERROR:
-            raise e.error_from_result(
-                result, encoding=pgconn_encoding(self._pgconn)
-            )
+            raise e.error_from_result(result, encoding=self._encoding)
         else:
             raise e.ProgrammingError(
                 "copy() should be used only with COPY ... TO STDOUT or COPY ..."
index 28f198bac9ec8aca27a0bc47bef3b8bc416bce7a..b4289a67373375f5348acfcd1338764d4a2a968c 100644 (file)
@@ -15,7 +15,6 @@ from .abc import ConnectionType, Query, Params, PQGen
 from .rows import Row, RowFactory, AsyncRowFactory
 from .cursor import BaseCursor, Cursor, execute
 from .cursor_async import AsyncCursor
-from ._encodings import pgconn_encoding
 
 if TYPE_CHECKING:
     from .connection import Connection
@@ -59,8 +58,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
     ) -> PQGen[None]:
         """Generator implementing `ServerCursor.execute()`."""
 
-        conn = cur._conn
-        query = self._make_declare_statement(conn, query)
+        query = self._make_declare_statement(cur, query)
 
         # If the cursor is being reused, the previous one must be closed.
         if self.described:
@@ -70,7 +68,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         yield from cur._start_query(query)
         pgq = cur._convert_query(query, params)
         cur._execute_send(pgq, no_pqexec=True)
-        results = yield from execute(conn.pgconn)
+        results = yield from execute(cur._conn.pgconn)
         if results[-1].status != pq.ExecStatus.COMMAND_OK:
             cur._raise_from_results(results)
 
@@ -87,9 +85,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         self, cur: BaseCursor[ConnectionType, Row]
     ) -> PQGen[None]:
         conn = cur._conn
-        conn.pgconn.send_describe_portal(
-            self.name.encode(pgconn_encoding(conn.pgconn))
-        )
+        conn.pgconn.send_describe_portal(self.name.encode(cur._encoding))
         results = yield from execute(conn.pgconn)
         cur._execute_results(results, format=self._format)
         self.described = True
@@ -155,11 +151,11 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         yield from cur._conn._exec_command(query)
 
     def _make_declare_statement(
-        self, conn: ConnectionType, query: Query
+        self, cur: BaseCursor[ConnectionType, Row], query: Query
     ) -> sql.Composable:
 
         if isinstance(query, bytes):
-            query = query.decode(pgconn_encoding(conn.pgconn))
+            query = query.decode(cur._encoding)
         if not isinstance(query, sql.Composable):
             query = sql.SQL(query)
 
index e850a2cb3f318d763c22646deff75dba6580ce0f..0874aede972e2ab07f568cc151febb2d38143fd0 100644 (file)
@@ -679,6 +679,16 @@ class TestColumn:
         assert cur.description == []
         assert cur.fetchall() == [()]
 
+    def test_description_closed_connection(self, conn):
+        # If we have reasons to break this test we will (e.g. we really need
+        # the connection). In #172 it fails just by accident.
+        cur = conn.execute("select 1::int4 as foo")
+        conn.close()
+        assert len(cur.description) == 1
+        col = cur.description[0]
+        assert col.name == "foo"
+        assert col.type_code == 23
+
 
 def test_str(conn):
     cur = conn.cursor()