]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
10556: Create new BitString subclass for bitstrings for PG drivers
authorThomas Stephenson <ovangle@gmail.com>
Thu, 15 May 2025 13:21:49 +0000 (23:21 +1000)
committerThomas Stephenson <ovangle@gmail.com>
Thu, 15 May 2025 13:21:49 +0000 (23:21 +1000)
- Added custom `BitString` implementation which inherits from `str`
  and overrides methods and operators appropriately and
  implements the bitwise operations in a manner consistent with
  postgresql. Also exposes class and instance methods ensuring
  instances can be converted to/from `int` and `bytes` sensibly.

- Added default implementations of `bind_processor` and
  `result_processor` to `postgresql.BIT` type to convert `BitString`
  to/from string for PG drivers.

- Override `bind_processor` and `result_processor` in AsyncpgBit
  in order to convert `BitString` to/from `asynpg.BitString`
  in `asyncpg` driver.

- Added support for bitwise comparators to postgresql `BIT` instances

Fixes: 10556
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/bitstring.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/types.py
test/dialect/postgresql/test_bitstring.py [new file with mode: 0644]
test/dialect/postgresql/test_types.py

index e426df71be75c226212ae4506b30e1136d34c71e..677f3b7dd5ce65d63536baa02b075d454f952f26 100644 (file)
@@ -33,6 +33,7 @@ from .base import SMALLINT
 from .base import TEXT
 from .base import UUID
 from .base import VARCHAR
+from .bitstring import BitString
 from .dml import Insert
 from .dml import insert
 from .ext import aggregate_order_by
