from ...sql import roles
from ...sql import sqltypes
from ...sql import type_api
+from ...sql.base import _NoArg
from ...sql.ddl import InvokeCreateDDLBase
from ...sql.ddl import InvokeDropDDLBase
DDLGenerator = EnumGenerator
DDLDropper = EnumDropper
- def __init__(self, *enums, name: str, create_type: bool = True, **kw):
+ def __init__(
+ self,
+ *enums,
+ name: Union[str, _NoArg, None] = _NoArg.NO_ARG,
+ create_type: bool = True,
+ **kw,
+ ):
"""Construct an :class:`_postgresql.ENUM`.
Arguments are the same as that of
"non-native enum."
)
self.create_type = create_type
- super().__init__(*enums, name=name, **kw)
+ if name is not _NoArg.NO_ARG:
+ kw["name"] = name
+ super().__init__(*enums, **kw)
+
+ def coerce_compared_value(self, op, value):
+ super_coerced_type = super().coerce_compared_value(op, value)
+ if (
+ super_coerced_type._type_affinity
+ is type_api.STRINGTYPE._type_affinity
+ ):
+ return self
+ else:
+ return super_coerced_type
@classmethod
def __test_init__(cls):
from sqlalchemy import exc
from sqlalchemy import Float
from sqlalchemy import func
+from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal
else:
assert False
+ @testing.variation("name", ["noname", "nonename", "explicit_name"])
+ @testing.variation("enum_type", ["pg", "plain"])
+ def test_native_enum_string_from_pep435(self, name, enum_type):
+ """test #9611"""
+
+ class MyEnum(_PY_Enum):
+ one = "one"
+ two = "two"
+
+ if enum_type.plain:
+ cls = Enum
+ elif enum_type.pg:
+ cls = ENUM
+ else:
+ enum_type.fail()
+
+ if name.noname:
+ e1 = cls(MyEnum)
+ eq_(e1.name, "myenum")
+ elif name.nonename:
+ e1 = cls(MyEnum, name=None)
+ eq_(e1.name, None)
+ elif name.explicit_name:
+ e1 = cls(MyEnum, name="abc")
+ eq_(e1.name, "abc")
+
+ @testing.variation("backend_type", ["native", "non_native", "pg_native"])
+ @testing.variation("enum_type", ["pep435", "str"])
+ def test_compare_to_string_round_trip(
+ self, connection, backend_type, enum_type, metadata
+ ):
+ """test #9621"""
+
+ if enum_type.pep435:
+
+ class MyEnum(_PY_Enum):
+ one = "one"
+ two = "two"
+
+ if backend_type.pg_native:
+ typ = ENUM(MyEnum, name="myenum2")
+ else:
+ typ = Enum(
+ MyEnum,
+ native_enum=bool(backend_type.native),
+ name="myenum2",
+ )
+ data = [{"someenum": MyEnum.one}, {"someenum": MyEnum.two}]
+ expected = MyEnum.two
+ elif enum_type.str:
+ if backend_type.pg_native:
+ typ = ENUM("one", "two", name="myenum2")
+ else:
+ typ = Enum(
+ "one",
+ "two",
+ native_enum=bool(backend_type.native),
+ name="myenum2",
+ )
+ data = [{"someenum": "one"}, {"someenum": "two"}]
+ expected = "two"
+ else:
+ enum_type.fail()
+
+ enum_table = Table(
+ "et2",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("someenum", typ),
+ )
+ metadata.create_all(connection)
+
+ connection.execute(insert(enum_table), data)
+ expr = select(enum_table.c.someenum).where(
+ enum_table.c.someenum == "two"
+ )
+
+ row = connection.execute(expr).one()
+ eq_(row, (expected,))
+
@testing.combinations(
(Enum("one", "two", "three")),
(ENUM("one", "two", "three", name=None)),