from .base import PGIdentifierPreparer
from .base import REGCLASS
from .base import REGCONFIG
+from .types import BIT
from .types import BYTEA
from .types import CITEXT
from ... import exc
render_bind_cast = True
+class AsyncpgBit(BIT):
+ render_bind_cast = True
+
+
class AsyncpgByteA(BYTEA):
render_bind_cast = True
{
sqltypes.String: AsyncpgString,
sqltypes.ARRAY: AsyncpgARRAY,
+ BIT: AsyncpgBit,
CITEXT: CITEXT,
REGCONFIG: AsyncpgREGCONFIG,
sqltypes.Time: AsyncpgTime,
from sqlalchemy.dialects.postgresql import array_agg
from sqlalchemy.dialects.postgresql import asyncpg
from sqlalchemy.dialects.postgresql import base
+from sqlalchemy.dialects.postgresql import BIT
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.dialects.postgresql import CITEXT
from sqlalchemy.dialects.postgresql import DATEMULTIRANGE
set(connection.scalars(select(t.c.value))),
{value},
)
+
+ @testing.variation("sort_by_parameter_order", [True, False])
+ @testing.variation("multiple_rows", [True, False])
+ @testing.only_on("postgresql+asyncpg")
+ @testing.requires.insert_returning
+ def test_imv_returning_datatypes_asyncpg_bit(
+ self,
+ connection,
+ metadata,
+ sort_by_parameter_order,
+ multiple_rows,
+ ):
+ """test #10532
+
+ this tests insertmanyvalues in conjunction with the BIT datatype
+ on asyncpg.
+
+ These tests are particularly for the asyncpg driver which needs
+ most types to be explicitly cast for the new IMV format
+
+ """
+ from asyncpg import BitString
+
+ t = Table(
+ "d_t",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("value", BIT(3)),
+ )
+
+ t.create(connection)
+
+ value = BitString.from_int(3, length=3)
+
+ result = connection.execute(
+ t.insert().returning(
+ t.c.id,
+ t.c.value,
+ sort_by_parameter_order=bool(sort_by_parameter_order),
+ ),
+ [{"value": value} for i in range(10)]
+ if multiple_rows
+ else {"value": value},
+ )
+
+ if multiple_rows:
+ i_range = range(1, 11)
+ else:
+ i_range = range(1, 2)
+
+ eq_(
+ set(result),
+ {(id_, value) for id_ in i_range},
+ )
+
+ eq_(
+ set(connection.scalars(select(t.c.value))),
+ {value},
+ )