]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Cursor.copy() made into a context manager
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 01:47:51 +0000 (01:47 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 02:33:34 +0000 (02:33 +0000)
What was before was a factory function, however that forced to have a
pattern like:

    async with (await cursor.copy()) as copy

Now instead what should be used is:

    async with cursor.copy() as copy

With this change the user pretty much is never exposed anymore to a Copy
object in a non-entered state. This is actually useful because it
reduces the surface of the API: now for instance Copy.finis() can become
a private method.

docs/cursor.rst
docs/usage.rst
psycopg3/psycopg3/cursor.py
tests/test_copy.py
tests/test_copy_async.py
tests/test_sql.py

index 6debec8a776e86ff9a6a32e5566b983ba303a463..8f0c651e7cf0f4973b92a34ddb11313b802de977 100644 (file)
@@ -145,12 +145,6 @@ Cursor support objects
         The data in the tuple will be converted as configured on the cursor;
         see :ref:`adaptation` for details.
 
-    .. automethod:: finish
-
-        If an *error* is specified, the :sql:`COPY` operation is cancelled.
-
-        The method is called automatically at the end of a `!with` block.
-
 
 .. autoclass:: AsyncCopy
 
@@ -161,4 +155,3 @@ Cursor support objects
     .. automethod:: read
     .. automethod:: write
     .. automethod:: write_row
-    .. automethod:: finish
index 83934d3cbd1556ce1804abf17f052e1d90e04b65..d4872a5c69ab95105c9cbf1b8c2c69a9d161a267 100644 (file)
@@ -202,8 +202,9 @@ produce `!bytes`:
 .. code:: python
 
     with open("data.out", "wb") as f:
-        for data in cursor.copy("COPY table_name TO STDOUT") as copy:
-            f.write(data)
+        with cursor.copy("COPY table_name TO STDOUT") as copy:
+            for data in copy:
+                f.write(data)
 
 Asynchronous operations are supported using the same patterns on an
 `AsyncConnection`. For instance, if `!f` is an object supporting an
@@ -212,7 +213,7 @@ copy operation could be:
 
 .. code:: python
 
-    async with (await cursor.copy("COPY data FROM STDIN")) as copy:
+    async with cursor.copy("COPY data FROM STDIN") as copy:
         data = await f.read()
         if not data:
             break
index bf865e4104d2c8b2bc76811a4a12d21136b5ee03..4cef7efaed341d89951f9c72fca9d3997f7484ed 100644 (file)
@@ -5,17 +5,10 @@ psycopg3 cursor objects
 # Copyright (C) 2020 The Psycopg Team
 
 from types import TracebackType
-from typing import (
-    Any,
-    AsyncIterator,
-    Callable,
-    Generic,
-    Iterator,
-    List,
-    Mapping,
-)
-from typing import Optional, Sequence, Type, TYPE_CHECKING, Union
+from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
+from typing import Mapping, Optional, Sequence, Type, TYPE_CHECKING, Union
 from operator import attrgetter
+from contextlib import asynccontextmanager, contextmanager
 
 from . import errors as e
 from . import pq
@@ -556,10 +549,19 @@ class Cursor(BaseCursor["Connection"]):
             self._pos += 1
             yield row
 
-    def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy:
+    @contextmanager
+    def copy(
+        self, statement: Query, vars: Optional[Params] = None
+    ) -> Iterator[Copy]:
         """
         Initiate a :sql:`COPY` operation and return a `Copy` object to manage it.
         """
+        with self._start_copy(statement, vars) as copy:
+            yield copy
+
+    def _start_copy(
+        self, statement: Query, vars: Optional[Params] = None
+    ) -> Copy:
         with self.connection.lock:
             self._start_query()
             self.connection._start_query()
@@ -682,12 +684,20 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             self._pos += 1
             yield row
 
+    @asynccontextmanager
     async def copy(
         self, statement: Query, vars: Optional[Params] = None
-    ) -> AsyncCopy:
+    ) -> AsyncIterator[AsyncCopy]:
         """
         Initiate a :sql:`COPY` operation and return an `AsyncCopy` object.
         """
+        copy = await self._start_copy(statement, vars)
+        async with copy:
+            yield copy
+
+    async def _start_copy(
+        self, statement: Query, vars: Optional[Params] = None
+    ) -> AsyncCopy:
         async with self.connection.lock:
             self._start_query()
             await self.connection._start_query()
index 6dfc30a81376d888d5df7e7b1213b9a21c815119..1d342b84fd344525a87de80df6a64d5267722eda 100644 (file)
@@ -35,31 +35,37 @@ sample_binary = b"".join(sample_binary_rows)
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_copy_out_iter(conn, format):
-    cur = conn.cursor()
-    copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
+def test_copy_out_read(conn, format):
     if format == pq.Format.TEXT:
         want = [row + b"\n" for row in sample_text.splitlines()]
     else:
         want = sample_binary_rows
