]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Raise the correct error (diag and all) on commit error
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 2 Dec 2020 03:35:42 +0000 (03:35 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 2 Dec 2020 03:40:14 +0000 (03:40 +0000)
Also added other error-related tests from psycopg2 test suite.

psycopg3/psycopg3/connection.py
tests/test_errors.py

index f370f0be1733b7708cb5835c54b8a9c5afb8b9f9..a11419fe28c8c37eaad2480da3aac0f7377d0d61 100644 (file)
@@ -329,12 +329,17 @@ class Connection(BaseConnection):
             command = command.as_string(self).encode(self.client_encoding)
 
         self.pgconn.send_query(command)
-        results = self.wait(execute(self.pgconn))
-        if results[-1].status != ExecStatus.COMMAND_OK:
-            raise e.OperationalError(
-                f"error on {command.decode('utf8')}:"
-                f" {pq.error_message(results[-1], encoding=self.client_encoding)}"
-            )
+        result = self.wait(execute(self.pgconn))[-1]
+        if result.status != ExecStatus.COMMAND_OK:
+            if result.status == ExecStatus.FATAL_ERROR:
+                raise e.error_from_result(
+                    result, encoding=self.client_encoding
+                )
+            else:
+                raise e.InterfaceError(
+                    f"unexpected result {ExecStatus(result.status).name}"
+                    f" from command {command.decode('utf8')!r}"
+                )
 
     @contextmanager
     def transaction(
@@ -487,12 +492,17 @@ class AsyncConnection(BaseConnection):
             command = command.as_string(self).encode(self.client_encoding)
 
         self.pgconn.send_query(command)
-        results = await self.wait(execute(self.pgconn))
-        if results[-1].status != ExecStatus.COMMAND_OK:
-            raise e.OperationalError(
-                f"error on {command.decode('utf8')}:"
-                f" {pq.error_message(results[-1], encoding=self.client_encoding)}"
-            )
+        result = (await self.wait(execute(self.pgconn)))[-1]
+        if result.status != ExecStatus.COMMAND_OK:
+            if result.status == ExecStatus.FATAL_ERROR:
+                raise e.error_from_result(
+                    result, encoding=self.client_encoding
+                )
+            else:
+                raise e.InterfaceError(
+                    f"unexpected result {ExecStatus(result.status).name}"
+                    f" from command {command.decode('utf8')!r}"
+                )
 
     @asynccontextmanager
     async def transaction(
index f6da644a4029cfc575befc4ce3db13a72e7fa322..63792a215c8e55d4dfefc22d3cc92060f79cf672 100644 (file)
@@ -1,4 +1,6 @@
+import gc
 import pickle
+from weakref import ref
 
 import pytest
 
@@ -47,6 +49,24 @@ def test_diag_right_attr(pgconn, monkeypatch):
     assert len(checked) == len(pq.DiagnosticField)
 
 
+def test_diag_attrs_9_6(conn):
+    cur = conn.cursor()
+    cur.execute(
+        """
+        create temp table test_exc (
+            data int constraint chk_eq1 check (data = 1)
+        )"""
+    )
+    with pytest.raises(e.Error) as exc:
+        cur.execute("insert into test_exc values(2)")
+    diag = exc.value.diag
+    assert diag.sqlstate == "23514"
+    assert diag.schema_name[:7] == "pg_temp"
+    assert diag.table_name == "test_exc"
+    assert diag.constraint_name == "chk_eq1"
+    assert diag.severity_nonlocalized == "ERROR"
+
+
 @pytest.mark.parametrize("enc", ["utf8", "latin9"])
 def test_diag_encoding(conn, enc):
     msgs = []
@@ -131,3 +151,66 @@ def test_diag_pickle(conn):
         assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower())
 
     assert diag2.sqlstate == "42P01"
+
+
+def test_diag_survives_cursor(conn):
+    cur = conn.cursor()
+    with pytest.raises(e.Error) as exc:
+        cur.execute("select * from nosuchtable")
+
+    diag = exc.value.diag
+    del exc
+    w = ref(cur)
+    del cur
+    gc.collect()
+    assert w() is None
+    assert diag.sqlstate == "42P01"
+
+
+def test_diag_independent(conn):
+    conn.autocommit = True
+    cur = conn.cursor()
+
+    with pytest.raises(e.Error) as exc1:
+        cur.execute("l'acqua e' poca e 'a papera nun galleggia")
+
+    with pytest.raises(e.Error) as exc2:
+        cur.execute("select level from water where ducks > 1")
+
+    assert exc1.value.diag.sqlstate == "42601"
+    assert exc2.value.diag.sqlstate == "42P01"
+
+
+def test_diag_from_commit(conn):
+    cur = conn.cursor()
+    cur.execute(
+        """
+        create temp table test_deferred (
+           data int primary key,
+           ref int references test_deferred (data)
+               deferrable initially deferred)
+    """
+    )
+    cur.execute("insert into test_deferred values (1,2)")
+    with pytest.raises(e.Error) as exc:
+        conn.commit()
+
+    assert exc.value.diag.sqlstate == "23503"
+
+
+@pytest.mark.asyncio
+async def test_diag_from_commit_async(aconn):
+    cur = await aconn.cursor()
+    await cur.execute(
+        """
+        create temp table test_deferred (
+           data int primary key,
+           ref int references test_deferred (data)
+               deferrable initially deferred)
+    """
+    )
+    await cur.execute("insert into test_deferred values (1,2)")
+    with pytest.raises(e.Error) as exc:
+        await aconn.commit()
+
+    assert exc.value.diag.sqlstate == "23503"