@@ -154,6 +155,7 @@ __all__ = (
     "JSONPATH",
     "Any",
     "All",
+    "BitString",
     "DropEnumType",
     "DropDomainType",
     "CreateDomainType",
index 3d6aae91764f6fe8b38ed45f8ba9d8714d025ee3..12be23635d3a1edf25781e47dcb033a628152d28 100644 (file)
@@ -207,6 +207,7 @@ from .base import PGExecutionContext
 from .base import PGIdentifierPreparer
 from .base import REGCLASS
 from .base import REGCONFIG
+from .bitstring import BitString
 from .types import BIT
 from .types import BYTEA
 from .types import CITEXT
@@ -242,6 +243,32 @@ class AsyncpgTime(sqltypes.Time):
 class AsyncpgBit(BIT):
     render_bind_cast = True
 
+    def bind_processor(self, dialect):
+        asyncpg_BitString = dialect.dbapi.asyncpg.BitString
+
+        def to_bind(value):
+            print(f'processing bound value \'{value}\'')
+            if isinstance(value, str):
+                value = BitString(value)
+                r = asyncpg_BitString.from_int(int(value), len(value))
+                print(f'returning {r}')
+                return r
+            return value
+
+        return to_bind
+
+    def result_processor(self, dialect, coltype):
+        def to_result(value):
+            if value is not None:
+                print(f'result {value} length {len(value)}')
+                value = BitString.from_int(
+                    value.to_int(),
+                    length=len(value)
+                )
+            return value
+
+        return to_result
+
 
 class AsyncpgByteA(BYTEA):
     render_bind_cast = True
diff --git a/lib/sqlalchemy/dialects/postgresql/bitstring.py b/lib/sqlalchemy/dialects/postgresql/bitstring.py
new file mode 100644 (file)
index 0000000..803255d
--- /dev/null
@@ -0,0 +1,367 @@
+# dialects/postgresql/types.py
+# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
+
+import functools
+import math
+from typing import SupportsIndex, cast
+from typing import Literal
+
+
+@functools.total_ordering
+class BitString(str):
+    """Represent a PostgreSQL bit string.
+
+    e.g.
+        b = BitString('101')
+    """
+
+    def __new__(cls, _value: str, _check=True):
+        if not isinstance(_value, BitString) and (
+            _check and _value and any(c not in "01" for c in _value)
+        ):
+            print(f'value: {_value}')
+            raise ValueError("BitString must only contain '0' and '1' chars")
+        return super().__new__(cls, _value)
+
+    @classmethod
+    def from_int(cls, value: int, length: int):
+        """
+        Returns a BitString consisting of the bits in the little-endian
+        representation of the given python int ``value``. A ``ValueError``
+        is raised if ``value`` is not a non-negative integer.
+
+        If the provided ``value`` can not be represented in a bit string
+        of at most ``length``, a ``ValueError`` will be raised. The bitstring
+        will be padded on the left by ``'0'`` to bits to produce a
+        """
+        if value < 0:
+            raise ValueError("value must be a postive integer")
+
+        if length >= 0:
+            if length > 0:
+                template_str = f'{{0:0{length}b}}' if length > 0 else ''
+                r = template_str.format(value)
+            else:
+                # f'{0:00b}'.format(0) == '0'
+                r = ''
+
+            if len(r) > length:
+                raise ValueError(
+                    f"Cannot encode {value} as a BitString of length {length}"
+                )
+        else:
+            r = '{0:b}'.format(value)
+
+        return cls(r)
+
+    @classmethod
+    def from_bytes(cls, value: bytes, length: int = -1):
+        """
+        Returns a ``BitString`` consisting of the bits in the given ``value``
+        bytes.
+
+        If ``length`` is provided, then the length of the provided string
+        will be exactly ``length``, with ``'0'`` bits inserted at the left of
+        the string in order to produce a value of the required length.
+        If, the bits obtained by omitting the leading `0` bits of ``value``
+        cannot be represented in a string of this length, then a ``ValueError``
+        will be raised.
+        """
+        str_v: str = "".join(f"{c:08b}" for c in value)
+        if length >= 0:
+            str_v = str_v.lstrip('0')
+
+            if len(str_v) >= length:
+                raise ValueError(
+                    f"Cannot encode {value} as a BitString of length {length}"
+                )
+            str_v = str_v.zfill(length)
+
+        return cls(str_v)
+
+    def get_bit(self, index) -> Literal["0", "1"]:
+        """
+        Returns the value of the flag at the given index
+
+        e.g. BitString('0101').get_flag(4) == 1
+        """
+        return cast(Literal["0", "1"], super().__getitem__(index))
+
+    @property
+    def bit_length(self):
+        return len(self)
+
+    @property
+    def octet_length(self):
+        return math.ceil(len(self) / 8)
+
+    def has_bit(self, index) -> bool:
+        return self.get_bit(index) == "1"
+
+    def set_bit(
+        self, index: int, value: bool | int | Literal["0", "1"]
+    ) -> BitString:
+        """
+        Set the bit at index to the given value.
+
+        If value is an int, then it is considered to be '1' iff nonzero.
+        """
+        if index < 0 or index >= len(self):
+            raise IndexError("BitString index out of range")
+
+        if isinstance(value, (bool, int)):
+            value = "1" if value else "0"
+
+        if self.get_bit(index) == value:
+            return self
+
+        return BitString(
+            "".join([self[:index], value, self[index + 1:]]), False
+        )
+
+    # These methods probably should return str and not override the fillchar
+    # def ljust(self, width, fillchar=None) -> BitString:
+    #     """
+    #     Returns the BitString left justified in a string of length width.
+    #     Padding is done using the provided fillchar (default is '0').
+
+    #     If the width is shorter than the length, then the original BitString
+    #     is returned.
+    #     """
+    #     if width < len(self):
+    #         return self
+
+    #     fillchar = fillchar or "0"
+    #     if str(fillchar) not in "01":
+    #         raise ValueError("fillchar must be either '0' or '1'")
+
+    #     return BitString(super().ljust(width, fillchar or "0"))
+
+    # def rjust(self, width, fillchar=None) -> BitString:
+    #     if width < len(self):
+    #         return self
+
+    #     fillchar = fillchar or "0"
+    #     if str(fillchar) not in "01":
+    #         raise ValueError("fillchar must be either '0' or '1'")
+
+    #     return BitString(super().rjust(width, fillchar))
+
+    def lstrip(self, char=None) -> BitString:
+        """
+        Returns a copy of the BitString with leading characters removed.
+
+        If omitted or None, 'chars' defaults '0'
+
+        e.g.
+        BitString('00010101000').lstrip() === BitString('00010101')
+        BitString('11110101111').lstrip('1') === BitString('1111010')
+        """
+        if char is None:
+            char = "0"
+        return BitString(super().lstrip(char), False)
+
+    def rstrip(self, char=None) -> BitString:
+        """
+        Returns a copy of the BitString with trailing characters removed.
+
+        If omitted or None, 'chars' trailing '0'
+
+        e.g.
+        BitString('00010101000').rstrip() === BitString('10101000')
+        BitString('11110101111').rstrip('1') === BitString('10101111')
+        """
+        if char is None:
+            char = "0"
+        return BitString(super().rstrip(char), False)
+
+    def strip(self, char=None) -> BitString:
+        """
+        Returns a copy of the BitString with both leading and trailing
+        characters removed.
+        If ommitted or None, char defaults to '0'
+
+        e.g.
+        BitString('00010101000').rstrip() === BitString('10101')
+        BitString('11110101111').rstrip('1') === BitString('1010')
+        """
+        if char is None:
+            char = "0"
+        return BitString(super().strip(char))
+
+    def partition(self, sep: str = "0") -> tuple[BitString, str, BitString]:
+        """
+        Split the string after the first appearance of sep
+        (which defaults to '0') and return a 3-tuple containing
+        the portion of the string before the separator.
+
+        """
+        prefix, _, suffix = super().partition(sep)
+        return (BitString(prefix, False), sep, BitString(suffix, False))
+
+    def removeprefix(self, prefix: str, /) -> BitString:
+        return BitString(super().removeprefix(prefix), False)
+
+    def removesuffix(self, suffix: str, /) -> BitString:
+        return BitString(super().removesuffix(suffix), False)
+
+    def replace(self, old, new, count: SupportsIndex = -1) -> BitString:
+        new = BitString(new)
+        return BitString(super().replace(old, new, count=count), False)
+
+    def split(  # type: ignore
+            self,
+            sep=None,
+            maxsplit: SupportsIndex = -1,
+    ) -> list[BitString]:
+        return [BitString(word) for word in super().split(sep, maxsplit)]
+
+    def zfill(self, width) -> BitString:
+        return BitString(super().zfill(width), False)
+
+    def __repr__(self):
+        return f'BitString("{self.__str__()}")'
+
+    def __int__(self):
+        return int(self, 2) if self else 0
+
+    def __bytes__(self):
+        s = str(self)
+        bs = []
+        while s:
+            bs.append(int(s[-8:], 2))
+            s = s[:-8]
+        return bytes(bs)
+
+    def __lt__(self, o):
+        if isinstance(o, BitString):
+            return super().__lt__(o)
+        return NotImplemented
+
+    def __eq__(self, o):
+        return isinstance(o, BitString) and super().__eq__(o)
+
+    def __hash__(self):
+        return hash(BitString) ^ super().__hash__()
+
+    def __getitem__(self, key):
+        return BitString(super().__getitem__(key), False)
+
+    def __add__(self, o):
+        """Return self + o"""
+        if not isinstance(o, str):
+            raise TypeError((
+                "Can only concatenate str "
+                "(not '{0}') to BitString"
+            ).format(type(o)))
+        return BitString(''.join([self, o]))
+
+    def __radd__(self, o):
+        if not isinstance(o, str):
+            raise TypeError((
+                "Can only concatenate str (not '{0}') to BitString"
+            ).format(type(o)))
+        return BitString(''.join([o, self]))
+
+    def __lshift__(self, amount: int):
+        """
+        Shifts each the bitstring to the left by the given amount.
+        String length is preserved.
+
+        i.e. BitString('000101') << 1 == BitString('001010')
+        """
+        return BitString(
+            "".join([self, *("0" for _ in range(amount))])[-len(self):], False
+        )
+
+    def __rshift__(self, amount: int):
+        """
+        Shifts each bit in the bitstring to the right by the given amount.
+        String length is preserved.
+
+        e.g. BitString('101') >> 1 == BitString('010')
+        """
+        return BitString(self[:-amount], False).zfill(width=len(self))
+
+    def __invert__(self):
+        """
+        Inverts (~) each bit in the bitstring
+
+        e.g. ~BitString('01010') == BitString('10101')
+        """
+        return BitString("".join("1" if x == "0" else "0" for x in self))
+
+    def __and__(self, o):
+        """
+        Performs a bitwise and (``&``) with the given operand.
+        A ``ValueError`` is raised if the operand is not the same length.
+
+        e.g. BitString('011') & BitString('011') == BitString('010')
+        """
+
+        if not isinstance(o, str):
+            return NotImplemented
+        o = BitString(o)
+        if len(self) != len(o):
+            raise ValueError("Operands must be the same length")
+
+        return BitString(
+            "".join(
+                "1" if (x == "1" and y == "1") else "0"
+                for x, y in zip(self, o)
+            ),
+            False,
+        )
+
+    def __or__(self, o):
+        """
+        Performs a bitwise or (``|``) with the given operand.
+        A ``ValueError`` is raised if the operand is not the same length.
+
+        e.g. BitString('011') | BitString('010') == BitString('011')
+        """
+        if not isinstance(o, str):
+            return NotImplemented
+
+        if len(self) != len(o):
+            raise ValueError("Operands must be the same length")
+
+        o = BitString(o)
+        return BitString(
+            "".join(
+                "1" if (x == "1" or y == "1") else "0"
+                for (x, y) in zip(self, o)
+            ),
+            False,
+        )
+
+    def __xor__(self, o):
+        """
+        Performs a bitwise xor (``^``) with the given operand.
+        A ``ValueError`` is raised if the operand is not the same length.
+
+        e.g. BitString('011') ^ BitString('010') == BitString('001')
+        """
+
+        if not isinstance(o, BitString):
+            return NotImplemented
+
+        if len(self) != len(o):
+            raise ValueError("Operands must be the same length")
+
+        return BitString(
+            "".join(
+                (
+                    "1"
+                    if ((x == "1" and y == "0") or (x == "0" and y == "1"))
+                    else "0"
+                )
+                for (x, y) in zip(self, o)
+            ),
+            False,
+        )
index ff5e967ef6fb3d27106af8371a0b2e93afed3359..4296df35e38e541c27b37ae33f14a5f0d14bb5b3 100644 (file)
@@ -8,6 +8,7 @@ from __future__ import annotations
 
 import datetime as dt
 from typing import Any
+from typing import Literal
 from typing import Optional
 from typing import overload
 from typing import Type
@@ -16,13 +17,14 @@ from uuid import UUID as _python_UUID
 
 from ...sql import sqltypes
 from ...sql import type_api
-from ...util.typing import Literal
+from ...sql.type_api import TypeEngine
+
+from .bitstring import BitString
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
     from ...sql.operators import OperatorType
     from ...sql.type_api import _LiteralProcessorType
-    from ...sql.type_api import TypeEngine
 
 _DECIMAL_TYPES = (1231, 1700)
 _FLOAT_TYPES = (700, 701, 1021, 1022)
@@ -256,7 +258,8 @@ class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
 PGInterval = INTERVAL
 
 
-class BIT(sqltypes.TypeEngine[int]):
+class BIT(sqltypes.TypeEngine[BitString]):
+    render_bind_cast = True
     __visit_name__ = "BIT"
 
     def __init__(
@@ -270,6 +273,44 @@ class BIT(sqltypes.TypeEngine[int]):
             self.length = length or 1
         self.varying = varying
 
+    def bind_processor(self, dialect):
+        def bound_value(value):
+            if isinstance(value, BitString):
+                return str(value)
+            return value
+        return bound_value
+
+    def result_processor(self, dialect, coltype):
+        def from_result_value(value):
+            if value is not None:
+                value = BitString(value)
+            return value
+        return from_result_value
+
+    def coerce_compared_value(self, op, value) -> TypeEngine[Any]:
+        if isinstance(value, str):
+            return self
+        return super().coerce_compared_value(op, value)
+
+    @property
+    def python_type(self):
+        return BitString
+
+    class comparator_factory(TypeEngine.Comparator[BitString]):
+        def __lshift__(self, other: Any):
+            return self.bitwise_lshift(other)
+
+        def __rshift__(self, other: Any):
+            return self.bitwise_rshift(other)
+
+        def __and__(self, other: Any):
+            return self.bitwise_and(other)
+
+        def __or__(self, other: Any):
+            return self.bitwise_or(other)
+
+        def __invert__(self):
+            return self.bitwise_not()
 
 PGBit = BIT
 
diff --git a/test/dialect/postgresql/test_bitstring.py b/test/dialect/postgresql/test_bitstring.py
new file mode 100644 (file)
index 0000000..2a5b879
--- /dev/null
@@ -0,0 +1,105 @@
+from sqlalchemy.testing import fixtures
+
+from sqlalchemy.dialects.postgresql import BitString
+from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.assertions import is_false
+from sqlalchemy.testing.assertions import is_true
+from sqlalchemy.testing.assertions import assert_raises
+
+
+class BitStringTests(fixtures.TestBase):
+
+    def test_ctor(self):
+        x = BitString("1110111")
+        eq_(str(x), "1110111")
+        eq_(int(x), 119)
+
+        eq_(BitString("111"), BitString("111"))
+        is_false(BitString("111") == "111")
+
+        eq_(hash(BitString("011")), hash(BitString("011")))
+        is_false(hash(BitString("011")) == hash("011"))
+
+        eq_(BitString("011")[1], BitString("1"))
+
+    def test_int_conversion(self):
+        assert_raises(ValueError, lambda: BitString.from_int(127, length=6))
+
+        eq_(BitString.from_int(127, length=8), BitString("01111111"))
+        eq_(int(BitString.from_int(127, length=8)), 127)
+
+        eq_(BitString.from_int(119, length=10), BitString("0001110111"))
+        eq_(int(BitString.from_int(119, length=10)), 119)
+
+    def test_bytes_conversion(self):
+        eq_(BitString.from_bytes(b"\x01"), BitString("0000001"))
+        eq_(BitString.from_bytes(b"\x01", 4), BitString("00000001"))
+
+        eq_(BitString.from_bytes(b"\xaf\x04"), BitString("101011110010"))
+        eq_(
+            BitString.from_bytes(b"\xaf\x04", 12),
+            BitString("0000101011110010"),
+        )
+        assert_raises(
+            ValueError, lambda: BitString.from_bytes(b"\xaf\x04", 4), 1
+        )
+
+    def test_get_set_bit(self):
+        eq_(BitString("1010").get_bit(2), "1")
+        eq_(BitString("0101").get_bit(2), "0")
+        assert_raises(IndexError, lambda: BitString("0").get_bit(1))
+
+        eq_(BitString("0101").set_bit(3, "0"), BitString("0100"))
+        eq_(BitString("0101").set_bit(3, "1"), BitString("0101"))
+        assert_raises(IndexError, lambda: BitString("1111").set_bit(5, "1"))
+
+    def test_string_methods(self):
+
+        # Which of these methods should be overridden to produce BitStrings?
+        eq_(BitString("111").center(8), "  111   ")
+
+        eq_(BitString("0101").ljust(8), "0101    ")
+        eq_(BitString("0110").rjust(8), "    0110")
+
+        eq_(BitString("01100").lstrip(), BitString("1100"))
+        eq_(BitString("01100").rstrip(), BitString("011"))
+        eq_(BitString("01100").strip(), BitString("11"))
+
+        eq_(BitString("11100").removeprefix("111"), BitString("00"))
+        eq_(BitString("11100").removeprefix("0"), BitString("11100"))
+
+        eq_(BitString("11100").removesuffix("10"), BitString("11100"))
+        eq_(BitString("11100").removesuffix("00"), BitString("111"))
+
+        eq_(
+            BitString("010101011").replace("0101", "11", 1),
+            BitString("1101011"),
+        )
+        eq_(
+            BitString("01101101").split("1", 2),
+            [BitString("0"), BitString(""), BitString("01101")],
+        )
+
+        eq_(BitString("0110").split("11"), [BitString("0"), BitString("0")])
+        eq_(BitString("111").zfill(8), BitString("00000111"))
+
+    def test_str_ops(self):
+        is_true("1" in BitString("001"))
+        is_true("0" in BitString("110"))
+        is_false("1" in BitString("000"))
+
+        is_true("001" in BitString("01001"))
+        is_true(BitString("001") in BitString("01001"))
+        is_false(BitString("000") in BitString("01001"))
+
+        eq_(BitString("010") + "001", BitString("010001"))
+        eq_("001" + BitString("010"), BitString("001010"))
+
+    def test_bitwise_ops(self):
+        eq_(~BitString("0101"), BitString("1010"))
+        eq_(BitString("010") & BitString("011"), BitString("010"))
+        eq_(BitString("010") | BitString("011"), BitString("011"))
+        eq_(BitString("010") ^ BitString("011"), BitString("001"))
+
+        eq_(BitString("001100") << 2, BitString("110000"))
+        eq_(BitString("001100") >> 2, BitString("000011"))
index 795a897699b331b4c994c594cffdebce610d6cc5..6fa07a109419bbf45407c09b6e4296f7651a5ce4 100644 (file)
@@ -46,6 +46,7 @@ from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import array_agg
 from sqlalchemy.dialects.postgresql import asyncpg
 from sqlalchemy.dialects.postgresql import base
+from sqlalchemy.dialects.postgresql import BitString
 from sqlalchemy.dialects.postgresql import BIT
 from sqlalchemy.dialects.postgresql import BYTEA
 from sqlalchemy.dialects.postgresql import CITEXT
@@ -3514,7 +3515,9 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
             metadata,
             Column("id", postgresql.UUID, primary_key=True),
             Column("flag", postgresql.BIT),
-            Column("bitstring", postgresql.BIT(4)),
+            Column("bitstring_varying", postgresql.BIT(varying=True)),
+            Column("bitstring_varying_6", postgresql.BIT(6, varying=True)),
+            Column("bitstring_4", postgresql.BIT(4)),
             Column("addr", postgresql.INET),
             Column("addr2", postgresql.MACADDR),
             Column("addr4", postgresql.MACADDR8),
@@ -3542,7 +3545,19 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
         self.assert_tables_equal(special_types_table, t, strict_types=True)
         assert t.c.plain_interval.type.precision is None
         assert t.c.precision_interval.type.precision == 3
-        assert t.c.bitstring.type.length == 4
+
+        assert t.c.flag.type.varying is False
+        assert t.c.flag.type.length == 1
+
+        assert t.c.bitstring_varying.type.varying is True
+        assert t.c.bitstring_varying.type.length is None
+
+        assert t.c.bitstring_varying_6.type.varying is True
+        assert t.c.bitstring_varying_6.type.length  == 6
+
+        assert t.c.bitstring_4.type.varying is False
+        assert t.c.bitstring_4.type.length == 4
+
 
     @testing.combinations(
         (postgresql.INET, "127.0.0.1"),
@@ -3568,6 +3583,38 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
             "test",
         )
 
+    @testing.combinations(
+        (postgresql.BIT(varying=True), BitString("")),
+        (postgresql.BIT(varying=True), BitString("1101010101")),
+        (postgresql.BIT(6, varying=True), BitString("")),
+        (postgresql.BIT(6, varying=True), BitString("010101")),
+        (postgresql.BIT(1), BitString("0")),
+        (postgresql.BIT(4), BitString("0010")),
+        (postgresql.BIT(4), "0010"),
+        argnames="column_type, value",
+    )
+    def test_bitstring_round_trip(
+        self, connection, metadata, column_type, value
+    ):
+        t = Table(
+            "bits",
+            metadata,
+            Column("name", String),
+            Column("value", column_type)
+        )
+        t.create(connection)
+
+        connection.execute(t.insert(), {"name": "test", "value": value})
+        print('value type affinity', t.c.value.type._type_affinity)
+        eq_(
+            connection.scalar(select(t.c.name).where(t.c.value == value)),
+            "test",
+        )
+
+        result_value = connection.scalar(select(t.c.value).where(t.c.name == "test"))
+        assert isinstance(result_value, BitString)
+        assert str(result_value) == str(value)
+
     def test_tsvector_round_trip(self, connection, metadata):
         t = Table("t1", metadata, Column("data", postgresql.TSVECTOR))
         t.create(connection)
@@ -4169,6 +4216,107 @@ class HStoreRoundTripTest(fixtures.TablesTest):
             eq_(s.query(Data.data, Data).all(), [(d.data, d)])
 
 
+class BitTests(fixtures.TestBase):
+    def test_concatenation(self, connection):
+        coltype = BIT(varying=True)
+
+        q = select(
+            literal(BitString('1111'), coltype).concat(BitString('0000'))
+        )
+        r = connection.execute(q).first()
+        eq_(r[0], BitString('11110000'))
+
+    @testing.skip("compiler bug")
+    def test_invert_operator(self, connection):
+        coltype = BIT(4)
+
+        q = select(
+            literal(BitString('0010'), coltype).bitwise_not()
+        )
+        r = connection.execute(q).first()
+
+        # Observing r[0] == '1101' here.
+        # See: sql.compiler.Compiler._label_select_column
+        # The unary operator does not "wrap a column expression"
+        # and it isn't a from clause of the select,
+        # the compiler doesn't actually add the column to the select's
+        # result_columns and thus the type's result_processor never gets
+        # called.
+        eq_(r[0], BitString('1101'))
+
+    def test_and_operator(self, connection):
+        coltype = BIT(6)
+
+        q1 = select(
+            literal(BitString('001010'), coltype)
+            & literal(BitString('010111'), coltype)
+        )
+        r1 = connection.execute(q1).first()
+
+        eq_(r1[0], BitString('000010'))
+
+        q2 = select(
+            literal(BitString('010101'), coltype) & BitString('001011')
+        )
+        r2 = connection.execute(q2).first()
+        eq_(r2[0], BitString('000001'))
+
+    def test_or_operator(self, connection):
+        coltype = BIT(6)
+
+        q1 = select(
+            literal(BitString('001010'), coltype)
+            & literal(BitString('010111'), coltype)
+        )
+        r1 = connection.execute(q1).first()
+
+        eq_(r1[0], BitString('011111'))
+
+        q2 = select(
+            literal(BitString('010101')) & BitString('001001')
+        )
+        r2 = connection.execute(q2).first()
+        eq_(r2[0], BitString('011101'))
+
+    def test_xor_operator(self, connection):
+        coltype = BIT(6)
+
+        q1 = select(
+            literal(BitString('001010'), coltype)
+            & literal(BitString('010111'), coltype)
+        )
+        r1 = connection.execute(q1).first()
+        eq_(r1[0], BitString('001101'))
+
+        q2 = select(
+            literal(BitString('010101'), coltype) & BitString('001011')
+        )
+        r2 = connection.execute(q2).first()
+        eq_(r2[0], BitString('011110'))
+
+    def test_lshift_operator(self, connection):
+        coltype = BIT(6)
+
+        q = select(
+            literal(BitString('001010'), coltype),
+            literal(BitString('001010'), coltype) << 1,
+        )
+
+        r = connection.execute(q).first()
+        eq_(tuple(r), (BitString('001010'), BitString('010100')))
+
+    def test_rshift_operator(self, connection):
+        coltype = BIT(6)
+
+        q = select(
+            literal(BitString('001010'), coltype),
+            literal(BitString('001010'), coltype) >> 1
+        )
+
+        r = connection.execute(q).first()
+        eq_(tuple(r), (BitString('001010'), BitString('000101')))
+
+
 class RangeMiscTests(fixtures.TestBase):
     @testing.combinations(
         (Range(2, 7), INT4RANGE),
@@ -6626,7 +6774,7 @@ class PGInsertManyValuesTest(fixtures.TestBase):
 
     @testing.combinations(
         ("BYTEA", BYTEA(), b"7\xe7\x9f"),
-        ("BIT", BIT(3), "011"),
+        ("BIT", BIT(3), BitString("011")),
         argnames="type_,value",
         id_="iaa",
     )
@@ -6659,11 +6807,6 @@ class PGInsertManyValuesTest(fixtures.TestBase):
 
         t.create(connection)
 
-        if type_._type_affinity is BIT and testing.against("+asyncpg"):
-            import asyncpg
-
-            value = asyncpg.BitString(value)
-
         result = connection.execute(
             t.insert().returning(
                 t.c.id,