-    assert list(copy) == want
-
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_copy_out_context(conn, format):
     cur = conn.cursor()
-    out = []
     with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
-        for row in copy:
-            out.append(row)
+        for row in want:
+            got = copy.read()
+            assert got == row
+
+        assert copy.read() is None
+        assert copy.read() is None
 
+    assert copy.read() is None
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_iter(conn, format):
     if format == pq.Format.TEXT:
         want = [row + b"\n" for row in sample_text.splitlines()]
     else:
         want = sample_binary_rows
-    assert out == want
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        assert list(copy) == want
 
 
 @pytest.mark.parametrize(
@@ -69,9 +75,9 @@ def test_copy_out_context(conn, format):
 def test_copy_in_buffers(conn, format, buffer):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
-    copy = cur.copy(f"copy copy_in from stdin (format {format.name})")
-    copy.write(globals()[buffer])
-    copy.finish()
+    with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+        copy.write(globals()[buffer])
+
     data = cur.execute("select * from copy_in order by 1").fetchall()
     assert data == sample_records
 
@@ -79,11 +85,10 @@ def test_copy_in_buffers(conn, format, buffer):
 def test_copy_in_buffers_pg_error(conn):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
-    copy = cur.copy("copy copy_in from stdin (format text)")
-    copy.write(sample_text)
-    copy.write(sample_text)
     with pytest.raises(e.UniqueViolation):
-        copy.finish()
+        with cur.copy("copy copy_in from stdin (format text)") as copy:
+            copy.write(sample_text)
+            copy.write(sample_text)
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
 
 
@@ -93,13 +98,16 @@ def test_copy_bad_result(conn):
     cur = conn.cursor()
 
     with pytest.raises(e.SyntaxError):
-        cur.copy("wat")
+        with cur.copy("wat"):
+            pass
 
     with pytest.raises(e.ProgrammingError):
-        cur.copy("select 1")
+        with cur.copy("select 1"):
+            pass
 
     with pytest.raises(e.ProgrammingError):
-        cur.copy("reset timezone")
+        with cur.copy("reset timezone"):
+            pass
 
 
 @pytest.mark.parametrize(
index 9ac4f9dee771b158b42f74d179b0ae5a56d1fc3d..8cc0e05aed037a4b6b92a79d1b9286ce43b040d6 100644 (file)
@@ -12,57 +12,40 @@ pytestmark = pytest.mark.asyncio
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 async def test_copy_out_read(aconn, format):
-    cur = await aconn.cursor()
-    copy = await cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    )
-
     if format == pq.Format.TEXT:
         want = [row + b"\n" for row in sample_text.splitlines()]
     else:
         want = sample_binary_rows
 
-    for row in want:
-        got = await copy.read()
-        assert got == row
+    cur = await aconn.cursor()
+    async with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        for row in want:
+            got = await copy.read()
+            assert got == row
+
+        assert await copy.read() is None
+        assert await copy.read() is None
 
     assert await copy.read() is None
-    assert await copy.read() is None
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 async def test_copy_out_iter(aconn, format):
-    cur = await aconn.cursor()
-    copy = await cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    )
     if format == pq.Format.TEXT:
         want = [row + b"\n" for row in sample_text.splitlines()]
     else:
         want = sample_binary_rows
-    got = []
-    async for row in copy:
-        got.append(row)
-    assert got == want
-
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-async def test_copy_out_context(aconn, format):
     cur = await aconn.cursor()
