]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added connection.status property
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 11:36:16 +0000 (23:36 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 11:36:16 +0000 (23:36 +1200)
Exposing ConnStatus and TransactionStatus on the connection object.

psycopg3/connection.py
tests/test_async_connection.py
tests/test_connection.py

index 6d666cf1623cc00566b07f2ed16a263f0ea1dcfa..094f89695e33549aa8f1e49efc3a82831d575919 100644 (file)
@@ -47,6 +47,10 @@ class BaseConnection:
     ProgrammingError = e.ProgrammingError
     NotSupportedError = e.NotSupportedError
 
+    # Enums useful for the connection
+    ConnStatus = pq.ConnStatus
+    TransactionStatus = pq.TransactionStatus
+
     def __init__(self, pgconn: pq.PGconn):
         self.pgconn = pgconn
         self.cursor_factory = cursor.BaseCursor
@@ -60,7 +64,11 @@ class BaseConnection:
 
     @property
     def closed(self) -> bool:
-        return self.pgconn.status == pq.ConnStatus.BAD
+        return self.status == self.ConnStatus.BAD
+
+    @property
+    def status(self) -> pq.ConnStatus:
+        return self.pgconn.status
 
     def cursor(
         self, name: Optional[str] = None, binary: bool = False
@@ -116,7 +124,7 @@ class BaseConnection:
         conn = pq.PGconn.connect_start(conninfo.encode("utf8"))
         logger.debug("connection started, status %s", conn.status.name)
         while 1:
-            if conn.status == pq.ConnStatus.BAD:
+            if conn.status == cls.ConnStatus.BAD:
                 raise e.OperationalError(
                     f"connection is bad: {pq.error_message(conn)}"
                 )
index 463a85404aac729f5ecf28a45fee0303cb357de8..b639ba9b98ed54b267dc7166179e0f400946c911 100644 (file)
@@ -4,9 +4,9 @@ import psycopg3
 from psycopg3 import AsyncConnection
 
 
-def test_connect(pq, dsn, loop):
+def test_connect(dsn, loop):
     conn = loop.run_until_complete(AsyncConnection.connect(dsn))
-    assert conn.pgconn.status == pq.ConnStatus.OK
+    assert conn.status == conn.ConnStatus.OK
 
 
 def test_connect_bad(loop):
@@ -14,22 +14,24 @@ def test_connect_bad(loop):
         loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb"))
 
 
-def test_close(pq, aconn):
+def test_close(aconn):
     assert not aconn.closed
     aconn.close()
     assert aconn.closed
+    assert aconn.status == aconn.ConnStatus.BAD
     aconn.close()
     assert aconn.closed
+    assert aconn.status == aconn.ConnStatus.BAD
 
 
-def test_commit(loop, pq, aconn):
+def test_commit(loop, aconn):
     aconn.pgconn.exec_(b"drop table if exists foo")
     aconn.pgconn.exec_(b"create table foo (id int primary key)")
     aconn.pgconn.exec_(b"begin")
-    assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
     res = aconn.pgconn.exec_(b"insert into foo values (1)")
     loop.run_until_complete(aconn.commit())
-    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
     res = aconn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) == b"1"
 
@@ -38,14 +40,14 @@ def test_commit(loop, pq, aconn):
         loop.run_until_complete(aconn.commit())
 
 
-def test_rollback(loop, pq, aconn):
+def test_rollback(loop, aconn):
     aconn.pgconn.exec_(b"drop table if exists foo")
     aconn.pgconn.exec_(b"create table foo (id int primary key)")
     aconn.pgconn.exec_(b"begin")
-    assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
     res = aconn.pgconn.exec_(b"insert into foo values (1)")
     loop.run_until_complete(aconn.rollback())
-    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
     res = aconn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) is None
 
index 61f3eed85b4fcd6c552cee0262f886b0fce17378..562dcfab77b2852a5a19e7e242afa7671e74a3bf 100644 (file)
@@ -4,9 +4,9 @@ import psycopg3
 from psycopg3 import Connection
 
 
-def test_connect(pq, dsn):
+def test_connect(dsn):
     conn = Connection.connect(dsn)
-    assert conn.pgconn.status == pq.ConnStatus.OK
+    assert conn.status == conn.ConnStatus.OK
 
 
 def test_connect_bad():
@@ -14,22 +14,24 @@ def test_connect_bad():
         Connection.connect("dbname=nosuchdb")
 
 
-def test_close(pq, conn):
+def test_close(conn):
     assert not conn.closed
     conn.close()
     assert conn.closed
+    assert conn.status == conn.ConnStatus.BAD
     conn.close()
     assert conn.closed
+    assert conn.status == conn.ConnStatus.BAD
 
 
-def test_commit(pq, conn):
+def test_commit(conn):
     conn.pgconn.exec_(b"drop table if exists foo")
     conn.pgconn.exec_(b"create table foo (id int primary key)")
     conn.pgconn.exec_(b"begin")
-    assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
     res = conn.pgconn.exec_(b"insert into foo values (1)")
     conn.commit()
-    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
     res = conn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) == b"1"
 
@@ -38,14 +40,14 @@ def test_commit(pq, conn):
         conn.commit()
 
 
-def test_rollback(pq, conn):
+def test_rollback(conn):
     conn.pgconn.exec_(b"drop table if exists foo")
     conn.pgconn.exec_(b"create table foo (id int primary key)")
     conn.pgconn.exec_(b"begin")
-    assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
     res = conn.pgconn.exec_(b"insert into foo values (1)")
     conn.rollback()
-    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
     res = conn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) is None