]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added .close() and .closed on connection and cursor
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 09:12:43 +0000 (21:12 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 09:12:43 +0000 (21:12 +1200)
psycopg3/connection.py
psycopg3/cursor.py
tests/test_async_connection.py
tests/test_async_cursor.py
tests/test_connection.py
tests/test_cursor.py

index 1ce4f62b2fb42a8938928cbd8d34b5adcb1d3307..6d666cf1623cc00566b07f2ed16a263f0ea1dcfa 100644 (file)
@@ -58,6 +58,10 @@ class BaseConnection:
     def close(self) -> None:
         self.pgconn.finish()
 
+    @property
+    def closed(self) -> bool:
+        return self.pgconn.status == pq.ConnStatus.BAD
+
     def cursor(
         self, name: Optional[str] = None, binary: bool = False
     ) -> cursor.BaseCursor:
index d4309b7c2ea41136e25ea290ac66601b25c9108e..2cf593e2710f98a9785c6bd29a463600f2fea91e 100644 (file)
@@ -65,6 +65,17 @@ class BaseCursor:
         self.loaders: LoadersMap = {}
         self._reset()
         self.arraysize = 1
+        self._closed = False
+
+    def __del__(self):
+        self.close()
+
+    def close(self) -> None:
+        self._closed = True
+
+    @property
+    def closed(self) -> bool:
+        return self._closed
 
     def _reset(self) -> None:
         from .adapt import Transformer
@@ -114,17 +125,20 @@ class BaseCursor:
     def _execute_send(
         self, query: Query, vars: Optional[Params]
     ) -> "QueryGen":
-        # Implement part of execute() before waiting common to sync and async
+        """
+        Implement part of execute() before waiting common to sync and async
+        """
+        if self.closed:
+            raise e.OperationalError("the cursor is closed")
+
+        if self.conn.closed:
+            raise e.OperationalError("the connection is closed")
+
         if self.conn.pgconn.status != ConnStatus.OK:
-            if self.conn.pgconn.status == ConnStatus.BAD:
-                raise e.InterfaceError(
-                    "cannot execute operations: the connection is closed"
-                )
-            else:
-                raise e.InterfaceError(
-                    f"cannot execute operations: the connection is"
-                    f" in status {self.conn.pgconn.status}"
-                )
+            raise e.InterfaceError(
+                f"cannot execute operations: the connection is"
+                f" in status {self.conn.pgconn.status}"
+            )
 
         self._reset()
 
@@ -162,7 +176,9 @@ class BaseCursor:
         return self.conn._exec_gen(self.conn.pgconn)
 
     def _execute_results(self, results: List[PGresult]) -> None:
-        # Implement part of execute() after waiting common to sync and async
+        """
+        Implement part of execute() after waiting common to sync and async
+        """
         if not results:
             raise e.InternalError("got no result from the query")
 
index e27b981e6992f59c80a0a27191cb8d06fcd308f0..463a85404aac729f5ecf28a45fee0303cb357de8 100644 (file)
@@ -14,6 +14,14 @@ def test_connect_bad(loop):
         loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb"))
 
 
+def test_close(pq, aconn):
+    assert not aconn.closed
+    aconn.close()
+    assert aconn.closed
+    aconn.close()
+    assert aconn.closed
+
+
 def test_commit(loop, pq, aconn):
     aconn.pgconn.exec_(b"drop table if exists foo")
     aconn.pgconn.exec_(b"create table foo (id int primary key)")
@@ -25,6 +33,10 @@ def test_commit(loop, pq, aconn):
     res = aconn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) == b"1"
 
+    aconn.close()
+    with pytest.raises(psycopg3.OperationalError):
+        loop.run_until_complete(aconn.commit())
+
 
 def test_rollback(loop, pq, aconn):
     aconn.pgconn.exec_(b"drop table if exists foo")
@@ -37,6 +49,10 @@ def test_rollback(loop, pq, aconn):
     res = aconn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) is None
 
+    aconn.close()
+    with pytest.raises(psycopg3.OperationalError):
+        loop.run_until_complete(aconn.rollback())
+
 
 def test_get_encoding(aconn, loop):
     cur = aconn.cursor()
index e896be4b63c68b2d635ff3467aaed667fe312ab7..7c258778e1cea521810dc9bdbf49434559451928 100644 (file)
@@ -1,4 +1,21 @@
-def test_execute_many(aconn, loop):
+import pytest
+import psycopg3
+
+
+def test_close(aconn, loop):
+    cur = aconn.cursor()
+    assert not cur.closed
+    cur.close()
+    assert cur.closed
+
+    with pytest.raises(psycopg3.OperationalError):
+        loop.run_until_complete(cur.execute("select 'foo'"))
+
+    cur.close()
+    assert cur.closed
+
+
+def test_execute_many_results(aconn, loop):
     cur = aconn.cursor()
     rv = loop.run_until_complete(cur.execute("select 'foo'; select 'bar'"))
     assert rv is cur
index 89d00fb3d2563d1da9fe3a5c0d1acfbe4268fc16..61f3eed85b4fcd6c552cee0262f886b0fce17378 100644 (file)
@@ -14,6 +14,14 @@ def test_connect_bad():
         Connection.connect("dbname=nosuchdb")
 
 
+def test_close(pq, conn):
+    assert not conn.closed
+    conn.close()
+    assert conn.closed
+    conn.close()
+    assert conn.closed
+
+
 def test_commit(pq, conn):
     conn.pgconn.exec_(b"drop table if exists foo")
     conn.pgconn.exec_(b"create table foo (id int primary key)")
@@ -25,6 +33,10 @@ def test_commit(pq, conn):
     res = conn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) == b"1"
 
+    conn.close()
+    with pytest.raises(psycopg3.OperationalError):
+        conn.commit()
+
 
 def test_rollback(pq, conn):
     conn.pgconn.exec_(b"drop table if exists foo")
@@ -37,6 +49,10 @@ def test_rollback(pq, conn):
     res = conn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) is None
 
+    conn.close()
+    with pytest.raises(psycopg3.OperationalError):
+        conn.rollback()
+
 
 def test_get_encoding(conn):
     (enc,) = conn.cursor().execute("show client_encoding").fetchone()
index faa46a42ccaebf51dbf296628979626eeef9423a..962bed7c2605b99327f8c24e8a77601e868fde61 100644 (file)
@@ -1,7 +1,21 @@
 import pytest
+import psycopg3
 
 
-def test_execute_many(conn):
+def test_close(conn):
+    cur = conn.cursor()
+    assert not cur.closed
+    cur.close()
+    assert cur.closed
+
+    with pytest.raises(psycopg3.OperationalError):
+        cur.execute("select 'foo'")
+
+    cur.close()
+    assert cur.closed
+
+
+def test_execute_many_results(conn):
     cur = conn.cursor()
     rv = cur.execute("select 'foo'; select 'bar'")
     assert rv is cur