]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Raise TypeError attempting to use a Copy context more than once
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Nov 2020 12:46:41 +0000 (12:46 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Nov 2020 12:46:41 +0000 (12:46 +0000)
See #10

psycopg3/psycopg3/copy.py
psycopg3/psycopg3/transaction.py
tests/test_copy.py
tests/test_copy_async.py

index 931bddb0a5f49de15660b8aa53bf024c216668bb..2fcbe7ca42940d33526243027799c5b2a48ca5f6 100644 (file)
@@ -172,7 +172,8 @@ class Copy(BaseCopy["Connection"]):
         self._finished = True
 
     def __enter__(self) -> "Copy":
-        assert not self._finished
+        if self._finished:
+            raise TypeError("copy blocks can be used only once")
         return self
 
     def __exit__(
@@ -240,7 +241,8 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         self._finished = True
 
     async def __aenter__(self) -> "AsyncCopy":
-        assert not self._finished
+        if self._finished:
+            raise TypeError("copy blocks can be used only once")
         return self
 
     async def __aexit__(
index 14c6adaba707c146d98d4f1a47a751732fab7374..9e279688d786f35123c2e734e146075b39f13a9e 100644 (file)
@@ -76,7 +76,7 @@ class BaseTransaction(Generic[ConnectionType]):
 
     def _enter_commands(self) -> List[str]:
         if not self._yolo:
-            raise TypeError("transaction blocks cannot be use more than once")
+            raise TypeError("transaction blocks can be used only once")
         else:
             self._yolo = False
 
index 8beec2f103bf5b6517083d09349cafa837e0f136..2ad60d2433197110cae024672e47ba781c6aca3d 100644 (file)
@@ -299,6 +299,16 @@ def test_copy_query(conn):
         list(copy)
 
 
+def test_cant_reenter(conn):
+    cur = conn.cursor()
+    with cur.copy("copy (select 1) to stdout") as copy:
+        list(copy)
+
+    with pytest.raises(TypeError):
+        with copy:
+            list(copy)
+
+
 def ensure_table(cur, tabledef, name="copy_in"):
     cur.execute(f"drop table if exists {name}")
     cur.execute(f"create table {name} ({tabledef})")
index 3bbd8f25f73cd97e97cc8bcb7b84e289b2fc41cf..ac25b1d20268550bb66e8c063b02be0d5844fcce 100644 (file)
@@ -288,6 +288,18 @@ async def test_copy_query(aconn):
             pass
 
 
+async def test_cant_reenter(aconn):
+    cur = await aconn.cursor()
+    async with cur.copy("copy (select 1) to stdout") as copy:
+        async for record in copy:
+            pass
+
+    with pytest.raises(TypeError):
+        async with copy:
+            async for record in copy:
+                pass
+
+
 async def ensure_table(cur, tabledef, name="copy_in"):
     await cur.execute(f"drop table if exists {name}")
     await cur.execute(f"create table {name} ({tabledef})")