__visit_name__ = "BYTEA"
-class INET(sqltypes.TypeEngine[str]):
+class _NetworkAddressTypeMixin:
+
+ def coerce_compared_value(
+ self, op: Optional[OperatorType], value: Any
+ ) -> TypeEngine[Any]:
+ if TYPE_CHECKING:
+ assert isinstance(self, TypeEngine)
+ return self
+
+
+class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "INET"
PGInet = INET
-class CIDR(sqltypes.TypeEngine[str]):
+class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "CIDR"
PGCidr = CIDR
-class MACADDR(sqltypes.TypeEngine[str]):
+class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR"
PGMacAddr = MACADDR
-class MACADDR8(sqltypes.TypeEngine[str]):
+class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR8"
def test_bit_compile(self, type_, expected):
self.assert_compile(type_, expected)
+ @testing.combinations(
+ (psycopg.dialect(),),
+ (psycopg2.dialect(),),
+ (asyncpg.dialect(),),
+ (pg8000.dialect(),),
+ argnames="dialect",
+ id_="n",
+ )
+ def test_network_address_cast(self, metadata, dialect):
+ t = Table(
+ "addresses",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("addr", postgresql.INET),
+ Column("addr2", postgresql.MACADDR),
+ Column("addr3", postgresql.CIDR),
+ Column("addr4", postgresql.MACADDR8),
+ )
+ stmt = select(t.c.id).where(
+ t.c.addr == "127.0.0.1",
+ t.c.addr2 == "08:00:2b:01:02:03",
+ t.c.addr3 == "192.168.100.128/25",
+ t.c.addr4 == "08:00:2b:01:02:03:04:05",
+ )
+ param, param2, param3, param4 = {
+ "format": ("%s", "%s", "%s", "%s"),
+ "numeric_dollar": ("$1", "$2", "$3", "$4"),
+ "pyformat": (
+ "%(addr_1)s",
+ "%(addr2_1)s",
+ "%(addr3_1)s",
+ "%(addr4_1)s",
+ ),
+ }[dialect.paramstyle]
+ expected = (
+ "SELECT addresses.id FROM addresses "
+ f"WHERE addresses.addr = {param} "
+ f"AND addresses.addr2 = {param2} "
+ f"AND addresses.addr3 = {param3} "
+ f"AND addresses.addr4 = {param4}"
+ )
+ self.assert_compile(stmt, expected, dialect=dialect)
+
class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
"""test DDL and reflection of PG-specific types"""
assert t.c.precision_interval.type.precision == 3
assert t.c.bitstring.type.length == 4
+ @testing.combinations(
+ (postgresql.INET, "127.0.0.1"),
+ (postgresql.CIDR, "192.168.100.128/25"),
+ (postgresql.MACADDR, "08:00:2b:01:02:03"),
+ (postgresql.MACADDR8, "08:00:2b:01:02:03:04:05"),
+ argnames="column_type, value",
+ id_="na",
+ )
+ def test_network_address_round_trip(
+ self, connection, metadata, column_type, value
+ ):
+ t = Table(
+ "addresses",
+ metadata,
+ Column("name", String),
+ Column("value", column_type),
+ )
+ t.create(connection)
+ connection.execute(t.insert(), {"name": "test", "value": value})
+ eq_(
+ connection.scalar(select(t.c.name).where(t.c.value == value)),
+ "test",
+ )
+
def test_tsvector_round_trip(self, connection, metadata):
t = Table("t1", metadata, Column("data", postgresql.TSVECTOR))
t.create(connection)