From: Thomas Stephenson Date: Thu, 15 May 2025 13:21:49 +0000 (+1000) Subject: 10556: Create new BitString subclass for bitstrings for PG drivers X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=761748582718bde596fcb46d7109a274b910d0cc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git 10556: Create new BitString subclass for bitstrings for PG drivers - Added custom `BitString` implementation which inherits from `str` and overrides methods and operators appropriately and implements the bitwise operations in a manner consistent with postgresql. Also exposes class and instance methods ensuring instances can be converted to/from `int` and `bytes` sensibly. - Added default implementations of `bind_processor` and `result_processor` to `postgresql.BIT` type to convert `BitString` to/from string for PG drivers. - Override `bind_processor` and `result_processor` in AsyncpgBit in order to convert `BitString` to/from `asynpg.BitString` in `asyncpg` driver. - Added support for bitwise comparators to postgresql `BIT` instances Fixes: 10556 --- diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index e426df71be..677f3b7dd5 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -33,6 +33,7 @@ from .base import SMALLINT from .base import TEXT from .base import UUID from .base import VARCHAR +from .bitstring import BitString from .dml import Insert from .dml import insert from .ext import aggregate_order_by @@ -154,6 +155,7 @@ __all__ = ( "JSONPATH", "Any", "All", + "BitString", "DropEnumType", "DropDomainType", "CreateDomainType", diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 3d6aae9176..12be23635d 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -207,6 +207,7 @@ from .base import PGExecutionContext from .base import PGIdentifierPreparer from .base import REGCLASS from .base import REGCONFIG +from .bitstring import BitString from .types import BIT from .types import BYTEA from .types import CITEXT @@ -242,6 +243,32 @@ class AsyncpgTime(sqltypes.Time): class AsyncpgBit(BIT): render_bind_cast = True + def bind_processor(self, dialect): + asyncpg_BitString = dialect.dbapi.asyncpg.BitString + + def to_bind(value): + print(f'processing bound value \'{value}\'') + if isinstance(value, str): + value = BitString(value) + r = asyncpg_BitString.from_int(int(value), len(value)) + print(f'returning {r}') + return r + return value + + return to_bind + + def result_processor(self, dialect, coltype): + def to_result(value): + if value is not None: + print(f'result {value} length {len(value)}') + value = BitString.from_int( + value.to_int(), + length=len(value) + ) + return value + + return to_result + class AsyncpgByteA(BYTEA): render_bind_cast = True diff --git a/lib/sqlalchemy/dialects/postgresql/bitstring.py b/lib/sqlalchemy/dialects/postgresql/bitstring.py new file mode 100644 index 0000000000..803255d840 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/bitstring.py @@ -0,0 +1,367 @@ +# dialects/postgresql/types.py +# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import functools +import math +from typing import SupportsIndex, cast +from typing import Literal + + +@functools.total_ordering +class BitString(str): + """Represent a PostgreSQL bit string. + + e.g. + b = BitString('101') + """ + + def __new__(cls, _value: str, _check=True): + if not isinstance(_value, BitString) and ( + _check and _value and any(c not in "01" for c in _value) + ): + print(f'value: {_value}') + raise ValueError("BitString must only contain '0' and '1' chars") + return super().__new__(cls, _value) + + @classmethod + def from_int(cls, value: int, length: int): + """ + Returns a BitString consisting of the bits in the little-endian + representation of the given python int ``value``. A ``ValueError`` + is raised if ``value`` is not a non-negative integer. + + If the provided ``value`` can not be represented in a bit string + of at most ``length``, a ``ValueError`` will be raised. The bitstring + will be padded on the left by ``'0'`` to bits to produce a + """ + if value < 0: + raise ValueError("value must be a postive integer") + + if length >= 0: + if length > 0: + template_str = f'{{0:0{length}b}}' if length > 0 else '' + r = template_str.format(value) + else: + # f'{0:00b}'.format(0) == '0' + r = '' + + if len(r) > length: + raise ValueError( + f"Cannot encode {value} as a BitString of length {length}" + ) + else: + r = '{0:b}'.format(value) + + return cls(r) + + @classmethod + def from_bytes(cls, value: bytes, length: int = -1): + """ + Returns a ``BitString`` consisting of the bits in the given ``value`` + bytes. + + If ``length`` is provided, then the length of the provided string + will be exactly ``length``, with ``'0'`` bits inserted at the left of + the string in order to produce a value of the required length. + If, the bits obtained by omitting the leading `0` bits of ``value`` + cannot be represented in a string of this length, then a ``ValueError`` + will be raised. + """ + str_v: str = "".join(f"{c:08b}" for c in value) + if length >= 0: + str_v = str_v.lstrip('0') + + if len(str_v) >= length: + raise ValueError( + f"Cannot encode {value} as a BitString of length {length}" + ) + str_v = str_v.zfill(length) + + return cls(str_v) + + def get_bit(self, index) -> Literal["0", "1"]: + """ + Returns the value of the flag at the given index + + e.g. BitString('0101').get_flag(4) == 1 + """ + return cast(Literal["0", "1"], super().__getitem__(index)) + + @property + def bit_length(self): + return len(self) + + @property + def octet_length(self): + return math.ceil(len(self) / 8) + + def has_bit(self, index) -> bool: + return self.get_bit(index) == "1" + + def set_bit( + self, index: int, value: bool | int | Literal["0", "1"] + ) -> BitString: + """ + Set the bit at index to the given value. + + If value is an int, then it is considered to be '1' iff nonzero. + """ + if index < 0 or index >= len(self): + raise IndexError("BitString index out of range") + + if isinstance(value, (bool, int)): + value = "1" if value else "0" + + if self.get_bit(index) == value: + return self + + return BitString( + "".join([self[:index], value, self[index + 1:]]), False + ) + + # These methods probably should return str and not override the fillchar + # def ljust(self, width, fillchar=None) -> BitString: + # """ + # Returns the BitString left justified in a string of length width. + # Padding is done using the provided fillchar (default is '0'). + + # If the width is shorter than the length, then the original BitString + # is returned. + # """ + # if width < len(self): + # return self + + # fillchar = fillchar or "0" + # if str(fillchar) not in "01": + # raise ValueError("fillchar must be either '0' or '1'") + + # return BitString(super().ljust(width, fillchar or "0")) + + # def rjust(self, width, fillchar=None) -> BitString: + # if width < len(self): + # return self + + # fillchar = fillchar or "0" + # if str(fillchar) not in "01": + # raise ValueError("fillchar must be either '0' or '1'") + + # return BitString(super().rjust(width, fillchar)) + + def lstrip(self, char=None) -> BitString: + """ + Returns a copy of the BitString with leading characters removed. + + If omitted or None, 'chars' defaults '0' + + e.g. + BitString('00010101000').lstrip() === BitString('00010101') + BitString('11110101111').lstrip('1') === BitString('1111010') + """ + if char is None: + char = "0" + return BitString(super().lstrip(char), False) + + def rstrip(self, char=None) -> BitString: + """ + Returns a copy of the BitString with trailing characters removed. + + If omitted or None, 'chars' trailing '0' + + e.g. + BitString('00010101000').rstrip() === BitString('10101000') + BitString('11110101111').rstrip('1') === BitString('10101111') + """ + if char is None: + char = "0" + return BitString(super().rstrip(char), False) + + def strip(self, char=None) -> BitString: + """ + Returns a copy of the BitString with both leading and trailing + characters removed. + If ommitted or None, char defaults to '0' + + e.g. + BitString('00010101000').rstrip() === BitString('10101') + BitString('11110101111').rstrip('1') === BitString('1010') + """ + if char is None: + char = "0" + return BitString(super().strip(char)) + + def partition(self, sep: str = "0") -> tuple[BitString, str, BitString]: + """ + Split the string after the first appearance of sep + (which defaults to '0') and return a 3-tuple containing + the portion of the string before the separator. + + """ + prefix, _, suffix = super().partition(sep) + return (BitString(prefix, False), sep, BitString(suffix, False)) + + def removeprefix(self, prefix: str, /) -> BitString: + return BitString(super().removeprefix(prefix), False) + + def removesuffix(self, suffix: str, /) -> BitString: + return BitString(super().removesuffix(suffix), False) + + def replace(self, old, new, count: SupportsIndex = -1) -> BitString: + new = BitString(new) + return BitString(super().replace(old, new, count=count), False) + + def split( # type: ignore + self, + sep=None, + maxsplit: SupportsIndex = -1, + ) -> list[BitString]: + return [BitString(word) for word in super().split(sep, maxsplit)] + + def zfill(self, width) -> BitString: + return BitString(super().zfill(width), False) + + def __repr__(self): + return f'BitString("{self.__str__()}")' + + def __int__(self): + return int(self, 2) if self else 0 + + def __bytes__(self): + s = str(self) + bs = [] + while s: + bs.append(int(s[-8:], 2)) + s = s[:-8] + return bytes(bs) + + def __lt__(self, o): + if isinstance(o, BitString): + return super().__lt__(o) + return NotImplemented + + def __eq__(self, o): + return isinstance(o, BitString) and super().__eq__(o) + + def __hash__(self): + return hash(BitString) ^ super().__hash__() + + def __getitem__(self, key): + return BitString(super().__getitem__(key), False) + + def __add__(self, o): + """Return self + o""" + if not isinstance(o, str): + raise TypeError(( + "Can only concatenate str " + "(not '{0}') to BitString" + ).format(type(o))) + return BitString(''.join([self, o])) + + def __radd__(self, o): + if not isinstance(o, str): + raise TypeError(( + "Can only concatenate str (not '{0}') to BitString" + ).format(type(o))) + return BitString(''.join([o, self])) + + def __lshift__(self, amount: int): + """ + Shifts each the bitstring to the left by the given amount. + String length is preserved. + + i.e. BitString('000101') << 1 == BitString('001010') + """ + return BitString( + "".join([self, *("0" for _ in range(amount))])[-len(self):], False + ) + + def __rshift__(self, amount: int): + """ + Shifts each bit in the bitstring to the right by the given amount. + String length is preserved. + + e.g. BitString('101') >> 1 == BitString('010') + """ + return BitString(self[:-amount], False).zfill(width=len(self)) + + def __invert__(self): + """ + Inverts (~) each bit in the bitstring + + e.g. ~BitString('01010') == BitString('10101') + """ + return BitString("".join("1" if x == "0" else "0" for x in self)) + + def __and__(self, o): + """ + Performs a bitwise and (``&``) with the given operand. + A ``ValueError`` is raised if the operand is not the same length. + + e.g. BitString('011') & BitString('011') == BitString('010') + """ + + if not isinstance(o, str): + return NotImplemented + o = BitString(o) + if len(self) != len(o): + raise ValueError("Operands must be the same length") + + return BitString( + "".join( + "1" if (x == "1" and y == "1") else "0" + for x, y in zip(self, o) + ), + False, + ) + + def __or__(self, o): + """ + Performs a bitwise or (``|``) with the given operand. + A ``ValueError`` is raised if the operand is not the same length. + + e.g. BitString('011') | BitString('010') == BitString('011') + """ + if not isinstance(o, str): + return NotImplemented + + if len(self) != len(o): + raise ValueError("Operands must be the same length") + + o = BitString(o) + return BitString( + "".join( + "1" if (x == "1" or y == "1") else "0" + for (x, y) in zip(self, o) + ), + False, + ) + + def __xor__(self, o): + """ + Performs a bitwise xor (``^``) with the given operand. + A ``ValueError`` is raised if the operand is not the same length. + + e.g. BitString('011') ^ BitString('010') == BitString('001') + """ + + if not isinstance(o, BitString): + return NotImplemented + + if len(self) != len(o): + raise ValueError("Operands must be the same length") + + return BitString( + "".join( + ( + "1" + if ((x == "1" and y == "0") or (x == "0" and y == "1")) + else "0" + ) + for (x, y) in zip(self, o) + ), + False, + ) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index ff5e967ef6..4296df35e3 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -8,6 +8,7 @@ from __future__ import annotations import datetime as dt from typing import Any +from typing import Literal from typing import Optional from typing import overload from typing import Type @@ -16,13 +17,14 @@ from uuid import UUID as _python_UUID from ...sql import sqltypes from ...sql import type_api -from ...util.typing import Literal +from ...sql.type_api import TypeEngine + +from .bitstring import BitString if TYPE_CHECKING: from ...engine.interfaces import Dialect from ...sql.operators import OperatorType from ...sql.type_api import _LiteralProcessorType - from ...sql.type_api import TypeEngine _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) @@ -256,7 +258,8 @@ class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): PGInterval = INTERVAL -class BIT(sqltypes.TypeEngine[int]): +class BIT(sqltypes.TypeEngine[BitString]): + render_bind_cast = True __visit_name__ = "BIT" def __init__( @@ -270,6 +273,44 @@ class BIT(sqltypes.TypeEngine[int]): self.length = length or 1 self.varying = varying + def bind_processor(self, dialect): + def bound_value(value): + if isinstance(value, BitString): + return str(value) + return value + return bound_value + + def result_processor(self, dialect, coltype): + def from_result_value(value): + if value is not None: + value = BitString(value) + return value + return from_result_value + + def coerce_compared_value(self, op, value) -> TypeEngine[Any]: + if isinstance(value, str): + return self + return super().coerce_compared_value(op, value) + + @property + def python_type(self): + return BitString + + class comparator_factory(TypeEngine.Comparator[BitString]): + def __lshift__(self, other: Any): + return self.bitwise_lshift(other) + + def __rshift__(self, other: Any): + return self.bitwise_rshift(other) + + def __and__(self, other: Any): + return self.bitwise_and(other) + + def __or__(self, other: Any): + return self.bitwise_or(other) + + def __invert__(self): + return self.bitwise_not() PGBit = BIT diff --git a/test/dialect/postgresql/test_bitstring.py b/test/dialect/postgresql/test_bitstring.py new file mode 100644 index 0000000000..2a5b879c14 --- /dev/null +++ b/test/dialect/postgresql/test_bitstring.py @@ -0,0 +1,105 @@ +from sqlalchemy.testing import fixtures + +from sqlalchemy.dialects.postgresql import BitString +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import is_false +from sqlalchemy.testing.assertions import is_true +from sqlalchemy.testing.assertions import assert_raises + + +class BitStringTests(fixtures.TestBase): + + def test_ctor(self): + x = BitString("1110111") + eq_(str(x), "1110111") + eq_(int(x), 119) + + eq_(BitString("111"), BitString("111")) + is_false(BitString("111") == "111") + + eq_(hash(BitString("011")), hash(BitString("011"))) + is_false(hash(BitString("011")) == hash("011")) + + eq_(BitString("011")[1], BitString("1")) + + def test_int_conversion(self): + assert_raises(ValueError, lambda: BitString.from_int(127, length=6)) + + eq_(BitString.from_int(127, length=8), BitString("01111111")) + eq_(int(BitString.from_int(127, length=8)), 127) + + eq_(BitString.from_int(119, length=10), BitString("0001110111")) + eq_(int(BitString.from_int(119, length=10)), 119) + + def test_bytes_conversion(self): + eq_(BitString.from_bytes(b"\x01"), BitString("0000001")) + eq_(BitString.from_bytes(b"\x01", 4), BitString("00000001")) + + eq_(BitString.from_bytes(b"\xaf\x04"), BitString("101011110010")) + eq_( + BitString.from_bytes(b"\xaf\x04", 12), + BitString("0000101011110010"), + ) + assert_raises( + ValueError, lambda: BitString.from_bytes(b"\xaf\x04", 4), 1 + ) + + def test_get_set_bit(self): + eq_(BitString("1010").get_bit(2), "1") + eq_(BitString("0101").get_bit(2), "0") + assert_raises(IndexError, lambda: BitString("0").get_bit(1)) + + eq_(BitString("0101").set_bit(3, "0"), BitString("0100")) + eq_(BitString("0101").set_bit(3, "1"), BitString("0101")) + assert_raises(IndexError, lambda: BitString("1111").set_bit(5, "1")) + + def test_string_methods(self): + + # Which of these methods should be overridden to produce BitStrings? + eq_(BitString("111").center(8), " 111 ") + + eq_(BitString("0101").ljust(8), "0101 ") + eq_(BitString("0110").rjust(8), " 0110") + + eq_(BitString("01100").lstrip(), BitString("1100")) + eq_(BitString("01100").rstrip(), BitString("011")) + eq_(BitString("01100").strip(), BitString("11")) + + eq_(BitString("11100").removeprefix("111"), BitString("00")) + eq_(BitString("11100").removeprefix("0"), BitString("11100")) + + eq_(BitString("11100").removesuffix("10"), BitString("11100")) + eq_(BitString("11100").removesuffix("00"), BitString("111")) + + eq_( + BitString("010101011").replace("0101", "11", 1), + BitString("1101011"), + ) + eq_( + BitString("01101101").split("1", 2), + [BitString("0"), BitString(""), BitString("01101")], + ) + + eq_(BitString("0110").split("11"), [BitString("0"), BitString("0")]) + eq_(BitString("111").zfill(8), BitString("00000111")) + + def test_str_ops(self): + is_true("1" in BitString("001")) + is_true("0" in BitString("110")) + is_false("1" in BitString("000")) + + is_true("001" in BitString("01001")) + is_true(BitString("001") in BitString("01001")) + is_false(BitString("000") in BitString("01001")) + + eq_(BitString("010") + "001", BitString("010001")) + eq_("001" + BitString("010"), BitString("001010")) + + def test_bitwise_ops(self): + eq_(~BitString("0101"), BitString("1010")) + eq_(BitString("010") & BitString("011"), BitString("010")) + eq_(BitString("010") | BitString("011"), BitString("011")) + eq_(BitString("010") ^ BitString("011"), BitString("001")) + + eq_(BitString("001100") << 2, BitString("110000")) + eq_(BitString("001100") >> 2, BitString("000011")) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 795a897699..6fa07a1094 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -46,6 +46,7 @@ from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import array_agg from sqlalchemy.dialects.postgresql import asyncpg from sqlalchemy.dialects.postgresql import base +from sqlalchemy.dialects.postgresql import BitString from sqlalchemy.dialects.postgresql import BIT from sqlalchemy.dialects.postgresql import BYTEA from sqlalchemy.dialects.postgresql import CITEXT @@ -3514,7 +3515,9 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): metadata, Column("id", postgresql.UUID, primary_key=True), Column("flag", postgresql.BIT), - Column("bitstring", postgresql.BIT(4)), + Column("bitstring_varying", postgresql.BIT(varying=True)), + Column("bitstring_varying_6", postgresql.BIT(6, varying=True)), + Column("bitstring_4", postgresql.BIT(4)), Column("addr", postgresql.INET), Column("addr2", postgresql.MACADDR), Column("addr4", postgresql.MACADDR8), @@ -3542,7 +3545,19 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): self.assert_tables_equal(special_types_table, t, strict_types=True) assert t.c.plain_interval.type.precision is None assert t.c.precision_interval.type.precision == 3 - assert t.c.bitstring.type.length == 4 + + assert t.c.flag.type.varying is False + assert t.c.flag.type.length == 1 + + assert t.c.bitstring_varying.type.varying is True + assert t.c.bitstring_varying.type.length is None + + assert t.c.bitstring_varying_6.type.varying is True + assert t.c.bitstring_varying_6.type.length == 6 + + assert t.c.bitstring_4.type.varying is False + assert t.c.bitstring_4.type.length == 4 + @testing.combinations( (postgresql.INET, "127.0.0.1"), @@ -3568,6 +3583,38 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): "test", ) + @testing.combinations( + (postgresql.BIT(varying=True), BitString("")), + (postgresql.BIT(varying=True), BitString("1101010101")), + (postgresql.BIT(6, varying=True), BitString("")), + (postgresql.BIT(6, varying=True), BitString("010101")), + (postgresql.BIT(1), BitString("0")), + (postgresql.BIT(4), BitString("0010")), + (postgresql.BIT(4), "0010"), + argnames="column_type, value", + ) + def test_bitstring_round_trip( + self, connection, metadata, column_type, value + ): + t = Table( + "bits", + metadata, + Column("name", String), + Column("value", column_type) + ) + t.create(connection) + + connection.execute(t.insert(), {"name": "test", "value": value}) + print('value type affinity', t.c.value.type._type_affinity) + eq_( + connection.scalar(select(t.c.name).where(t.c.value == value)), + "test", + ) + + result_value = connection.scalar(select(t.c.value).where(t.c.name == "test")) + assert isinstance(result_value, BitString) + assert str(result_value) == str(value) + def test_tsvector_round_trip(self, connection, metadata): t = Table("t1", metadata, Column("data", postgresql.TSVECTOR)) t.create(connection) @@ -4169,6 +4216,107 @@ class HStoreRoundTripTest(fixtures.TablesTest): eq_(s.query(Data.data, Data).all(), [(d.data, d)]) +class BitTests(fixtures.TestBase): + def test_concatenation(self, connection): + coltype = BIT(varying=True) + + q = select( + literal(BitString('1111'), coltype).concat(BitString('0000')) + ) + r = connection.execute(q).first() + eq_(r[0], BitString('11110000')) + + @testing.skip("compiler bug") + def test_invert_operator(self, connection): + coltype = BIT(4) + + q = select( + literal(BitString('0010'), coltype).bitwise_not() + ) + r = connection.execute(q).first() + + # Observing r[0] == '1101' here. + # See: sql.compiler.Compiler._label_select_column + # The unary operator does not "wrap a column expression" + # and it isn't a from clause of the select, + # the compiler doesn't actually add the column to the select's + # result_columns and thus the type's result_processor never gets + # called. + eq_(r[0], BitString('1101')) + + def test_and_operator(self, connection): + coltype = BIT(6) + + q1 = select( + literal(BitString('001010'), coltype) + & literal(BitString('010111'), coltype) + ) + r1 = connection.execute(q1).first() + + eq_(r1[0], BitString('000010')) + + q2 = select( + literal(BitString('010101'), coltype) & BitString('001011') + ) + r2 = connection.execute(q2).first() + eq_(r2[0], BitString('000001')) + + def test_or_operator(self, connection): + coltype = BIT(6) + + q1 = select( + literal(BitString('001010'), coltype) + & literal(BitString('010111'), coltype) + ) + r1 = connection.execute(q1).first() + + eq_(r1[0], BitString('011111')) + + q2 = select( + literal(BitString('010101')) & BitString('001001') + ) + r2 = connection.execute(q2).first() + eq_(r2[0], BitString('011101')) + + def test_xor_operator(self, connection): + coltype = BIT(6) + + q1 = select( + literal(BitString('001010'), coltype) + & literal(BitString('010111'), coltype) + ) + r1 = connection.execute(q1).first() + eq_(r1[0], BitString('001101')) + + q2 = select( + literal(BitString('010101'), coltype) & BitString('001011') + ) + r2 = connection.execute(q2).first() + eq_(r2[0], BitString('011110')) + + def test_lshift_operator(self, connection): + coltype = BIT(6) + + q = select( + literal(BitString('001010'), coltype), + literal(BitString('001010'), coltype) << 1, + ) + + r = connection.execute(q).first() + eq_(tuple(r), (BitString('001010'), BitString('010100'))) + + def test_rshift_operator(self, connection): + coltype = BIT(6) + + q = select( + literal(BitString('001010'), coltype), + literal(BitString('001010'), coltype) >> 1 + ) + + r = connection.execute(q).first() + eq_(tuple(r), (BitString('001010'), BitString('000101'))) + + class RangeMiscTests(fixtures.TestBase): @testing.combinations( (Range(2, 7), INT4RANGE), @@ -6626,7 +6774,7 @@ class PGInsertManyValuesTest(fixtures.TestBase): @testing.combinations( ("BYTEA", BYTEA(), b"7\xe7\x9f"), - ("BIT", BIT(3), "011"), + ("BIT", BIT(3), BitString("011")), argnames="type_,value", id_="iaa", ) @@ -6659,11 +6807,6 @@ class PGInsertManyValuesTest(fixtures.TestBase): t.create(connection) - if type_._type_affinity is BIT and testing.against("+asyncpg"): - import asyncpg - - value = asyncpg.BitString(value) - result = connection.execute( t.insert().returning( t.c.id,