]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add new str` subclass for postgresql bitstring
authorFederico Caselli <cfederico87@gmail.com>
Fri, 20 Jun 2025 20:28:45 +0000 (22:28 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 2 Jul 2025 02:44:41 +0000 (22:44 -0400)
Adds a new ``str`` subclass :class:`dialects.postgresql.BitString`
representing PostgreSQL bitstrings in python, that includes
functionality for converting to and from ``int`` and ``bytes``, in
addition to implementing utility methods and operators for dealing
with bits.

This new class is returned automatically by the :class:`postgresql.BIT`
type.

Fixes: #10556
Closes: #12594
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12594
Pull-request-sha: da47f40739ff9fc6e75da44bd1663aadf80e93ca

Change-Id: I64685660527c23666f7351b2c393fa86dfb643ea

doc/build/changelog/migration_21.rst
doc/build/changelog/unreleased_21/10556.rst [new file with mode: 0644]
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/bitstring.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/types.py
test/dialect/postgresql/test_bitstring.py [new file with mode: 0644]
test/dialect/postgresql/test_types.py

index dd3d81a1400450a20186412e48e0ee288cea6ecd..5634cdda647f7fddee7f627b3a98140918007c6e 100644 (file)
@@ -417,3 +417,27 @@ would appear in a valid ODBC connection string (i.e., the same as would be
 required if using the connection string directly with ``pyodbc.connect()``).
 
 :ticket:`11250`
+
+.. _change_10556:
+
+Addition of ``BitString`` subclass for handling postgresql ``BIT`` columns
+--------------------------------------------------------------------------
+
+Values of :class:`_postgresql.BIT` columns in the PostgreSQL dialect are
+returned as instances of a new ``str`` subclass,
+:class:`_postgresql.BitString`.  Previously, the value of :class:`_postgresql.BIT`
+columns was driver dependent, with most drivers returning ``str`` instances
+except ``asyncpg``, which used ``asyncpg.BitString``.
+
+With this change, for the ``psycopg``, ``psycopg2``, and ``pg8000`` drivers,
+the new :class:`_postgresql.BitString` type is mostly compatible with ``str``, but
+adds methods for bit manipulation and supports bitwise operators.
+
+As :class:`_postgresql.BitString` is a string subclass, hashability as well
+as equality tests continue to work against plain strings.   This also leaves
+ordering operators intact.
+
+For implementations using the ``asyncpg`` driver, the new type is incompatible with
+the existing ``asyncpg.BitString`` type.
+
+:ticket:`10556`
diff --git a/doc/build/changelog/unreleased_21/10556.rst b/doc/build/changelog/unreleased_21/10556.rst
new file mode 100644 (file)
index 0000000..153b9a9
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: feature, postgresql
+    :tickets: 10556
+
+    Adds a new ``str`` subclass :class:`_postgresql.BitString` representing
+    PostgreSQL bitstrings in python, that includes
+    functionality for converting to and from ``int`` and ``bytes``, in
+    addition to implementing utility methods and operators for dealing with bits.
+
+    This new class is returned automatically by the :class:`postgresql.BIT` type.
+
+    .. seealso::
+
+        :ref:`change_10556`
index 009463e6ee860689d7807eae43cf9a76048fff02..de651a15b4c2b16c6efb90774da800024bfca7e3 100644 (file)
@@ -20,6 +20,23 @@ as well as array literals:
 * :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate
   function syntax.
 
+BIT type
+--------
+
+PostgreSQL's BIT type is a so-called "bit string" that stores a string of
+ones and zeroes.   SQLAlchemy provides the :class:`_postgresql.BIT` type
+to represent columns and expressions of this type, as well as the
+:class:`_postgresql.BitString` value type which is a richly featured ``str``
+subclass that works with :class:`_postgresql.BIT`.
+
+* :class:`_postgresql.BIT` - the PostgreSQL BIT type
+
+* :class:`_postgresql.BitString` - Rich-featured ``str`` subclass returned
+  and accepted for columns and expressions that use :class:`_postgresql.BIT`.
+
+.. versionchanged:: 2.1  :class:`_postgresql.BIT` now works with the newly
+   added :class:`_postgresql.BitString` value type.
+
 .. _postgresql_json_types:
 
 JSON Types
@@ -455,6 +472,9 @@ construction arguments, are as follows:
 
 .. autoclass:: BIT
 
+.. autoclass:: BitString
+    :members:
+
 .. autoclass:: BYTEA
     :members: __init__
 
index e426df71be75c226212ae4506b30e1136d34c71e..677f3b7dd5ce65d63536baa02b075d454f952f26 100644 (file)
@@ -33,6 +33,7 @@ from .base import SMALLINT
 from .base import TEXT
 from .base import UUID
 from .base import VARCHAR
+from .bitstring import BitString
 from .dml import Insert
 from .dml import insert
 from .ext import aggregate_order_by
@@ -154,6 +155,7 @@ __all__ = (
     "JSONPATH",
     "Any",
     "All",
+    "BitString",
     "DropEnumType",
     "DropDomainType",
     "CreateDomainType",
index 3d6aae91764f6fe8b38ed45f8ba9d8714d025ee3..6b9bb0677daddcdba3c7cc3acaee93105e43d16e 100644 (file)
@@ -207,6 +207,7 @@ from .base import PGExecutionContext
 from .base import PGIdentifierPreparer
 from .base import REGCLASS
 from .base import REGCONFIG
+from .bitstring import BitString
 from .types import BIT
 from .types import BYTEA
 from .types import CITEXT
@@ -242,6 +243,25 @@ class AsyncpgTime(sqltypes.Time):
 class AsyncpgBit(BIT):
     render_bind_cast = True
 
+    def bind_processor(self, dialect):
+        asyncpg_BitString = dialect.dbapi.asyncpg.BitString
+
+        def to_bind(value):
+            if isinstance(value, str):
+                value = BitString(value)
+                value = asyncpg_BitString.from_int(int(value), len(value))
+            return value
+
+        return to_bind
+
+    def result_processor(self, dialect, coltype):
+        def to_result(value):
+            if value is not None:
+                value = BitString.from_int(value.to_int(), length=len(value))
+            return value
+
+        return to_result
+
 
 class AsyncpgByteA(BYTEA):
     render_bind_cast = True
diff --git a/lib/sqlalchemy/dialects/postgresql/bitstring.py b/lib/sqlalchemy/dialects/postgresql/bitstring.py
new file mode 100644 (file)
index 0000000..fb1dc52
--- /dev/null
@@ -0,0 +1,327 @@
+# dialects/postgresql/bitstring.py
+# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
+
+import math
+from typing import Any
+from typing import cast
+from typing import Literal
+from typing import SupportsIndex
+
+
+class BitString(str):
+    """Represent a PostgreSQL bit string in python.
+
+    This object is used by the :class:`_postgresql.BIT` type when returning
+    values.   :class:`_postgresql.BitString` values may also be constructed
+    directly and used with :class:`_postgresql.BIT` columns::
+
+        from sqlalchemy.dialects.postgresql import BitString
+
+        with engine.connect() as conn:
+            conn.execute(table.insert(), {"data": BitString("011001101")})
+
+    .. versionadded:: 2.1
+
+    """
+
+    _DIGITS = frozenset("01")
+
+    def __new__(cls, _value: str, _check: bool = True) -> BitString:
+        if isinstance(_value, BitString):
+            return _value
+        elif _check and cls._DIGITS.union(_value) > cls._DIGITS:
+            raise ValueError("BitString must only contain '0' and '1' chars")
+        else:
+            return super().__new__(cls, _value)
+
+    @classmethod
+    def from_int(cls, value: int, length: int) -> BitString:
+        """Returns a BitString consisting of the bits in the integer ``value``.
+        A ``ValueError`` is raised if ``value`` is not a non-negative integer.
+
+        If the provided ``value`` can not be represented in a bit string
+        of at most ``length``, a ``ValueError`` will be raised. The bitstring
+        will be padded on the left by ``'0'`` to bits to produce a
+        bitstring of the desired length.
+        """
+        if value < 0:
+            raise ValueError("value must be non-negative")
+        if length < 0:
+            raise ValueError("length must be non-negative")
+
+        template_str = f"{{0:0{length}b}}" if length > 0 else ""
+        r = template_str.format(value)
+
+        if (length == 0 and value > 0) or len(r) > length:
+            raise ValueError(
+                f"Cannot encode {value} as a BitString of length {length}"
+            )
+
+        return cls(r)
+
+    @classmethod
+    def from_bytes(cls, value: bytes, length: int = -1) -> BitString:
+        """Returns a ``BitString`` consisting of the bits in the given
+        ``value`` bytes.
+
+        If ``length`` is provided, then the length of the provided string
+        will be exactly ``length``, with ``'0'`` bits inserted at the left of
+        the string in order to produce a value of the required length.
+        If the bits obtained by omitting the leading ``'0'`` bits of ``value``
+        cannot be represented in a string of this length a ``ValueError``
+        will be raised.
+        """
+        str_v: str = "".join(f"{int(c):08b}" for c in value)
+        if length >= 0:
+            str_v = str_v.lstrip("0")
+
+            if len(str_v) > length:
+                raise ValueError(
+                    f"Cannot encode {value!r} as a BitString of "
+                    f"length {length}"
+                )
+            str_v = str_v.zfill(length)
+
+        return cls(str_v)
+
+    def get_bit(self, index: int) -> Literal["0", "1"]:
+        """Returns the value of the flag at the given
+        index::
+
+            BitString("0101").get_flag(4) == "1"
+        """
+        return cast(Literal["0", "1"], super().__getitem__(index))
+
+    @property
+    def bit_length(self) -> int:
+        return len(self)
+
+    @property
+    def octet_length(self) -> int:
+        return math.ceil(len(self) / 8)
+
+    def has_bit(self, index: int) -> bool:
+        return self.get_bit(index) == "1"
+
+    def set_bit(
+        self, index: int, value: bool | int | Literal["0", "1"]
+    ) -> BitString:
+        """Set the bit at index to the given value.
+
+        If value is an int, then it is considered to be '1' iff nonzero.
+        """
+        if index < 0 or index >= len(self):
+            raise IndexError("BitString index out of range")
+
+        if isinstance(value, (bool, int)):
+            value = "1" if value else "0"
+
+        if self.get_bit(index) == value:
+            return self
+
+        return BitString(
+            "".join([self[:index], value, self[index + 1 :]]), False
+        )
+
+    def lstrip(self, char: str | None = None) -> BitString:
+        """Returns a copy of the BitString with leading characters removed.
+
+        If omitted or None, 'chars' defaults '0'::
+
+            BitString("00010101000").lstrip() == BitString("00010101")
+            BitString("11110101111").lstrip("1") == BitString("1111010")
+        """
+        if char is None:
+            char = "0"
+        return BitString(super().lstrip(char), False)
+
+    def rstrip(self, char: str | None = "0") -> BitString:
+        """Returns a copy of the BitString with trailing characters removed.
+
+        If omitted or None, ``'char'`` defaults to "0"::
+
+            BitString("00010101000").rstrip() == BitString("10101000")
+            BitString("11110101111").rstrip("1") == BitString("10101111")
+        """
+        if char is None:
+            char = "0"
+        return BitString(super().rstrip(char), False)
+
+    def strip(self, char: str | None = "0") -> BitString:
+        """Returns a copy of the BitString with both leading and trailing
+        characters removed.
+        If omitted or None, ``'char'`` defaults to ``"0"``::
+
+            BitString("00010101000").rstrip() == BitString("10101")
+            BitString("11110101111").rstrip("1") == BitString("1010")
+        """
+        if char is None:
+            char = "0"
+        return BitString(super().strip(char))
+
+    def removeprefix(self, prefix: str, /) -> BitString:
+        return BitString(super().removeprefix(prefix), False)
+
+    def removesuffix(self, suffix: str, /) -> BitString:
+        return BitString(super().removesuffix(suffix), False)
+
+    def replace(
+        self,
+        old: str,
+        new: str,
+        count: SupportsIndex = -1,
+    ) -> BitString:
+        new = BitString(new)
+        return BitString(super().replace(old, new, count), False)
+
+    def split(
+        self,
+        sep: str | None = None,
+        maxsplit: SupportsIndex = -1,
+    ) -> list[str]:
+        return [BitString(word) for word in super().split(sep, maxsplit)]
+
+    def zfill(self, width: SupportsIndex) -> BitString:
+        return BitString(super().zfill(width), False)
+
+    def __repr__(self) -> str:
+        return f'BitString("{self.__str__()}")'
+
+    def __int__(self) -> int:
+        return int(self, 2) if self else 0
+
+    def to_bytes(self, length: int = -1) -> bytes:
+        return int(self).to_bytes(
+            length if length >= 0 else self.octet_length, byteorder="big"
+        )
+
+    def __bytes__(self) -> bytes:
+        return self.to_bytes()
+
+    def __getitem__(
+        self, key: SupportsIndex | slice[Any, Any, Any]
+    ) -> BitString:
+        return BitString(super().__getitem__(key), False)
+
+    def __add__(self, o: str) -> BitString:
+        """Return self + o"""
+        if not isinstance(o, str):
+            raise TypeError(
+                f"Can only concatenate str (not '{type(self)}') to BitString"
+            )
+        return BitString("".join([self, o]))
+
+    def __radd__(self, o: str) -> BitString:
+        if not isinstance(o, str):
+            raise TypeError(
+                f"Can only concatenate str (not '{type(self)}') to BitString"
+            )
+        return BitString("".join([o, self]))
+
+    def __lshift__(self, amount: int) -> BitString:
+        """Shifts each the bitstring to the left by the given amount.
+        String length is preserved::
+
+            BitString("000101") << 1 == BitString("001010")
+        """
+        return BitString(
+            "".join([self, *("0" for _ in range(amount))])[-len(self) :], False
+        )
+
+    def __rshift__(self, amount: int) -> BitString:
+        """Shifts each bit in the bitstring to the right by the given amount.
+        String length is preserved::
+
+            BitString("101") >> 1 == BitString("010")
+        """
+        return BitString(self[:-amount], False).zfill(width=len(self))
+
+    def __invert__(self) -> BitString:
+        """Inverts (~) each bit in the
+        bitstring::
+
+            ~BitString("01010") == BitString("10101")
+        """
+        return BitString("".join("1" if x == "0" else "0" for x in self))
+
+    def __and__(self, o: str) -> BitString:
+        """Performs a bitwise and (``&``) with the given operand.
+        A ``ValueError`` is raised if the operand is not the same length.
+
+        e.g.::
+
+            BitString("011") & BitString("011") == BitString("010")
+        """
+
+        if not isinstance(o, str):
+            return NotImplemented
+        o = BitString(o)
+        if len(self) != len(o):
+            raise ValueError("Operands must be the same length")
+
+        return BitString(
+            "".join(
+                "1" if (x == "1" and y == "1") else "0"
+                for x, y in zip(self, o)
+            ),
+            False,
+        )
+
+    def __or__(self, o: str) -> BitString:
+        """Performs a bitwise or (``|``) with the given operand.
+        A ``ValueError`` is raised if the operand is not the same length.
+
+        e.g.::
+
+            BitString("011") | BitString("010") == BitString("011")
+        """
+        if not isinstance(o, str):
+            return NotImplemented
+
+        if len(self) != len(o):
+            raise ValueError("Operands must be the same length")
+
+        o = BitString(o)
+        return BitString(
+            "".join(
+                "1" if (x == "1" or y == "1") else "0"
+                for (x, y) in zip(self, o)
+            ),
+            False,
+        )
+
+    def __xor__(self, o: str) -> BitString:
+        """Performs a bitwise xor (``^``) with the given operand.
+        A ``ValueError`` is raised if the operand is not the same length.
+
+        e.g.::
+
+            BitString("011") ^ BitString("010") == BitString("001")
+        """
+
+        if not isinstance(o, BitString):
+            return NotImplemented
+
+        if len(self) != len(o):
+            raise ValueError("Operands must be the same length")
+
+        return BitString(
+            "".join(
+                (
+                    "1"
+                    if ((x == "1" and y == "0") or (x == "0" and y == "1"))
+                    else "0"
+                )
+                for (x, y) in zip(self, o)
+            ),
+            False,
+        )
+
+    __rand__ = __and__
+    __ror__ = __or__
+    __rxor__ = __xor__
index ff5e967ef6fb3d27106af8371a0b2e93afed3359..96e5644572c562c318ac9ca874645c6f6a051bb8 100644 (file)
@@ -8,21 +8,25 @@ from __future__ import annotations
 
 import datetime as dt
 from typing import Any
+from typing import Literal
 from typing import Optional
 from typing import overload
 from typing import Type
 from typing import TYPE_CHECKING
 from uuid import UUID as _python_UUID
 
+from .bitstring import BitString
 from ...sql import sqltypes
 from ...sql import type_api
-from ...util.typing import Literal
+from ...sql.type_api import TypeEngine
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
+    from ...sql.operators import ColumnOperators
     from ...sql.operators import OperatorType
+    from ...sql.type_api import _BindProcessorType
     from ...sql.type_api import _LiteralProcessorType
-    from ...sql.type_api import TypeEngine
+    from ...sql.type_api import _ResultProcessorType
 
 _DECIMAL_TYPES = (1231, 1700)
 _FLOAT_TYPES = (700, 701, 1021, 1022)
@@ -256,7 +260,18 @@ class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
 PGInterval = INTERVAL
 
 
-class BIT(sqltypes.TypeEngine[int]):
+class BIT(sqltypes.TypeEngine[BitString]):
+    """Represent the PostgreSQL BIT type.
+
+    The :class:`_postgresql.BIT` type yields values in the form of the
+    :class:`_postgresql.BitString` Python value type.
+
+    .. versionchanged:: 2.1  The :class:`_postgresql.BIT` type now works
+       with :class:`_postgresql.BitString` values rather than plain strings.
+
+    """
+
+    render_bind_cast = True
     __visit_name__ = "BIT"
 
     def __init__(
@@ -270,6 +285,58 @@ class BIT(sqltypes.TypeEngine[int]):
             self.length = length or 1
         self.varying = varying
 
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> _BindProcessorType[BitString]:
+        def bound_value(value: Any) -> Any:
+            if isinstance(value, BitString):
+                return str(value)
+            return value
+
+        return bound_value
+
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> _ResultProcessorType[BitString]:
+        def from_result_value(value: Any) -> Any:
+            if value is not None:
+                value = BitString(value)
+            return value
+
+        return from_result_value
+
+    def coerce_compared_value(
+        self, op: OperatorType | None, value: Any
+    ) -> TypeEngine[Any]:
+        if isinstance(value, str):
+            return self
+        return super().coerce_compared_value(op, value)
+
+    @property
+    def python_type(self) -> type[Any]:
+        return BitString
+
+    class comparator_factory(TypeEngine.Comparator[BitString]):
+        def __lshift__(self, other: Any) -> ColumnOperators:
+            return self.bitwise_lshift(other)
+
+        def __rshift__(self, other: Any) -> ColumnOperators:
+            return self.bitwise_rshift(other)
+
+        def __and__(self, other: Any) -> ColumnOperators:
+            return self.bitwise_and(other)
+
+        def __or__(self, other: Any) -> ColumnOperators:
+            return self.bitwise_or(other)
+
+        # NOTE: __xor__ is not defined on sql.operators.ColumnOperators.
+        # Use `bitwise_xor` directly instead.
+        # def __xor__(self, other: Any) -> ColumnOperators:
+        #     return self.bitwise_xor(other)
+
+        def __invert__(self) -> ColumnOperators:
+            return self.bitwise_not()
+
 
 PGBit = BIT
 
diff --git a/test/dialect/postgresql/test_bitstring.py b/test/dialect/postgresql/test_bitstring.py
new file mode 100644 (file)
index 0000000..2ca8c72
--- /dev/null
@@ -0,0 +1,165 @@
+from sqlalchemy import testing
+from sqlalchemy.dialects.postgresql import BitString
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.assertions import assert_raises
+from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.assertions import is_
+from sqlalchemy.testing.assertions import is_false
+from sqlalchemy.testing.assertions import is_not
+from sqlalchemy.testing.assertions import is_true
+
+
+class BitStringTests(fixtures.TestBase):
+
+    @testing.combinations(
+        lambda: BitString("111") == BitString("111"),
+        lambda: BitString("111") == "111",
+        lambda: BitString("111") != BitString("110"),
+        lambda: BitString("111") != "110",
+        lambda: hash(BitString("011")) == hash(BitString("011")),
+        lambda: hash(BitString("011")) == hash("011"),
+        lambda: BitString("011")[1] == BitString("1"),
+        lambda: BitString("010") > BitString("001"),
+        lambda: "010" > BitString("001"),
+        lambda: "011" <= BitString("011"),
+        lambda: "011" <= BitString("101"),
+    )
+    def test_comparisons(self, case):
+        is_true(case())
+
+    def test_sorting(self):
+        eq_(
+            sorted([BitString("110"), BitString("010"), "111", "101"]),
+            [BitString("010"), "101", BitString("110"), "111"],
+        )
+
+    def test_str_conversion(self):
+        x = BitString("1110111")
+        eq_(str(x), "1110111")
+
+        assert_raises(ValueError, lambda: BitString("1246"))
+
+    def test_same_instance_returned(self):
+        x = BitString("1110111")
+        y = BitString("1110111")
+        z = BitString(x)
+
+        eq_(x, y)
+        eq_(x, z)
+
+        is_not(x, y)
+        is_(x, z)
+
+    @testing.combinations(
+        (0, 0, BitString("")),
+        (0, 1, BitString("0")),
+        (1, 1, BitString("1")),
+        (1, 0, ValueError),
+        (1, -1, ValueError),
+        (2, 1, ValueError),
+        (-1, 4, ValueError),
+        (1, 4, BitString("0001")),
+        (1, 10, BitString("0000000001")),
+        (127, 8, BitString("01111111")),
+        (127, 10, BitString("0001111111")),
+        (1404, 8, ValueError),
+        (1404, 12, BitString("010101111100")),
+        argnames="source, bitlen, result_or_error",
+    )
+    def test_int_conversion(self, source, bitlen, result_or_error):
+        if isinstance(result_or_error, type):
+            assert_raises(
+                result_or_error, lambda: BitString.from_int(source, bitlen)
+            )
+            return
+
+        result = result_or_error
+
+        bits = BitString.from_int(source, bitlen)
+        eq_(bits, result)
+        eq_(int(bits), source)
+
+    @testing.combinations(
+        (b"", -1, BitString("")),
+        (b"", 4, BitString("0000")),
+        (b"\x00", 1, BitString("0")),
+        (b"\x01", 1, BitString("1")),
+        (b"\x01", 4, BitString("0001")),
+        (b"\x01", 10, BitString("0000000001")),
+        (b"\x01", -1, BitString("00000001")),
+        (b"\xff", 10, BitString("0011111111")),
+        (b"\xaf\x04", 8, ValueError),
+        (b"\xaf\x04", 16, BitString("1010111100000100")),
+        (b"\xaf\x04", 20, BitString("00001010111100000100")),
+        argnames="source, bitlen, result_or_error",
+    )
+    def test_bytes_conversion(self, source, bitlen, result_or_error):
+        if isinstance(result_or_error, type):
+            assert_raises(
+                result_or_error,
+                lambda: BitString.from_bytes(source, length=bitlen),
+            )
+            return
+        result = result_or_error
+
+        bits = BitString.from_bytes(source, bitlen)
+        eq_(bits, result)
+
+        # Expecting a roundtrip conversion in this case is nonsensical
+        if source == b"" and bitlen > 0:
+            return
+        eq_(bits.to_bytes(len(source)), source)
+
+    def test_get_set_bit(self):
+        eq_(BitString("1010").get_bit(2), "1")
+        eq_(BitString("0101").get_bit(2), "0")
+        assert_raises(IndexError, lambda: BitString("0").get_bit(1))
+
+        eq_(BitString("0101").set_bit(3, "0"), BitString("0100"))
+        eq_(BitString("0101").set_bit(3, "1"), BitString("0101"))
+        assert_raises(IndexError, lambda: BitString("1111").set_bit(5, "1"))
+
+    def test_string_methods(self):
+
+        eq_(BitString("01100").lstrip(), BitString("1100"))
+        eq_(BitString("01100").rstrip(), BitString("011"))
+        eq_(BitString("01100").strip(), BitString("11"))
+
+        eq_(BitString("11100").removeprefix("111"), BitString("00"))
+        eq_(BitString("11100").removeprefix("0"), BitString("11100"))
+
+        eq_(BitString("11100").removesuffix("10"), BitString("11100"))
+        eq_(BitString("11100").removesuffix("00"), BitString("111"))
+
+        eq_(
+            BitString("010101011").replace("0101", "11", 1),
+            BitString("1101011"),
+        )
+        eq_(
+            BitString("01101101").split("1", 2),
+            [BitString("0"), BitString(""), BitString("01101")],
+        )
+
+        eq_(BitString("0110").split("11"), [BitString("0"), BitString("0")])
+        eq_(BitString("111").zfill(8), BitString("00000111"))
+
+    def test_string_operators(self):
+        is_true("1" in BitString("001"))
+        is_true("0" in BitString("110"))
+        is_false("1" in BitString("000"))
+
+        is_true("001" in BitString("01001"))
+        is_true(BitString("001") in BitString("01001"))
+        is_false(BitString("000") in BitString("01001"))
+
+        eq_(BitString("010") + "001", BitString("010001"))
+        eq_("001" + BitString("010"), BitString("001010"))
+
+    def test_bitwise_operators(self):
+        eq_(~BitString("0101"), BitString("1010"))
+        eq_(BitString("010") & BitString("011"), BitString("010"))
+        eq_(BitString("010") | BitString("011"), BitString("011"))
+        eq_(BitString("010") ^ BitString("011"), BitString("001"))
+
+        eq_(BitString("001100") << 2, BitString("110000"))
+        eq_(BitString("001100") >> 2, BitString("000011"))
index 6151ed2dcc00e8efd7b1595b2f25686f279784c3..2e0d1ea5f6ba30231b1fd98d808be27caefef51b 100644 (file)
@@ -47,6 +47,7 @@ from sqlalchemy.dialects.postgresql import array_agg
 from sqlalchemy.dialects.postgresql import asyncpg
 from sqlalchemy.dialects.postgresql import base
 from sqlalchemy.dialects.postgresql import BIT
+from sqlalchemy.dialects.postgresql import BitString
 from sqlalchemy.dialects.postgresql import BYTEA
 from sqlalchemy.dialects.postgresql import CITEXT
 from sqlalchemy.dialects.postgresql import DATEMULTIRANGE
@@ -3523,7 +3524,9 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
             metadata,
             Column("id", postgresql.UUID, primary_key=True),
             Column("flag", postgresql.BIT),
-            Column("bitstring", postgresql.BIT(4)),
+            Column("bitstring_varying", postgresql.BIT(varying=True)),
+            Column("bitstring_varying_6", postgresql.BIT(6, varying=True)),
+            Column("bitstring_4", postgresql.BIT(4)),
             Column("addr", postgresql.INET),
             Column("addr2", postgresql.MACADDR),
             Column("addr4", postgresql.MACADDR8),
@@ -3551,7 +3554,18 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
         self.assert_tables_equal(special_types_table, t, strict_types=True)
         assert t.c.plain_interval.type.precision is None
         assert t.c.precision_interval.type.precision == 3
-        assert t.c.bitstring.type.length == 4
+
+        assert t.c.flag.type.varying is False
+        assert t.c.flag.type.length == 1
+
+        assert t.c.bitstring_varying.type.varying is True
+        assert t.c.bitstring_varying.type.length is None
+
+        assert t.c.bitstring_varying_6.type.varying is True
+        assert t.c.bitstring_varying_6.type.length == 6
+
+        assert t.c.bitstring_4.type.varying is False
+        assert t.c.bitstring_4.type.length == 4
 
     @testing.combinations(
         (postgresql.INET, "127.0.0.1"),
@@ -3581,6 +3595,41 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
             "test",
         )
 
+    @testing.combinations(
+        (postgresql.BIT(varying=True), BitString("")),
+        (postgresql.BIT(varying=True), BitString("1101010101")),
+        (postgresql.BIT(6, varying=True), BitString("")),
+        (postgresql.BIT(6, varying=True), BitString("010101")),
+        (postgresql.BIT(1), BitString("0")),
+        (postgresql.BIT(4), BitString("0010")),
+        (postgresql.BIT(4), "0010"),
+        argnames="column_type, value",
+    )
+    def test_bitstring_round_trip(
+        self, connection, metadata, column_type, value
+    ):
+        t = Table(
+            "bits",
+            metadata,
+            Column("name", String),
+            Column("value", column_type),
+        )
+        t.create(connection)
+
+        connection.execute(t.insert(), {"name": "test", "value": value})
+        eq_(
+            connection.scalar(select(t.c.name).where(t.c.value == value)),
+            "test",
+        )
+
+        result_value = connection.scalar(
+            select(t.c.value).where(t.c.name == "test")
+        )
+        assert isinstance(result_value, BitString)
+        eq_(result_value, value)
+        eq_(result_value, str(value))
+        eq_(str(result_value), str(value))
+
     def test_tsvector_round_trip(self, connection, metadata):
         t = Table("t1", metadata, Column("data", postgresql.TSVECTOR))
         t.create(connection)
@@ -4182,6 +4231,103 @@ class HStoreRoundTripTest(fixtures.TablesTest):
             eq_(s.query(Data.data, Data).all(), [(d.data, d)])
 
 
+class BitTests(fixtures.TestBase):
+    __backend__ = True
+    __only_on__ = "postgresql"
+
+    def test_concatenation(self, connection):
+        coltype = BIT(varying=True)
+
+        q = select(
+            literal(BitString("1111"), coltype).concat(BitString("0000"))
+        )
+        r = connection.execute(q).first()
+        eq_(r[0], BitString("11110000"))
+
+    def test_invert_operator(self, connection):
+        coltype = BIT(4)
+
+        q = select(literal(BitString("0010"), coltype).bitwise_not())
+        r = connection.execute(q).first()
+
+        eq_(r[0], BitString("1101"))
+
+    def test_and_operator(self, connection):
+        coltype = BIT(6)
+
+        q1 = select(
+            literal(BitString("001010"), coltype)
+            & literal(BitString("010111"), coltype)
+        )
+        r1 = connection.execute(q1).first()
+
+        eq_(r1[0], BitString("000010"))
+
+        q2 = select(
+            literal(BitString("010101"), coltype) & BitString("001011")
+        )
+        r2 = connection.execute(q2).first()
+        eq_(r2[0], BitString("000001"))
+
+    def test_or_operator(self, connection):
+        coltype = BIT(6)
+
+        q1 = select(
+            literal(BitString("001010"), coltype)
+            | literal(BitString("010111"), coltype)
+        )
+        r1 = connection.execute(q1).first()
+
+        eq_(r1[0], BitString("011111"))
+
+        q2 = select(
+            literal(BitString("010101"), coltype) | BitString("001011")
+        )
+        r2 = connection.execute(q2).first()
+        eq_(r2[0], BitString("011111"))
+
+    def test_xor_operator(self, connection):
+        coltype = BIT(6)
+
+        q1 = select(
+            literal(BitString("001010"), coltype).bitwise_xor(
+                literal(BitString("010111"), coltype)
+            )
+        )
+        r1 = connection.execute(q1).first()
+        eq_(r1[0], BitString("011101"))
+
+        q2 = select(
+            literal(BitString("010101"), coltype).bitwise_xor(
+                BitString("001011")
+            )
+        )
+        r2 = connection.execute(q2).first()
+        eq_(r2[0], BitString("011110"))
+
+    def test_lshift_operator(self, connection):
+        coltype = BIT(6)
+
+        q = select(
+            literal(BitString("001010"), coltype),
+            literal(BitString("001010"), coltype) << 1,
+        )
+
+        r = connection.execute(q).first()
+        eq_(tuple(r), (BitString("001010"), BitString("010100")))
+
+    def test_rshift_operator(self, connection):
+        coltype = BIT(6)
+
+        q = select(
+            literal(BitString("001010"), coltype),
+            literal(BitString("001010"), coltype) >> 1,
+        )
+
+        r = connection.execute(q).first()
+        eq_(tuple(r), (BitString("001010"), BitString("000101")))
+
+
 class RangeMiscTests(fixtures.TestBase):
     @testing.combinations(
         (Range(2, 7), INT4RANGE),
@@ -6639,7 +6785,7 @@ class PGInsertManyValuesTest(fixtures.TestBase):
 
     @testing.combinations(
         ("BYTEA", BYTEA(), b"7\xe7\x9f"),
-        ("BIT", BIT(3), "011"),
+        ("BIT", BIT(3), BitString("011")),
         argnames="type_,value",
         id_="iaa",
     )
@@ -6672,11 +6818,6 @@ class PGInsertManyValuesTest(fixtures.TestBase):
 
         t.create(connection)
 
-        if type_._type_affinity is BIT and testing.against("+asyncpg"):
-            import asyncpg
-
-            value = asyncpg.BitString(value)
-
         result = connection.execute(
             t.insert().returning(
                 t.c.id,