From: Daniele Varrazzo Date: Tue, 24 Nov 2020 12:46:41 +0000 (+0000) Subject: Raise TypeError attempting to use a Copy context more than once X-Git-Tag: 3.0.dev0~310 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ddc3bb2d1f658792df1f18e2c9c31aa378122837;p=thirdparty%2Fpsycopg.git Raise TypeError attempting to use a Copy context more than once See #10 --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 931bddb0a..2fcbe7ca4 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -172,7 +172,8 @@ class Copy(BaseCopy["Connection"]): self._finished = True def __enter__(self) -> "Copy": - assert not self._finished + if self._finished: + raise TypeError("copy blocks can be used only once") return self def __exit__( @@ -240,7 +241,8 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): self._finished = True async def __aenter__(self) -> "AsyncCopy": - assert not self._finished + if self._finished: + raise TypeError("copy blocks can be used only once") return self async def __aexit__( diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 14c6adaba..9e279688d 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -76,7 +76,7 @@ class BaseTransaction(Generic[ConnectionType]): def _enter_commands(self) -> List[str]: if not self._yolo: - raise TypeError("transaction blocks cannot be use more than once") + raise TypeError("transaction blocks can be used only once") else: self._yolo = False diff --git a/tests/test_copy.py b/tests/test_copy.py index 8beec2f10..2ad60d243 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -299,6 +299,16 @@ def test_copy_query(conn): list(copy) +def test_cant_reenter(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + list(copy) + + with pytest.raises(TypeError): + with copy: + list(copy) + + def ensure_table(cur, tabledef, name="copy_in"): cur.execute(f"drop table if exists {name}") cur.execute(f"create table {name} ({tabledef})") diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 3bbd8f25f..ac25b1d20 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -288,6 +288,18 @@ async def test_copy_query(aconn): pass +async def test_cant_reenter(aconn): + cur = await aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + async for record in copy: + pass + + with pytest.raises(TypeError): + async with copy: + async for record in copy: + pass + + async def ensure_table(cur, tabledef, name="copy_in"): await cur.execute(f"drop table if exists {name}") await cur.execute(f"create table {name} ({tabledef})")