--- /dev/null
+import struct
+
+from psycopg.pq import Format
+from psycopg.copy import AsyncWriter
+from psycopg.copy import FileWriter as FileWriter # noqa: F401
+
+sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
+sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
+sample_tabledef = "col1 serial primary key, col2 int, data text"
+
+sample_text = b"""\
+40010\t40020\thello
+40040\t\\N\tworld
+"""
+
+sample_binary_str = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+ bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
+]
+sample_binary = b"".join(sample_binary_rows)
+
+special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
+
+
+def ensure_table(cur, tabledef, name="copy_in"):
+ cur.execute(f"drop table if exists {name}")
+ cur.execute(f"create table {name} ({tabledef})")
+
+
+async def ensure_table_async(cur, tabledef, name="copy_in"):
+ await cur.execute(f"drop table if exists {name}")
+ await cur.execute(f"create table {name} ({tabledef})")
+
+
+def py_to_raw(item, fmt):
+ """Convert from Python type to the expected result from the db"""
+ if fmt == Format.TEXT:
+ if isinstance(item, int):
+ return str(item)
+ else:
+ if isinstance(item, int):
+ # Assume int4
+ return struct.pack("!i", item)
+ elif isinstance(item, str):
+ return item.encode()
+ return item
+
+
+class AsyncFileWriter(AsyncWriter):
+ def __init__(self, file):
+ self.file = file
+
+ async def write(self, data):
+ self.file.write(data)
from psycopg.types.numeric import Int4
from ..utils import eur, gc_collect, gc_count
-from ..test_copy import sample_text, sample_binary # noqa
-from ..test_copy import ensure_table, sample_records
-from ..test_copy import sample_tabledef as sample_tabledef_pg
+from .._test_copy import sample_text, sample_binary # noqa
+from .._test_copy import ensure_table, sample_records
+from .._test_copy import sample_tabledef as sample_tabledef_pg
# CRDB int/serial are int8
sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4")
from psycopg.types.numeric import Int4
from ..utils import eur, gc_collect, gc_count
-from ..test_copy import sample_text, sample_binary # noqa
-from ..test_copy import sample_records
-from ..test_copy_async import ensure_table
+from .._test_copy import sample_text, sample_binary # noqa
+from .._test_copy import ensure_table_async, sample_records
from .test_copy import sample_tabledef, copyopt
pytestmark = [pytest.mark.crdb, pytest.mark.anyio]
)
async def test_copy_in_buffers(aconn, format, buffer):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
await copy.write(globals()[buffer])
async def test_copy_in_buffers_pg_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(e.UniqueViolation):
async with cur.copy("copy copy_in from stdin") as copy:
await copy.write(sample_text)
async def test_copy_in_str(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy("copy copy_in from stdin") as copy:
await copy.write(sample_text.decode())
@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
async def test_copy_in_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(e.QueryCanceled):
async with cur.copy("copy copy_in from stdin with binary") as copy:
await copy.write(sample_text.decode())
@pytest.mark.parametrize("format", Format)
async def test_copy_in_empty(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
pass
@pytest.mark.slow
async def test_copy_big_size_record(aconn):
cur = aconn.cursor()
- await ensure_table(cur, "id serial primary key, data text")
+ await ensure_table_async(cur, "id serial primary key, data text")
data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
async with cur.copy("copy copy_in (data) from stdin") as copy:
await copy.write_row([data])
@pytest.mark.slow
async def test_copy_big_size_block(aconn):
cur = aconn.cursor()
- await ensure_table(cur, "id serial primary key, data text")
+ await ensure_table_async(cur, "id serial primary key, data text")
data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
copy_data = data + "\n"
async with cur.copy("copy copy_in (data) from stdin") as copy:
async def test_copy_in_buffers_with_pg_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(e.UniqueViolation):
async with cur.copy("copy copy_in from stdin") as copy:
await copy.write(sample_text)
@pytest.mark.parametrize("format", Format)
async def test_copy_in_records(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
for row in sample_records:
@pytest.mark.parametrize("format", Format)
async def test_copy_in_records_set_types(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
copy.set_types(["int4", "int4", "text"])
@pytest.mark.parametrize("format", Format)
async def test_copy_in_records_binary(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+ await ensure_table_async(cur, "col1 serial primary key, col2 int4, data text")
async with cur.copy(
f"copy copy_in (col2, data) from stdin {copyopt(format)}"
@pytest.mark.crdb_skip("copy canceled")
async def test_copy_in_buffers_with_py_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(e.QueryCanceled) as exc:
async with cur.copy("copy copy_in from stdin") as copy:
await copy.write(sample_text)
async def test_copy_in_allchars(aconn):
cur = aconn.cursor()
- await ensure_table(cur, "col1 int primary key, col2 int, data text")
+ await ensure_table_async(cur, "col1 int primary key, col2 int, data text")
async with cur.copy("copy copy_in from stdin") as copy:
for i in range(1, 256):
import string
-import struct
import hashlib
from io import BytesIO, StringIO
from random import choice, randrange
from psycopg import sql
from psycopg import errors as e
from psycopg.pq import Format
-from psycopg.copy import Copy, LibpqWriter, QueuedLibpqWriter, FileWriter
+from psycopg.copy import Copy, LibpqWriter, QueuedLibpqWriter
from psycopg.adapt import PyFormat
from psycopg.types import TypeInfo
from psycopg.types.hstore import register_hstore
from psycopg.types.numeric import Int4
from .utils import eur, gc_collect, gc_count
+from ._test_copy import sample_text, sample_binary, sample_binary_rows # noqa
+from ._test_copy import sample_values, sample_records, sample_tabledef
+from ._test_copy import ensure_table, py_to_raw, special_chars, FileWriter
pytestmark = pytest.mark.crdb_skip("copy")
-sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
-sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
-sample_tabledef = "col1 serial primary key, col2 int, data text"
-
-sample_text = b"""\
-40010\t40020\thello
-40040\t\\N\tworld
-"""
-
-sample_binary_str = """
-5047 434f 5059 0aff 0d0a 00
-00 0000 0000 0000 00
-00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
-
-0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
-
-ff ff
-"""
-
-sample_binary_rows = [
- bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
-]
-sample_binary = b"".join(sample_binary_rows)
-
-special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
-
@pytest.mark.parametrize("format", Format)
def test_copy_out_read(conn, format):
if format == Format.TEXT:
from psycopg.types.string import StrDumper as BaseDumper
else:
- from psycopg.types.string import ( # type: ignore[assignment]
- StrBinaryDumper as BaseDumper,
- )
+ from psycopg.types.string import StrBinaryDumper
+
+ BaseDumper = StrBinaryDumper # type: ignore
class MyStrDumper(BaseDumper):
def dump(self, obj):
with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
for row in sample_records:
if format == Format.BINARY:
- row = tuple(
- Int4(i) if isinstance(i, int) else i for i in row
- ) # type: ignore[assignment]
+ row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+ row = row2 # type: ignore[assignment]
copy.write_row(row)
data = cur.execute("select * from copy_in order by 1").fetchall()
faker.assert_record(got, want)
-def py_to_raw(item, fmt):
- """Convert from Python type to the expected result from the db"""
- if fmt == Format.TEXT:
- if isinstance(item, int):
- return str(item)
- else:
- if isinstance(item, int):
- # Assume int4
- return struct.pack("!i", item)
- elif isinstance(item, str):
- return item.encode()
- return item
-
-
-def ensure_table(cur, tabledef, name="copy_in"):
- cur.execute(f"drop table if exists {name}")
- cur.execute(f"create table {name} ({tabledef})")
-
-
class DataGenerator:
def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
self.conn = conn
from psycopg import sql
from psycopg import errors as e
from psycopg.pq import Format
-from psycopg.copy import AsyncCopy
-from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter
-from psycopg.types import TypeInfo
+from psycopg.copy import AsyncCopy, AsyncLibpqWriter, AsyncQueuedLibpqWriter
from psycopg.adapt import PyFormat
+from psycopg.types import TypeInfo
from psycopg.types.hstore import register_hstore
from psycopg.types.numeric import Int4
from .utils import alist, eur, gc_collect, gc_count
-from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
-from .test_copy import sample_values, sample_records, sample_tabledef
-from .test_copy import py_to_raw, special_chars
+from ._test_copy import sample_text, sample_binary, sample_binary_rows # noqa
+from ._test_copy import sample_values, sample_records, sample_tabledef
+from ._test_copy import ensure_table_async, py_to_raw, special_chars, AsyncFileWriter
-pytestmark = [
- pytest.mark.crdb_skip("copy"),
-]
+pytestmark = pytest.mark.crdb_skip("copy")
@pytest.mark.parametrize("format", Format)
async with cur.copy(
f"copy ({sample_values}) to stdout (format {format.name})"
) as copy:
- copy.set_types("int4 int4 text".split())
+ copy.set_types(["int4", "int4", "text"])
rows = await alist(copy.rows())
assert rows == sample_records
)
async def test_copy_in_buffers(aconn, format, buffer):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
await copy.write(globals()[buffer])
async def test_copy_in_buffers_pg_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(e.UniqueViolation):
async with cur.copy("copy copy_in from stdin (format text)") as copy:
await copy.write(sample_text)
async def test_copy_in_str(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy("copy copy_in from stdin (format text)") as copy:
await copy.write(sample_text.decode())
async def test_copy_in_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(TypeError):
async with cur.copy("copy copy_in from stdin (format binary)") as copy:
await copy.write(sample_text.decode())
@pytest.mark.parametrize("format", Format)
async def test_copy_in_empty(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
pass
@pytest.mark.slow
async def test_copy_big_size_record(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
async with cur.copy("copy copy_in (data) from stdin") as copy:
await copy.write_row([data])
@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
async def test_copy_big_size_block(aconn, pytype):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
async with cur.copy("copy copy_in (data) from stdin") as copy:
if format == Format.TEXT:
from psycopg.types.string import StrDumper as BaseDumper
else:
- from psycopg.types.string import ( # type: ignore[assignment]
- StrBinaryDumper as BaseDumper,
- )
+ from psycopg.types.string import StrBinaryDumper
+
+ BaseDumper = StrBinaryDumper # type: ignore
class MyStrDumper(BaseDumper):
def dump(self, obj):
aconn.adapters.register_dumper(str, MyStrDumper)
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(
f"copy copy_in (data) from stdin (format {format.name})"
@pytest.mark.parametrize("format", Format)
async def test_copy_in_error_empty(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(ZeroDivisionError, match="mannaggiamiseria"):
async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
raise ZeroDivisionError("mannaggiamiseria")
async def test_copy_in_buffers_with_pg_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(e.UniqueViolation):
async with cur.copy("copy copy_in from stdin (format text)") as copy:
await copy.write(sample_text)
async def test_copy_in_buffers_with_py_error(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
with pytest.raises(ZeroDivisionError, match="nuttengoggenio"):
async with cur.copy("copy copy_in from stdin (format text)") as copy:
await copy.write(sample_text)
@pytest.mark.parametrize("format", Format)
async def test_copy_in_records(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
for row in sample_records:
if format == Format.BINARY:
- row = tuple(
- Int4(i) if isinstance(i, int) else i for i in row
- ) # type: ignore[assignment]
+ row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+ row = row2 # type: ignore[assignment]
await copy.write_row(row)
await cur.execute("select * from copy_in order by 1")
@pytest.mark.parametrize("format", Format)
async def test_copy_in_records_set_types(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
copy.set_types(["int4", "int4", "text"])
@pytest.mark.parametrize("format", Format)
async def test_copy_in_records_binary(aconn, format):
cur = aconn.cursor()
- await ensure_table(cur, "col1 serial primary key, col2 int, data text")
+ await ensure_table_async(cur, "col1 serial primary key, col2 int, data text")
async with cur.copy(
f"copy copy_in (col2, data) from stdin (format {format.name})"
async def test_copy_in_allchars(aconn):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
await aconn.execute("set client_encoding to utf8")
async with cur.copy("copy copy_in from stdin (format text)") as copy:
)
async def test_worker_life(aconn, format, buffer):
cur = aconn.cursor()
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(
f"copy copy_in from stdin (format {format.name})",
writer=AsyncQueuedLibpqWriter(cur),
cur = aconn.cursor()
writer = AsyncLibpqWriter(cur)
- await ensure_table(cur, sample_tabledef)
+ await ensure_table_async(cur, sample_tabledef)
async with cur.copy(
f"copy copy_in from stdin (format {format.name})", writer=writer
) as copy:
faker.assert_record(got, want)
-async def ensure_table(cur, tabledef, name="copy_in"):
- await cur.execute(f"drop table if exists {name}")
- await cur.execute(f"create table {name} ({tabledef})")
-
-
class DataGenerator:
def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
self.conn = conn
async def ensure_table(self):
cur = self.conn.cursor()
- await ensure_table(cur, "id integer primary key, data text")
+ await ensure_table_async(cur, "id integer primary key, data text")
def records(self):
for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
block = block.encode()
m.update(block)
return m.hexdigest()
-
-
-class AsyncFileWriter(AsyncWriter):
- def __init__(self, file):
- self.file = file
-
- async def write(self, data):
- self.file.write(data)
names_map = {
"AsyncClientCursor": "ClientCursor",
"AsyncConnection": "Connection",
+ "AsyncCopy": "Copy",
"AsyncCursor": "Cursor",
+ "AsyncFileWriter": "FileWriter",
+ "AsyncLibpqWriter": "LibpqWriter",
+ "AsyncQueuedLibpqWriter": "QueuedLibpqWriter",
"AsyncRawCursor": "RawCursor",
"AsyncServerCursor": "ServerCursor",
"aclose": "close",