From: Federico Caselli Date: Fri, 20 Jun 2025 20:28:45 +0000 (+0200) Subject: Add new str` subclass for postgresql bitstring X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=baec9a3ac8fd675307026dca81b96a8f2102efa2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add new str` subclass for postgresql bitstring Adds a new ``str`` subclass :class:`dialects.postgresql.BitString` representing PostgreSQL bitstrings in python, that includes functionality for converting to and from ``int`` and ``bytes``, in addition to implementing utility methods and operators for dealing with bits. This new class is returned automatically by the :class:`postgresql.BIT` type. Fixes: #10556 Closes: #12594 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12594 Pull-request-sha: da47f40739ff9fc6e75da44bd1663aadf80e93ca Change-Id: I64685660527c23666f7351b2c393fa86dfb643ea --- diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index dd3d81a140..5634cdda64 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -417,3 +417,27 @@ would appear in a valid ODBC connection string (i.e., the same as would be required if using the connection string directly with ``pyodbc.connect()``). :ticket:`11250` + +.. _change_10556: + +Addition of ``BitString`` subclass for handling postgresql ``BIT`` columns +-------------------------------------------------------------------------- + +Values of :class:`_postgresql.BIT` columns in the PostgreSQL dialect are +returned as instances of a new ``str`` subclass, +:class:`_postgresql.BitString`. Previously, the value of :class:`_postgresql.BIT` +columns was driver dependent, with most drivers returning ``str`` instances +except ``asyncpg``, which used ``asyncpg.BitString``. + +With this change, for the ``psycopg``, ``psycopg2``, and ``pg8000`` drivers, +the new :class:`_postgresql.BitString` type is mostly compatible with ``str``, but +adds methods for bit manipulation and supports bitwise operators. + +As :class:`_postgresql.BitString` is a string subclass, hashability as well +as equality tests continue to work against plain strings. This also leaves +ordering operators intact. + +For implementations using the ``asyncpg`` driver, the new type is incompatible with +the existing ``asyncpg.BitString`` type. + +:ticket:`10556` diff --git a/doc/build/changelog/unreleased_21/10556.rst b/doc/build/changelog/unreleased_21/10556.rst new file mode 100644 index 0000000000..153b9a95e5 --- /dev/null +++ b/doc/build/changelog/unreleased_21/10556.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: feature, postgresql + :tickets: 10556 + + Adds a new ``str`` subclass :class:`_postgresql.BitString` representing + PostgreSQL bitstrings in python, that includes + functionality for converting to and from ``int`` and ``bytes``, in + addition to implementing utility methods and operators for dealing with bits. + + This new class is returned automatically by the :class:`postgresql.BIT` type. + + .. seealso:: + + :ref:`change_10556` diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 009463e6ee..de651a15b4 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -20,6 +20,23 @@ as well as array literals: * :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate function syntax. +BIT type +-------- + +PostgreSQL's BIT type is a so-called "bit string" that stores a string of +ones and zeroes. SQLAlchemy provides the :class:`_postgresql.BIT` type +to represent columns and expressions of this type, as well as the +:class:`_postgresql.BitString` value type which is a richly featured ``str`` +subclass that works with :class:`_postgresql.BIT`. + +* :class:`_postgresql.BIT` - the PostgreSQL BIT type + +* :class:`_postgresql.BitString` - Rich-featured ``str`` subclass returned + and accepted for columns and expressions that use :class:`_postgresql.BIT`. + +.. versionchanged:: 2.1 :class:`_postgresql.BIT` now works with the newly + added :class:`_postgresql.BitString` value type. + .. _postgresql_json_types: JSON Types @@ -455,6 +472,9 @@ construction arguments, are as follows: .. autoclass:: BIT +.. autoclass:: BitString + :members: + .. autoclass:: BYTEA :members: __init__ diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index e426df71be..677f3b7dd5 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -33,6 +33,7 @@ from .base import SMALLINT from .base import TEXT from .base import UUID from .base import VARCHAR +from .bitstring import BitString from .dml import Insert from .dml import insert from .ext import aggregate_order_by @@ -154,6 +155,7 @@ __all__ = ( "JSONPATH", "Any", "All", + "BitString", "DropEnumType", "DropDomainType", "CreateDomainType", diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 3d6aae9176..6b9bb0677d 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -207,6 +207,7 @@ from .base import PGExecutionContext from .base import PGIdentifierPreparer from .base import REGCLASS from .base import REGCONFIG +from .bitstring import BitString from .types import BIT from .types import BYTEA from .types import CITEXT @@ -242,6 +243,25 @@ class AsyncpgTime(sqltypes.Time): class AsyncpgBit(BIT): render_bind_cast = True + def bind_processor(self, dialect): + asyncpg_BitString = dialect.dbapi.asyncpg.BitString + + def to_bind(value): + if isinstance(value, str): + value = BitString(value) + value = asyncpg_BitString.from_int(int(value), len(value)) + return value + + return to_bind + + def result_processor(self, dialect, coltype): + def to_result(value): + if value is not None: + value = BitString.from_int(value.to_int(), length=len(value)) + return value + + return to_result + class AsyncpgByteA(BYTEA): render_bind_cast = True diff --git a/lib/sqlalchemy/dialects/postgresql/bitstring.py b/lib/sqlalchemy/dialects/postgresql/bitstring.py new file mode 100644 index 0000000000..fb1dc528c7 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/bitstring.py @@ -0,0 +1,327 @@ +# dialects/postgresql/bitstring.py +# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import math +from typing import Any +from typing import cast +from typing import Literal +from typing import SupportsIndex + + +class BitString(str): + """Represent a PostgreSQL bit string in python. + + This object is used by the :class:`_postgresql.BIT` type when returning + values. :class:`_postgresql.BitString` values may also be constructed + directly and used with :class:`_postgresql.BIT` columns:: + + from sqlalchemy.dialects.postgresql import BitString + + with engine.connect() as conn: + conn.execute(table.insert(), {"data": BitString("011001101")}) + + .. versionadded:: 2.1 + + """ + + _DIGITS = frozenset("01") + + def __new__(cls, _value: str, _check: bool = True) -> BitString: + if isinstance(_value, BitString): + return _value + elif _check and cls._DIGITS.union(_value) > cls._DIGITS: + raise ValueError("BitString must only contain '0' and '1' chars") + else: + return super().__new__(cls, _value) + + @classmethod + def from_int(cls, value: int, length: int) -> BitString: + """Returns a BitString consisting of the bits in the integer ``value``. + A ``ValueError`` is raised if ``value`` is not a non-negative integer. + + If the provided ``value`` can not be represented in a bit string + of at most ``length``, a ``ValueError`` will be raised. The bitstring + will be padded on the left by ``'0'`` to bits to produce a + bitstring of the desired length. + """ + if value < 0: + raise ValueError("value must be non-negative") + if length < 0: + raise ValueError("length must be non-negative") + + template_str = f"{{0:0{length}b}}" if length > 0 else "" + r = template_str.format(value) + + if (length == 0 and value > 0) or len(r) > length: + raise ValueError( + f"Cannot encode {value} as a BitString of length {length}" + ) + + return cls(r) + + @classmethod + def from_bytes(cls, value: bytes, length: int = -1) -> BitString: + """Returns a ``BitString`` consisting of the bits in the given + ``value`` bytes. + + If ``length`` is provided, then the length of the provided string + will be exactly ``length``, with ``'0'`` bits inserted at the left of + the string in order to produce a value of the required length. + If the bits obtained by omitting the leading ``'0'`` bits of ``value`` + cannot be represented in a string of this length a ``ValueError`` + will be raised. + """ + str_v: str = "".join(f"{int(c):08b}" for c in value) + if length >= 0: + str_v = str_v.lstrip("0") + + if len(str_v) > length: + raise ValueError( + f"Cannot encode {value!r} as a BitString of " + f"length {length}" + ) + str_v = str_v.zfill(length) + + return cls(str_v) + + def get_bit(self, index: int) -> Literal["0", "1"]: + """Returns the value of the flag at the given + index:: + + BitString("0101").get_flag(4) == "1" + """ + return cast(Literal["0", "1"], super().__getitem__(index)) + + @property + def bit_length(self) -> int: + return len(self) + + @property + def octet_length(self) -> int: + return math.ceil(len(self) / 8) + + def has_bit(self, index: int) -> bool: + return self.get_bit(index) == "1" + + def set_bit( + self, index: int, value: bool | int | Literal["0", "1"] + ) -> BitString: + """Set the bit at index to the given value. + + If value is an int, then it is considered to be '1' iff nonzero. + """ + if index < 0 or index >= len(self): + raise IndexError("BitString index out of range") + + if isinstance(value, (bool, int)): + value = "1" if value else "0" + + if self.get_bit(index) == value: + return self + + return BitString( + "".join([self[:index], value, self[index + 1 :]]), False + ) + + def lstrip(self, char: str | None = None) -> BitString: + """Returns a copy of the BitString with leading characters removed. + + If omitted or None, 'chars' defaults '0':: + + BitString("00010101000").lstrip() == BitString("00010101") + BitString("11110101111").lstrip("1") == BitString("1111010") + """ + if char is None: + char = "0" + return BitString(super().lstrip(char), False) + + def rstrip(self, char: str | None = "0") -> BitString: + """Returns a copy of the BitString with trailing characters removed. + + If omitted or None, ``'char'`` defaults to "0":: + + BitString("00010101000").rstrip() == BitString("10101000") + BitString("11110101111").rstrip("1") == BitString("10101111") + """ + if char is None: + char = "0" + return BitString(super().rstrip(char), False) + + def strip(self, char: str | None = "0") -> BitString: + """Returns a copy of the BitString with both leading and trailing + characters removed. + If omitted or None, ``'char'`` defaults to ``"0"``:: + + BitString("00010101000").rstrip() == BitString("10101") + BitString("11110101111").rstrip("1") == BitString("1010") + """ + if char is None: + char = "0" + return BitString(super().strip(char)) + + def removeprefix(self, prefix: str, /) -> BitString: + return BitString(super().removeprefix(prefix), False) + + def removesuffix(self, suffix: str, /) -> BitString: + return BitString(super().removesuffix(suffix), False) + + def replace( + self, + old: str, + new: str, + count: SupportsIndex = -1, + ) -> BitString: + new = BitString(new) + return BitString(super().replace(old, new, count), False) + + def split( + self, + sep: str | None = None, + maxsplit: SupportsIndex = -1, + ) -> list[str]: + return [BitString(word) for word in super().split(sep, maxsplit)] + + def zfill(self, width: SupportsIndex) -> BitString: + return BitString(super().zfill(width), False) + + def __repr__(self) -> str: + return f'BitString("{self.__str__()}")' + + def __int__(self) -> int: + return int(self, 2) if self else 0 + + def to_bytes(self, length: int = -1) -> bytes: + return int(self).to_bytes( + length if length >= 0 else self.octet_length, byteorder="big" + ) + + def __bytes__(self) -> bytes: + return self.to_bytes() + + def __getitem__( + self, key: SupportsIndex | slice[Any, Any, Any] + ) -> BitString: + return BitString(super().__getitem__(key), False) + + def __add__(self, o: str) -> BitString: + """Return self + o""" + if not isinstance(o, str): + raise TypeError( + f"Can only concatenate str (not '{type(self)}') to BitString" + ) + return BitString("".join([self, o])) + + def __radd__(self, o: str) -> BitString: + if not isinstance(o, str): + raise TypeError( + f"Can only concatenate str (not '{type(self)}') to BitString" + ) + return BitString("".join([o, self])) + + def __lshift__(self, amount: int) -> BitString: + """Shifts each the bitstring to the left by the given amount. + String length is preserved:: + + BitString("000101") << 1 == BitString("001010") + """ + return BitString( + "".join([self, *("0" for _ in range(amount))])[-len(self) :], False + ) + + def __rshift__(self, amount: int) -> BitString: + """Shifts each bit in the bitstring to the right by the given amount. + String length is preserved:: + + BitString("101") >> 1 == BitString("010") + """ + return BitString(self[:-amount], False).zfill(width=len(self)) + + def __invert__(self) -> BitString: + """Inverts (~) each bit in the + bitstring:: + + ~BitString("01010") == BitString("10101") + """ + return BitString("".join("1" if x == "0" else "0" for x in self)) + + def __and__(self, o: str) -> BitString: + """Performs a bitwise and (``&``) with the given operand. + A ``ValueError`` is raised if the operand is not the same length. + + e.g.:: + + BitString("011") & BitString("011") == BitString("010") + """ + + if not isinstance(o, str): + return NotImplemented + o = BitString(o) + if len(self) != len(o): + raise ValueError("Operands must be the same length") + + return BitString( + "".join( + "1" if (x == "1" and y == "1") else "0" + for x, y in zip(self, o) + ), + False, + ) + + def __or__(self, o: str) -> BitString: + """Performs a bitwise or (``|``) with the given operand. + A ``ValueError`` is raised if the operand is not the same length. + + e.g.:: + + BitString("011") | BitString("010") == BitString("011") + """ + if not isinstance(o, str): + return NotImplemented + + if len(self) != len(o): + raise ValueError("Operands must be the same length") + + o = BitString(o) + return BitString( + "".join( + "1" if (x == "1" or y == "1") else "0" + for (x, y) in zip(self, o) + ), + False, + ) + + def __xor__(self, o: str) -> BitString: + """Performs a bitwise xor (``^``) with the given operand. + A ``ValueError`` is raised if the operand is not the same length. + + e.g.:: + + BitString("011") ^ BitString("010") == BitString("001") + """ + + if not isinstance(o, BitString): + return NotImplemented + + if len(self) != len(o): + raise ValueError("Operands must be the same length") + + return BitString( + "".join( + ( + "1" + if ((x == "1" and y == "0") or (x == "0" and y == "1")) + else "0" + ) + for (x, y) in zip(self, o) + ), + False, + ) + + __rand__ = __and__ + __ror__ = __or__ + __rxor__ = __xor__ diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index ff5e967ef6..96e5644572 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -8,21 +8,25 @@ from __future__ import annotations import datetime as dt from typing import Any +from typing import Literal from typing import Optional from typing import overload from typing import Type from typing import TYPE_CHECKING from uuid import UUID as _python_UUID +from .bitstring import BitString from ...sql import sqltypes from ...sql import type_api -from ...util.typing import Literal +from ...sql.type_api import TypeEngine if TYPE_CHECKING: from ...engine.interfaces import Dialect + from ...sql.operators import ColumnOperators from ...sql.operators import OperatorType + from ...sql.type_api import _BindProcessorType from ...sql.type_api import _LiteralProcessorType - from ...sql.type_api import TypeEngine + from ...sql.type_api import _ResultProcessorType _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) @@ -256,7 +260,18 @@ class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): PGInterval = INTERVAL -class BIT(sqltypes.TypeEngine[int]): +class BIT(sqltypes.TypeEngine[BitString]): + """Represent the PostgreSQL BIT type. + + The :class:`_postgresql.BIT` type yields values in the form of the + :class:`_postgresql.BitString` Python value type. + + .. versionchanged:: 2.1 The :class:`_postgresql.BIT` type now works + with :class:`_postgresql.BitString` values rather than plain strings. + + """ + + render_bind_cast = True __visit_name__ = "BIT" def __init__( @@ -270,6 +285,58 @@ class BIT(sqltypes.TypeEngine[int]): self.length = length or 1 self.varying = varying + def bind_processor( + self, dialect: Dialect + ) -> _BindProcessorType[BitString]: + def bound_value(value: Any) -> Any: + if isinstance(value, BitString): + return str(value) + return value + + return bound_value + + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[BitString]: + def from_result_value(value: Any) -> Any: + if value is not None: + value = BitString(value) + return value + + return from_result_value + + def coerce_compared_value( + self, op: OperatorType | None, value: Any + ) -> TypeEngine[Any]: + if isinstance(value, str): + return self + return super().coerce_compared_value(op, value) + + @property + def python_type(self) -> type[Any]: + return BitString + + class comparator_factory(TypeEngine.Comparator[BitString]): + def __lshift__(self, other: Any) -> ColumnOperators: + return self.bitwise_lshift(other) + + def __rshift__(self, other: Any) -> ColumnOperators: + return self.bitwise_rshift(other) + + def __and__(self, other: Any) -> ColumnOperators: + return self.bitwise_and(other) + + def __or__(self, other: Any) -> ColumnOperators: + return self.bitwise_or(other) + + # NOTE: __xor__ is not defined on sql.operators.ColumnOperators. + # Use `bitwise_xor` directly instead. + # def __xor__(self, other: Any) -> ColumnOperators: + # return self.bitwise_xor(other) + + def __invert__(self) -> ColumnOperators: + return self.bitwise_not() + PGBit = BIT diff --git a/test/dialect/postgresql/test_bitstring.py b/test/dialect/postgresql/test_bitstring.py new file mode 100644 index 0000000000..2ca8c72fa1 --- /dev/null +++ b/test/dialect/postgresql/test_bitstring.py @@ -0,0 +1,165 @@ +from sqlalchemy import testing +from sqlalchemy.dialects.postgresql import BitString +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.assertions import assert_raises +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.assertions import is_false +from sqlalchemy.testing.assertions import is_not +from sqlalchemy.testing.assertions import is_true + + +class BitStringTests(fixtures.TestBase): + + @testing.combinations( + lambda: BitString("111") == BitString("111"), + lambda: BitString("111") == "111", + lambda: BitString("111") != BitString("110"), + lambda: BitString("111") != "110", + lambda: hash(BitString("011")) == hash(BitString("011")), + lambda: hash(BitString("011")) == hash("011"), + lambda: BitString("011")[1] == BitString("1"), + lambda: BitString("010") > BitString("001"), + lambda: "010" > BitString("001"), + lambda: "011" <= BitString("011"), + lambda: "011" <= BitString("101"), + ) + def test_comparisons(self, case): + is_true(case()) + + def test_sorting(self): + eq_( + sorted([BitString("110"), BitString("010"), "111", "101"]), + [BitString("010"), "101", BitString("110"), "111"], + ) + + def test_str_conversion(self): + x = BitString("1110111") + eq_(str(x), "1110111") + + assert_raises(ValueError, lambda: BitString("1246")) + + def test_same_instance_returned(self): + x = BitString("1110111") + y = BitString("1110111") + z = BitString(x) + + eq_(x, y) + eq_(x, z) + + is_not(x, y) + is_(x, z) + + @testing.combinations( + (0, 0, BitString("")), + (0, 1, BitString("0")), + (1, 1, BitString("1")), + (1, 0, ValueError), + (1, -1, ValueError), + (2, 1, ValueError), + (-1, 4, ValueError), + (1, 4, BitString("0001")), + (1, 10, BitString("0000000001")), + (127, 8, BitString("01111111")), + (127, 10, BitString("0001111111")), + (1404, 8, ValueError), + (1404, 12, BitString("010101111100")), + argnames="source, bitlen, result_or_error", + ) + def test_int_conversion(self, source, bitlen, result_or_error): + if isinstance(result_or_error, type): + assert_raises( + result_or_error, lambda: BitString.from_int(source, bitlen) + ) + return + + result = result_or_error + + bits = BitString.from_int(source, bitlen) + eq_(bits, result) + eq_(int(bits), source) + + @testing.combinations( + (b"", -1, BitString("")), + (b"", 4, BitString("0000")), + (b"\x00", 1, BitString("0")), + (b"\x01", 1, BitString("1")), + (b"\x01", 4, BitString("0001")), + (b"\x01", 10, BitString("0000000001")), + (b"\x01", -1, BitString("00000001")), + (b"\xff", 10, BitString("0011111111")), + (b"\xaf\x04", 8, ValueError), + (b"\xaf\x04", 16, BitString("1010111100000100")), + (b"\xaf\x04", 20, BitString("00001010111100000100")), + argnames="source, bitlen, result_or_error", + ) + def test_bytes_conversion(self, source, bitlen, result_or_error): + if isinstance(result_or_error, type): + assert_raises( + result_or_error, + lambda: BitString.from_bytes(source, length=bitlen), + ) + return + result = result_or_error + + bits = BitString.from_bytes(source, bitlen) + eq_(bits, result) + + # Expecting a roundtrip conversion in this case is nonsensical + if source == b"" and bitlen > 0: + return + eq_(bits.to_bytes(len(source)), source) + + def test_get_set_bit(self): + eq_(BitString("1010").get_bit(2), "1") + eq_(BitString("0101").get_bit(2), "0") + assert_raises(IndexError, lambda: BitString("0").get_bit(1)) + + eq_(BitString("0101").set_bit(3, "0"), BitString("0100")) + eq_(BitString("0101").set_bit(3, "1"), BitString("0101")) + assert_raises(IndexError, lambda: BitString("1111").set_bit(5, "1")) + + def test_string_methods(self): + + eq_(BitString("01100").lstrip(), BitString("1100")) + eq_(BitString("01100").rstrip(), BitString("011")) + eq_(BitString("01100").strip(), BitString("11")) + + eq_(BitString("11100").removeprefix("111"), BitString("00")) + eq_(BitString("11100").removeprefix("0"), BitString("11100")) + + eq_(BitString("11100").removesuffix("10"), BitString("11100")) + eq_(BitString("11100").removesuffix("00"), BitString("111")) + + eq_( + BitString("010101011").replace("0101", "11", 1), + BitString("1101011"), + ) + eq_( + BitString("01101101").split("1", 2), + [BitString("0"), BitString(""), BitString("01101")], + ) + + eq_(BitString("0110").split("11"), [BitString("0"), BitString("0")]) + eq_(BitString("111").zfill(8), BitString("00000111")) + + def test_string_operators(self): + is_true("1" in BitString("001")) + is_true("0" in BitString("110")) + is_false("1" in BitString("000")) + + is_true("001" in BitString("01001")) + is_true(BitString("001") in BitString("01001")) + is_false(BitString("000") in BitString("01001")) + + eq_(BitString("010") + "001", BitString("010001")) + eq_("001" + BitString("010"), BitString("001010")) + + def test_bitwise_operators(self): + eq_(~BitString("0101"), BitString("1010")) + eq_(BitString("010") & BitString("011"), BitString("010")) + eq_(BitString("010") | BitString("011"), BitString("011")) + eq_(BitString("010") ^ BitString("011"), BitString("001")) + + eq_(BitString("001100") << 2, BitString("110000")) + eq_(BitString("001100") >> 2, BitString("000011")) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 6151ed2dcc..2e0d1ea5f6 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -47,6 +47,7 @@ from sqlalchemy.dialects.postgresql import array_agg from sqlalchemy.dialects.postgresql import asyncpg from sqlalchemy.dialects.postgresql import base from sqlalchemy.dialects.postgresql import BIT +from sqlalchemy.dialects.postgresql import BitString from sqlalchemy.dialects.postgresql import BYTEA from sqlalchemy.dialects.postgresql import CITEXT from sqlalchemy.dialects.postgresql import DATEMULTIRANGE @@ -3523,7 +3524,9 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): metadata, Column("id", postgresql.UUID, primary_key=True), Column("flag", postgresql.BIT), - Column("bitstring", postgresql.BIT(4)), + Column("bitstring_varying", postgresql.BIT(varying=True)), + Column("bitstring_varying_6", postgresql.BIT(6, varying=True)), + Column("bitstring_4", postgresql.BIT(4)), Column("addr", postgresql.INET), Column("addr2", postgresql.MACADDR), Column("addr4", postgresql.MACADDR8), @@ -3551,7 +3554,18 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): self.assert_tables_equal(special_types_table, t, strict_types=True) assert t.c.plain_interval.type.precision is None assert t.c.precision_interval.type.precision == 3 - assert t.c.bitstring.type.length == 4 + + assert t.c.flag.type.varying is False + assert t.c.flag.type.length == 1 + + assert t.c.bitstring_varying.type.varying is True + assert t.c.bitstring_varying.type.length is None + + assert t.c.bitstring_varying_6.type.varying is True + assert t.c.bitstring_varying_6.type.length == 6 + + assert t.c.bitstring_4.type.varying is False + assert t.c.bitstring_4.type.length == 4 @testing.combinations( (postgresql.INET, "127.0.0.1"), @@ -3581,6 +3595,41 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): "test", ) + @testing.combinations( + (postgresql.BIT(varying=True), BitString("")), + (postgresql.BIT(varying=True), BitString("1101010101")), + (postgresql.BIT(6, varying=True), BitString("")), + (postgresql.BIT(6, varying=True), BitString("010101")), + (postgresql.BIT(1), BitString("0")), + (postgresql.BIT(4), BitString("0010")), + (postgresql.BIT(4), "0010"), + argnames="column_type, value", + ) + def test_bitstring_round_trip( + self, connection, metadata, column_type, value + ): + t = Table( + "bits", + metadata, + Column("name", String), + Column("value", column_type), + ) + t.create(connection) + + connection.execute(t.insert(), {"name": "test", "value": value}) + eq_( + connection.scalar(select(t.c.name).where(t.c.value == value)), + "test", + ) + + result_value = connection.scalar( + select(t.c.value).where(t.c.name == "test") + ) + assert isinstance(result_value, BitString) + eq_(result_value, value) + eq_(result_value, str(value)) + eq_(str(result_value), str(value)) + def test_tsvector_round_trip(self, connection, metadata): t = Table("t1", metadata, Column("data", postgresql.TSVECTOR)) t.create(connection) @@ -4182,6 +4231,103 @@ class HStoreRoundTripTest(fixtures.TablesTest): eq_(s.query(Data.data, Data).all(), [(d.data, d)]) +class BitTests(fixtures.TestBase): + __backend__ = True + __only_on__ = "postgresql" + + def test_concatenation(self, connection): + coltype = BIT(varying=True) + + q = select( + literal(BitString("1111"), coltype).concat(BitString("0000")) + ) + r = connection.execute(q).first() + eq_(r[0], BitString("11110000")) + + def test_invert_operator(self, connection): + coltype = BIT(4) + + q = select(literal(BitString("0010"), coltype).bitwise_not()) + r = connection.execute(q).first() + + eq_(r[0], BitString("1101")) + + def test_and_operator(self, connection): + coltype = BIT(6) + + q1 = select( + literal(BitString("001010"), coltype) + & literal(BitString("010111"), coltype) + ) + r1 = connection.execute(q1).first() + + eq_(r1[0], BitString("000010")) + + q2 = select( + literal(BitString("010101"), coltype) & BitString("001011") + ) + r2 = connection.execute(q2).first() + eq_(r2[0], BitString("000001")) + + def test_or_operator(self, connection): + coltype = BIT(6) + + q1 = select( + literal(BitString("001010"), coltype) + | literal(BitString("010111"), coltype) + ) + r1 = connection.execute(q1).first() + + eq_(r1[0], BitString("011111")) + + q2 = select( + literal(BitString("010101"), coltype) | BitString("001011") + ) + r2 = connection.execute(q2).first() + eq_(r2[0], BitString("011111")) + + def test_xor_operator(self, connection): + coltype = BIT(6) + + q1 = select( + literal(BitString("001010"), coltype).bitwise_xor( + literal(BitString("010111"), coltype) + ) + ) + r1 = connection.execute(q1).first() + eq_(r1[0], BitString("011101")) + + q2 = select( + literal(BitString("010101"), coltype).bitwise_xor( + BitString("001011") + ) + ) + r2 = connection.execute(q2).first() + eq_(r2[0], BitString("011110")) + + def test_lshift_operator(self, connection): + coltype = BIT(6) + + q = select( + literal(BitString("001010"), coltype), + literal(BitString("001010"), coltype) << 1, + ) + + r = connection.execute(q).first() + eq_(tuple(r), (BitString("001010"), BitString("010100"))) + + def test_rshift_operator(self, connection): + coltype = BIT(6) + + q = select( + literal(BitString("001010"), coltype), + literal(BitString("001010"), coltype) >> 1, + ) + + r = connection.execute(q).first() + eq_(tuple(r), (BitString("001010"), BitString("000101"))) + + class RangeMiscTests(fixtures.TestBase): @testing.combinations( (Range(2, 7), INT4RANGE), @@ -6639,7 +6785,7 @@ class PGInsertManyValuesTest(fixtures.TestBase): @testing.combinations( ("BYTEA", BYTEA(), b"7\xe7\x9f"), - ("BIT", BIT(3), "011"), + ("BIT", BIT(3), BitString("011")), argnames="type_,value", id_="iaa", ) @@ -6672,11 +6818,6 @@ class PGInsertManyValuesTest(fixtures.TestBase): t.create(connection) - if type_._type_affinity is BIT and testing.against("+asyncpg"): - import asyncpg - - value = asyncpg.BitString(value) - result = connection.execute( t.insert().returning( t.c.id,