]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added connection autocommit
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 20 May 2020 07:04:13 +0000 (19:04 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 20 May 2020 07:04:13 +0000 (19:04 +1200)
psycopg3/connection.py
tests/test_async_connection.py
tests/test_connection.py

index 522fa1ce362a38a906ee6684eb73ebc61437abbb..9807cb273579e93c2dedb1b7ecc8d0aefa1fa073 100644 (file)
@@ -64,6 +64,7 @@ class BaseConnection:
     def __init__(self, pgconn: pq.proto.PGconn):
         self.pgconn = pgconn
         self.cursor_factory = cursor.BaseCursor
+        self._autocommit = False
         self.dumpers: proto.DumpersMap = {}
         self.loaders: proto.LoadersMap = {}
         # name of the postgres encoding (in bytes)
@@ -77,6 +78,20 @@ class BaseConnection:
     def status(self) -> pq.ConnStatus:
         return self.pgconn.status
 
+    @property
+    def autocommit(self) -> bool:
+        return self._autocommit
+
+    @autocommit.setter
+    def autocommit(self, value: bool) -> None:
+        status = self.pgconn.transaction_status
+        if status != TransactionStatus.IDLE:
+            raise e.ProgrammingError(
+                "can't change autocommit state: connection in"
+                f" transaction status {TransactionStatus(status).name}"
+            )
+        self._autocommit = value
+
     def cursor(
         self, name: Optional[str] = None, binary: bool = False
     ) -> cursor.BaseCursor:
@@ -136,14 +151,20 @@ class Connection(BaseConnection):
 
     @classmethod
     def connect(
-        cls, conninfo: Optional[str] = None, **kwargs: Any,
+        cls,
+        conninfo: Optional[str] = None,
+        *,
+        autocommit: bool = False,
+        **kwargs: Any,
     ) -> "Connection":
         if conninfo is None and not kwargs:
             raise TypeError("missing conninfo and not parameters specified")
         conninfo = make_conninfo(conninfo or "", **kwargs)
         gen = connect(conninfo)
         pgconn = cls.wait(gen)
-        return cls(pgconn)
+        conn = cls(pgconn)
+        conn._autocommit = autocommit
+        return conn
 
     def close(self) -> None:
         self.pgconn.finish()
@@ -156,8 +177,10 @@ class Connection(BaseConnection):
 
     def _start_query(self) -> None:
         # the function is meant to be called by a cursor once the lock is taken
-        status = self.pgconn.transaction_status
-        if status == TransactionStatus.INTRANS:
+        if self._autocommit:
+            return
+
+        if self.pgconn.transaction_status == TransactionStatus.INTRANS:
             return
 
         self.pgconn.send_query(b"begin")
@@ -222,14 +245,20 @@ class AsyncConnection(BaseConnection):
 
     @classmethod
     async def connect(
-        cls, conninfo: Optional[str] = None, **kwargs: Any
+        cls,
+        conninfo: Optional[str] = None,
+        *,
+        autocommit: bool = False,
+        **kwargs: Any,
     ) -> "AsyncConnection":
         if conninfo is None and not kwargs:
             raise TypeError("missing conninfo and not parameters specified")
         conninfo = make_conninfo(conninfo or "", **kwargs)
         gen = connect(conninfo)
         pgconn = await cls.wait(gen)
-        return cls(pgconn)
+        conn = cls(pgconn)
+        conn._autocommit = autocommit
+        return conn
 
     async def close(self) -> None:
         self.pgconn.finish()
@@ -242,8 +271,10 @@ class AsyncConnection(BaseConnection):
 
     async def _start_query(self) -> None:
         # the function is meant to be called by a cursor once the lock is taken
-        status = self.pgconn.transaction_status
-        if status == TransactionStatus.INTRANS:
+        if self._autocommit:
+            return
+
+        if self.pgconn.transaction_status == TransactionStatus.INTRANS:
             return
 
         self.pgconn.send_query(b"begin")
index 8192153b9b6599d5ef73aad2ddb24804c2fc9a26..5fb596c5b375fd9d2f0de75514f9d847045b3d35 100644 (file)
@@ -95,6 +95,51 @@ def test_auto_transaction_fail(loop, aconn):
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
+def test_autocommit(loop, aconn):
+    assert aconn.autocommit is False
+    aconn.autocommit = True
+    assert aconn.autocommit
+    cur = aconn.cursor()
+    loop.run_until_complete(cur.execute("select 1"))
+    assert loop.run_until_complete(cur.fetchone()) == (1,)
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+
+def test_autocommit_connect(loop, dsn):
+    aconn = loop.run_until_complete(
+        psycopg3.AsyncConnection.connect(dsn, autocommit=True)
+    )
+    assert aconn.autocommit
+
+
+def test_autocommit_intrans(loop, aconn):
+    cur = aconn.cursor()
+    loop.run_until_complete(cur.execute("select 1"))
+    assert loop.run_until_complete(cur.fetchone()) == (1,)
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+    with pytest.raises(psycopg3.ProgrammingError):
+        aconn.autocommit = True
+    assert not aconn.autocommit
+
+
+def test_autocommit_inerror(loop, aconn):
+    cur = aconn.cursor()
+    with pytest.raises(psycopg3.DatabaseError):
+        loop.run_until_complete(cur.execute("meh"))
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+    with pytest.raises(psycopg3.ProgrammingError):
+        aconn.autocommit = True
+    assert not aconn.autocommit
+
+
+def test_autocommit_unknown(loop, aconn):
+    loop.run_until_complete(aconn.close())
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.UNKNOWN
+    with pytest.raises(psycopg3.ProgrammingError):
+        aconn.autocommit = True
+    assert not aconn.autocommit
+
+
 def test_get_encoding(aconn, loop):
     cur = aconn.cursor()
     loop.run_until_complete(cur.execute("show client_encoding"))
index cfa0d9192db1c1df20819c35587020f16fb2ddec..2b6e1f49b2f44e8aa2d31e8cfcc2f1ba62784188 100644 (file)
@@ -93,6 +93,47 @@ def test_auto_transaction_fail(conn):
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
 
 
+def test_autocommit(conn):
+    assert conn.autocommit is False
+    conn.autocommit = True
+    assert conn.autocommit
+    cur = conn.cursor()
+    assert cur.execute("select 1").fetchone() == (1,)
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+
+def test_autocommit_connect(dsn):
+    conn = Connection.connect(dsn, autocommit=True)
+    assert conn.autocommit
+
+
+def test_autocommit_intrans(conn):
+    cur = conn.cursor()
+    assert cur.execute("select 1").fetchone() == (1,)
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+    with pytest.raises(psycopg3.ProgrammingError):
+        conn.autocommit = True
+    assert not conn.autocommit
+
+
+def test_autocommit_inerror(conn):
+    cur = conn.cursor()
+    with pytest.raises(psycopg3.DatabaseError):
+        cur.execute("meh")
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+    with pytest.raises(psycopg3.ProgrammingError):
+        conn.autocommit = True
+    assert not conn.autocommit
+
+
+def test_autocommit_unknown(conn):
+    conn.close()
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.UNKNOWN
+    with pytest.raises(psycopg3.ProgrammingError):
+        conn.autocommit = True
+    assert not conn.autocommit
+
+
 def test_get_encoding(conn):
     (enc,) = conn.cursor().execute("show client_encoding").fetchone()
     assert enc == conn.encoding