# Copyright (C) 2020 The Psycopg Team
from types import TracebackType
-from typing import (
- Any,
- AsyncIterator,
- Callable,
- Generic,
- Iterator,
- List,
- Mapping,
-)
-from typing import Optional, Sequence, Type, TYPE_CHECKING, Union
+from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
+from typing import Mapping, Optional, Sequence, Type, TYPE_CHECKING, Union
from operator import attrgetter
+from contextlib import asynccontextmanager, contextmanager
from . import errors as e
from . import pq
self._pos += 1
yield row
- def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy:
+ @contextmanager
+ def copy(
+ self, statement: Query, vars: Optional[Params] = None
+ ) -> Iterator[Copy]:
"""
Initiate a :sql:`COPY` operation and return a `Copy` object to manage it.
"""
+ with self._start_copy(statement, vars) as copy:
+ yield copy
+
+ def _start_copy(
+ self, statement: Query, vars: Optional[Params] = None
+ ) -> Copy:
with self.connection.lock:
self._start_query()
self.connection._start_query()
self._pos += 1
yield row
+ @asynccontextmanager
async def copy(
self, statement: Query, vars: Optional[Params] = None
- ) -> AsyncCopy:
+ ) -> AsyncIterator[AsyncCopy]:
"""
Initiate a :sql:`COPY` operation and return an `AsyncCopy` object.
"""
+ copy = await self._start_copy(statement, vars)
+ async with copy:
+ yield copy
+
+ async def _start_copy(
+ self, statement: Query, vars: Optional[Params] = None
+ ) -> AsyncCopy:
async with self.connection.lock:
self._start_query()
await self.connection._start_query()
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_copy_out_iter(conn, format):
- cur = conn.cursor()
- copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
+def test_copy_out_read(conn, format):
if format == pq.Format.TEXT:
want = [row + b"\n" for row in sample_text.splitlines()]
else:
want = sample_binary_rows
- assert list(copy) == want
-
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_copy_out_context(conn, format):
cur = conn.cursor()
- out = []
with cur.copy(
f"copy ({sample_values}) to stdout (format {format.name})"
) as copy:
- for row in copy:
- out.append(row)
+ for row in want:
+ got = copy.read()
+ assert got == row
+
+ assert copy.read() is None
+ assert copy.read() is None
+ assert copy.read() is None
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_iter(conn, format):
if format == pq.Format.TEXT:
want = [row + b"\n" for row in sample_text.splitlines()]
else:
want = sample_binary_rows
- assert out == want
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ assert list(copy) == want
@pytest.mark.parametrize(
def test_copy_in_buffers(conn, format, buffer):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
- copy = cur.copy(f"copy copy_in from stdin (format {format.name})")
- copy.write(globals()[buffer])
- copy.finish()
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.write(globals()[buffer])
+
data = cur.execute("select * from copy_in order by 1").fetchall()
assert data == sample_records
def test_copy_in_buffers_pg_error(conn):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
- copy = cur.copy("copy copy_in from stdin (format text)")
- copy.write(sample_text)
- copy.write(sample_text)
with pytest.raises(e.UniqueViolation):
- copy.finish()
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
cur = conn.cursor()
with pytest.raises(e.SyntaxError):
- cur.copy("wat")
+ with cur.copy("wat"):
+ pass
with pytest.raises(e.ProgrammingError):
- cur.copy("select 1")
+ with cur.copy("select 1"):
+ pass
with pytest.raises(e.ProgrammingError):
- cur.copy("reset timezone")
+ with cur.copy("reset timezone"):
+ pass
@pytest.mark.parametrize(
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
async def test_copy_out_read(aconn, format):
- cur = await aconn.cursor()
- copy = await cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- )
-
if format == pq.Format.TEXT:
want = [row + b"\n" for row in sample_text.splitlines()]
else:
want = sample_binary_rows
- for row in want:
- got = await copy.read()
- assert got == row
+ cur = await aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ for row in want:
+ got = await copy.read()
+ assert got == row
+
+ assert await copy.read() is None
+ assert await copy.read() is None
assert await copy.read() is None
- assert await copy.read() is None
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
async def test_copy_out_iter(aconn, format):
- cur = await aconn.cursor()
- copy = await cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- )
if format == pq.Format.TEXT:
want = [row + b"\n" for row in sample_text.splitlines()]
else:
want = sample_binary_rows
- got = []
- async for row in copy:
- got.append(row)
- assert got == want
-
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-async def test_copy_out_context(aconn, format):
cur = await aconn.cursor()
- out = []
- async with (
- await cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- )
+ got = []
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
) as copy:
async for row in copy:
- out.append(row)
-
- if format == pq.Format.TEXT:
- want = [row + b"\n" for row in sample_text.splitlines()]
- else:
- want = sample_binary_rows
- assert out == want
+ got.append(row)
+ assert got == want
@pytest.mark.parametrize(
async def test_copy_in_buffers(aconn, format, buffer):
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)
- copy = await cur.copy(f"copy copy_in from stdin (format {format.name})")
- await copy.write(globals()[buffer])
- await copy.finish()
+ async with cur.copy(
+ f"copy copy_in from stdin (format {format.name})"
+ ) as copy:
+ await copy.write(globals()[buffer])
+
await cur.execute("select * from copy_in order by 1")
data = await cur.fetchall()
assert data == sample_records
async def test_copy_in_buffers_pg_error(aconn):
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)
- copy = await cur.copy("copy copy_in from stdin (format text)")
- await copy.write(sample_text)
- await copy.write(sample_text)
with pytest.raises(e.UniqueViolation):
- await copy.finish()
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
cur = await aconn.cursor()
with pytest.raises(e.SyntaxError):
- await cur.copy("wat")
+ async with cur.copy("wat"):
+ pass
with pytest.raises(e.ProgrammingError):
- await cur.copy("select 1")
+ async with cur.copy("select 1"):
+ pass
with pytest.raises(e.ProgrammingError):
- await cur.copy("reset timezone")
-
-
-@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-async def test_copy_in_buffers_with(aconn, format, buffer):
- cur = await aconn.cursor()
- await ensure_table(cur, sample_tabledef)
- async with (
- await cur.copy(f"copy copy_in from stdin (format {format.name})")
- ) as copy:
- await copy.write(globals()[buffer])
-
- await cur.execute("select * from copy_in order by 1")
- data = await cur.fetchall()
- assert data == sample_records
+ async with cur.copy("reset timezone"):
+ pass
async def test_copy_in_str(aconn):
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)
- async with (
- await cur.copy("copy copy_in from stdin (format text)")
- ) as copy:
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
await copy.write(sample_text.decode("utf8"))
await cur.execute("select * from copy_in order by 1")
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)
with pytest.raises(e.QueryCanceled):
- async with (
- await cur.copy("copy copy_in from stdin (format binary)")
- ) as copy:
+ async with cur.copy("copy copy_in from stdin (format binary)") as copy:
await copy.write(sample_text.decode("utf8"))
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)
with pytest.raises(e.UniqueViolation):
- async with (
- await cur.copy("copy copy_in from stdin (format text)")
- ) as copy:
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
await copy.write(sample_text)
await copy.write(sample_text)
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)
with pytest.raises(e.QueryCanceled) as exc:
- async with (
- await cur.copy("copy copy_in from stdin (format text)")
- ) as copy:
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
await copy.write(sample_text)
raise Exception("nuttengoggenio")
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)
- async with (
- await cur.copy(f"copy copy_in from stdin (format {format.name})")
+ async with cur.copy(
+ f"copy copy_in from stdin (format {format.name})"
) as copy:
for row in sample_records:
await copy.write_row(row)
cur = await aconn.cursor()
await ensure_table(cur, "col1 serial primary key, col2 int, data text")
- async with (
- await cur.copy(
- f"copy copy_in (col2, data) from stdin (format {format.name})"
- )
+ async with cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
) as copy:
for row in sample_records:
await copy.write_row((None, row[2]))
await ensure_table(cur, sample_tabledef)
await aconn.set_client_encoding("utf8")
- async with (
- await cur.copy("copy copy_in from stdin (format text)")
- ) as copy:
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
for i in range(1, 256):
await copy.write_row((i, None, chr(i)))
await copy.write_row((ord(eur), None, eur))