-    out = []
-    async with (
-        await cur.copy(
-            f"copy ({sample_values}) to stdout (format {format.name})"
-        )
+    got = []
+    async with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
         async for row in copy:
-            out.append(row)
-
-    if format == pq.Format.TEXT:
-        want = [row + b"\n" for row in sample_text.splitlines()]
-    else:
-        want = sample_binary_rows
-    assert out == want
+            got.append(row)
+    assert got == want
 
 
 @pytest.mark.parametrize(
@@ -72,9 +55,11 @@ async def test_copy_out_context(aconn, format):
 async def test_copy_in_buffers(aconn, format, buffer):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
-    copy = await cur.copy(f"copy copy_in from stdin (format {format.name})")
-    await copy.write(globals()[buffer])
-    await copy.finish()
+    async with cur.copy(
+        f"copy copy_in from stdin (format {format.name})"
+    ) as copy:
+        await copy.write(globals()[buffer])
+
     await cur.execute("select * from copy_in order by 1")
     data = await cur.fetchall()
     assert data == sample_records
@@ -83,11 +68,10 @@ async def test_copy_in_buffers(aconn, format, buffer):
 async def test_copy_in_buffers_pg_error(aconn):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
-    copy = await cur.copy("copy copy_in from stdin (format text)")
-    await copy.write(sample_text)
-    await copy.write(sample_text)
     with pytest.raises(e.UniqueViolation):
-        await copy.finish()
+        async with cur.copy("copy copy_in from stdin (format text)") as copy:
+            await copy.write(sample_text)
+            await copy.write(sample_text)
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
 
 
@@ -97,38 +81,22 @@ async def test_copy_bad_result(aconn):
     cur = await aconn.cursor()
 
     with pytest.raises(e.SyntaxError):
-        await cur.copy("wat")
+        async with cur.copy("wat"):
+            pass
 
     with pytest.raises(e.ProgrammingError):
-        await cur.copy("select 1")
+        async with cur.copy("select 1"):
+            pass
 
     with pytest.raises(e.ProgrammingError):
-        await cur.copy("reset timezone")
-
-
-@pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-async def test_copy_in_buffers_with(aconn, format, buffer):
-    cur = await aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
-    async with (
-        await cur.copy(f"copy copy_in from stdin (format {format.name})")
-    ) as copy:
-        await copy.write(globals()[buffer])
-
-    await cur.execute("select * from copy_in order by 1")
-    data = await cur.fetchall()
-    assert data == sample_records
+        async with cur.copy("reset timezone"):
+            pass
 
 
 async def test_copy_in_str(aconn):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
-    async with (
-        await cur.copy("copy copy_in from stdin (format text)")
-    ) as copy:
+    async with cur.copy("copy copy_in from stdin (format text)") as copy:
         await copy.write(sample_text.decode("utf8"))
 
     await cur.execute("select * from copy_in order by 1")
@@ -140,9 +108,7 @@ async def test_copy_in_str_binary(aconn):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled):
-        async with (
-            await cur.copy("copy copy_in from stdin (format binary)")
-        ) as copy:
+        async with cur.copy("copy copy_in from stdin (format binary)") as copy:
             await copy.write(sample_text.decode("utf8"))
 
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
@@ -152,9 +118,7 @@ async def test_copy_in_buffers_with_pg_error(aconn):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     with pytest.raises(e.UniqueViolation):
-        async with (
-            await cur.copy("copy copy_in from stdin (format text)")
-        ) as copy:
+        async with cur.copy("copy copy_in from stdin (format text)") as copy:
             await copy.write(sample_text)
             await copy.write(sample_text)
 
@@ -165,9 +129,7 @@ async def test_copy_in_buffers_with_py_error(aconn):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled) as exc:
-        async with (
-            await cur.copy("copy copy_in from stdin (format text)")
-        ) as copy:
+        async with cur.copy("copy copy_in from stdin (format text)") as copy:
             await copy.write(sample_text)
             raise Exception("nuttengoggenio")
 
@@ -183,8 +145,8 @@ async def test_copy_in_records(aconn, format):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
 
-    async with (
-        await cur.copy(f"copy copy_in from stdin (format {format.name})")
+    async with cur.copy(
+        f"copy copy_in from stdin (format {format.name})"
     ) as copy:
         for row in sample_records:
             await copy.write_row(row)
@@ -202,10 +164,8 @@ async def test_copy_in_records_binary(aconn, format):
     cur = await aconn.cursor()
     await ensure_table(cur, "col1 serial primary key, col2 int, data text")
 
-    async with (
-        await cur.copy(
-            f"copy copy_in (col2, data) from stdin (format {format.name})"
-        )
+    async with cur.copy(
+        f"copy copy_in (col2, data) from stdin (format {format.name})"
     ) as copy:
         for row in sample_records:
             await copy.write_row((None, row[2]))
@@ -220,9 +180,7 @@ async def test_copy_in_allchars(aconn):
     await ensure_table(cur, sample_tabledef)
 
     await aconn.set_client_encoding("utf8")
-    async with (
-        await cur.copy("copy copy_in from stdin (format text)")
-    ) as copy:
+    async with cur.copy("copy copy_in from stdin (format text)") as copy:
         for i in range(1, 256):
             await copy.write_row((i, None, chr(i)))
         await copy.write_row((ord(eur), None, eur))
index 62ff15d0bcaff594a16385e13cbc81879961bad5..2d093c9a5745c6d20a907c4a2789c102695e1647 100755 (executable)
@@ -182,12 +182,12 @@ class TestSqlFormat:
             copy.write_row((10, "a", "b", "c"))
             copy.write_row((20, "d", "e", "f"))
 
-        copy = cur.copy(
+        with cur.copy(
             sql.SQL("copy (select {f} from {t} order by id) to stdout").format(
                 t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")
             )
-        )
-        assert list(copy) == [b"c\n", b"f\n"]
+        ) as copy:
+            assert list(copy) == [b"c\n", b"f\n"]
 
 
 class TestIdentifier: