From ddc3bb2d1f658792df1f18e2c9c31aa378122837 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 24 Nov 2020 12:46:41 +0000 Subject: [PATCH] Raise TypeError attempting to use a Copy context more than once See #10 --- psycopg3/psycopg3/copy.py | 6 ++++-- psycopg3/psycopg3/transaction.py | 2 +- tests/test_copy.py | 10 ++++++++++ tests/test_copy_async.py | 12 ++++++++++++ 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 931bddb0a..2fcbe7ca4 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -172,7 +172,8 @@ class Copy(BaseCopy["Connection"]): self._finished = True def __enter__(self) -> "Copy": - assert not self._finished + if self._finished: + raise TypeError("copy blocks can be used only once") return self def __exit__( @@ -240,7 +241,8 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): self._finished = True async def __aenter__(self) -> "AsyncCopy": - assert not self._finished + if self._finished: + raise TypeError("copy blocks can be used only once") return self async def __aexit__( diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 14c6adaba..9e279688d 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -76,7 +76,7 @@ class BaseTransaction(Generic[ConnectionType]): def _enter_commands(self) -> List[str]: if not self._yolo: - raise TypeError("transaction blocks cannot be use more than once") + raise TypeError("transaction blocks can be used only once") else: self._yolo = False diff --git a/tests/test_copy.py b/tests/test_copy.py index 8beec2f10..2ad60d243 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -299,6 +299,16 @@ def test_copy_query(conn): list(copy) +def test_cant_reenter(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + list(copy) + + with pytest.raises(TypeError): + with copy: + list(copy) + + def ensure_table(cur, tabledef, name="copy_in"): cur.execute(f"drop table if exists {name}") cur.execute(f"create table {name} ({tabledef})") diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 3bbd8f25f..ac25b1d20 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -288,6 +288,18 @@ async def test_copy_query(aconn): pass +async def test_cant_reenter(aconn): + cur = await aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + async for record in copy: + pass + + with pytest.raises(TypeError): + async with copy: + async for record in copy: + pass + + async def ensure_table(cur, tabledef, name="copy_in"): await cur.execute(f"drop table if exists {name}") await cur.execute(f"create table {name} ({tabledef})") -- 2.47.3