]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Guard for null connections in exec functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 08:13:58 +0000 (20:13 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 08:13:58 +0000 (20:13 +1200)
psycopg3/pq/pq_ctypes.py
tests/pq/test_exec.py

index ae7b8d7b0427fe30991d0813c35c14154175c275..7cb91c0035403c0ce57bf322b82a8c08647b57b6 100644 (file)
@@ -75,8 +75,7 @@ class PGconn:
 
     @property
     def info(self) -> List["ConninfoOption"]:
-        if not self.pgconn_ptr:
-            raise PQerror("the connection is closed")
+        self._ensure_pgconn()
         opts = impl.PQconninfo(self.pgconn_ptr)
         if not opts:
             raise MemoryError("couldn't allocate connection info")
@@ -86,8 +85,7 @@ class PGconn:
             impl.PQconninfoFree(opts)
 
     def reset(self) -> None:
-        if not self.pgconn_ptr:
-            raise PQerror("the connection is no more available")
+        self._ensure_pgconn()
         impl.PQreset(self.pgconn_ptr)
 
     def reset_start(self) -> None:
@@ -149,8 +147,7 @@ class PGconn:
         return TransactionStatus(rv)
 
     def parameter_status(self, name: bytes) -> Optional[bytes]:
-        if not self.pgconn_ptr:
-            raise PQerror("the connection is closed")
+        self._ensure_pgconn()
         return impl.PQparameterStatus(self.pgconn_ptr, name)
 
     @property
@@ -188,6 +185,7 @@ class PGconn:
     def exec_(self, command: bytes) -> "PGresult":
         if not isinstance(command, bytes):
             raise TypeError(f"bytes expected, got {type(command)} instead")
+        self._ensure_pgconn()
         rv = impl.PQexec(self.pgconn_ptr, command)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
@@ -210,6 +208,7 @@ class PGconn:
         args = self._query_params_args(
             command, param_values, param_types, param_formats, result_format
         )
+        self._ensure_pgconn()
         rv = impl.PQexecParams(*args)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
@@ -226,6 +225,7 @@ class PGconn:
         args = self._query_params_args(
             command, param_values, param_types, param_formats, result_format
         )
+        self._ensure_pgconn()
         if not impl.PQsendQueryParams(*args):
             raise PQerror(
                 f"sending query and params failed: {error_message(self)}"
@@ -304,6 +304,7 @@ class PGconn:
             nparams = len(param_types)
             atypes = (impl.Oid * nparams)(*param_types)
 
+        self._ensure_pgconn()
         rv = impl.PQprepare(self.pgconn_ptr, name, command, nparams, atypes)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
@@ -338,6 +339,7 @@ class PGconn:
                 )
             aformats = (c_int * nparams)(*param_formats)
 
+        self._ensure_pgconn()
         rv = impl.PQexecPrepared(
             self.pgconn_ptr,
             name,
@@ -354,6 +356,7 @@ class PGconn:
     def describe_prepared(self, name: bytes) -> "PGresult":
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+        self._ensure_pgconn()
         rv = impl.PQdescribePrepared(self.pgconn_ptr, name)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
@@ -362,6 +365,7 @@ class PGconn:
     def describe_portal(self, name: bytes) -> "PGresult":
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+        self._ensure_pgconn()
         rv = impl.PQdescribePortal(self.pgconn_ptr, name)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
@@ -427,6 +431,10 @@ class PGconn:
             raise PQerror("the connection is closed")
         return bool(func(self.pgconn_ptr))
 
+    def _ensure_pgconn(self) -> None:
+        if not self.pgconn_ptr:
+            raise PQerror("the connection is closed")
+
 
 class PGresult:
     __slots__ = ("pgresult_ptr",)
index af8322e3ed8e00eac89ba40bb3300b0e0d598a42..d575d6b063957b2530ea8b94453c6c1a9dde3d62 100644 (file)
@@ -2,6 +2,8 @@
 
 import pytest
 
+import psycopg3
+
 
 def test_exec_none(pq, pgconn):
     with pytest.raises(TypeError):
@@ -11,12 +13,18 @@ def test_exec_none(pq, pgconn):
 def test_exec(pq, pgconn):
     res = pgconn.exec_(b"select 'hel' || 'lo'")
     assert res.get_value(0, 0) == b"hello"
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.exec_(b"select 'hello'")
 
 
 def test_exec_params(pq, pgconn):
     res = pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"])
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"8"
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"])
 
 
 def test_exec_params_empty(pq, pgconn):
@@ -81,6 +89,12 @@ def test_prepare(pq, pgconn):
     res = pgconn.exec_prepared(b"prep", [b"3", b"5"])
     assert res.get_value(0, 0) == b"8"
 
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.prepare(b"prep", b"select $1::int + $2::int")
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.exec_prepared(b"prep", [b"3", b"5"])
+
 
 def test_prepare_types(pq, pgconn):
     res = pgconn.prepare(b"prep", b"select $1 + $2", [23, 23])
@@ -131,3 +145,7 @@ def test_describe_portal(pq, pgconn):
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
     assert res.nfields == 1
     assert res.fname(0) == b"foo"
+
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.describe_portal(b"cur")