import datetime
import decimal
from enum import Enum as _PY_Enum
+import functools
from ipaddress import IPv4Address
from ipaddress import IPv4Network
from ipaddress import IPv6Address
t.drop(connection)
eq_(inspect(connection).get_enums(), [])
- def _type_combinations(exclude_json=False, exclude_empty_lists=False):
+ def _type_combinations(
+ exclude_json=False,
+ exclude_empty_lists=False,
+ exclude_arrays_with_none=False,
+ ):
def str_values(x):
return ["one", "two: %s" % x, "three", "four", "five"]
def __ne__(self, other):
return not self.__eq__(other)
+ simple_enum = ["one", "two", "three", "four", "five", "six"]
difficult_enum = [
"Value",
"With space",
"With,comma",
+ "NULL",
'With"quote',
"With\\escape",
"""Various!@#$%^*()"'\\][{};:.<>|_+~chars""",
]
- def make_difficult_enum(cls_, native):
- return cls_(
- *difficult_enum, name="difficult_enum", native_enum=native
- )
+ def make_enum(cls_, members, native):
+ return cls_(*members, name="difficult_enum", native_enum=native)
- def difficult_enum_values(x):
- return [v for i, v in enumerate(difficult_enum) if i != x - 1]
+ def make_enum_values(members, x, *, include_none=False):
+ arr = [v for i, v in enumerate(members) if i != x - 1]
+ if include_none:
+ arr.insert(2, None)
+ return arr
elements = [
(sqltypes.Integer, lambda x: [1, x, 3, 4, 5]),
enum_values,
),
(
- make_difficult_enum(sqltypes.Enum, native=True),
- difficult_enum_values,
+ make_enum(sqltypes.Enum, difficult_enum, native=True),
+ functools.partial(make_enum_values, difficult_enum),
),
(
- make_difficult_enum(sqltypes.Enum, native=False),
- difficult_enum_values,
+ make_enum(sqltypes.Enum, difficult_enum, native=False),
+ functools.partial(make_enum_values, difficult_enum),
),
(
- make_difficult_enum(postgresql.ENUM, native=True),
- difficult_enum_values,
+ make_enum(postgresql.ENUM, difficult_enum, native=True),
+ functools.partial(make_enum_values, difficult_enum),
),
]
+ if not exclude_arrays_with_none:
+ elements.extend(
+ [
+ (
+ # unquoted ENUM values including NULL in the data
+ make_enum(sqltypes.Enum, simple_enum, native=True),
+ functools.partial(
+ make_enum_values, simple_enum, include_none=True
+ ),
+ ),
+ (
+ # unquoted ENUM values including NULL in the data
+ make_enum(sqltypes.Enum, simple_enum, native=False),
+ functools.partial(
+ make_enum_values, simple_enum, include_none=True
+ ),
+ ),
+ (
+ # unquoted ENUM values including NULL in the data
+ make_enum(postgresql.ENUM, simple_enum, native=True),
+ functools.partial(
+ make_enum_values, simple_enum, include_none=True
+ ),
+ ),
+ (
+ # quoted ENUM values, including both
+ # quoted "NULL" and real NULL in the data
+ make_enum(sqltypes.Enum, difficult_enum, native=True),
+ functools.partial(
+ make_enum_values, difficult_enum, include_none=True
+ ),
+ ),
+ (
+ # quoted ENUM values, including both
+ # quoted "NULL" and real NULL in the data
+ make_enum(sqltypes.Enum, difficult_enum, native=False),
+ functools.partial(
+ make_enum_values, difficult_enum, include_none=True
+ ),
+ ),
+ (
+ # quoted ENUM values, including both
+ # quoted "NULL" and real NULL in the data
+ make_enum(
+ postgresql.ENUM, difficult_enum, native=True
+ ),
+ functools.partial(
+ make_enum_values, difficult_enum, include_none=True
+ ),
+ ),
+ ]
+ )
if not exclude_empty_lists:
elements.extend(
[
elements[i] = elem
return testing.combinations_list(
- elements, argnames="type_,gen", id_="na"
+ elements, argnames="type_,generate_data", id_="na"
)
@classmethod
meta.create_all(connection)
- def go(gen):
+ def go(generate_data):
connection.execute(
table.insert(),
- [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+ [
+ {"id": 1, "bar": generate_data(1)},
+ {"id": 2, "bar": generate_data(2)},
+ ],
)
return table
@_type_combinations()
def test_type_specific_value_select(
- self, type_specific_fixture, connection, type_, gen
+ self, type_specific_fixture, connection, type_, generate_data
):
- table = type_specific_fixture(gen)
+ table = type_specific_fixture(generate_data)
rows = connection.execute(
select(table.c.bar).order_by(table.c.id)
).all()
- eq_(rows, [(gen(1),), (gen(2),)])
+ eq_(rows, [(generate_data(1),), (generate_data(2),)])
@_type_combinations()
def test_type_specific_value_update(
- self, type_specific_fixture, connection, type_, gen
+ self, type_specific_fixture, connection, type_, generate_data
):
- table = type_specific_fixture(gen)
+ table = type_specific_fixture(generate_data)
- new_gen = gen(3)
+ new_gen = generate_data(3)
connection.execute(
table.update().where(table.c.id == 2).values(bar=new_gen)
)
@_type_combinations(exclude_empty_lists=True)
def test_type_specific_slice_update(
- self, type_specific_fixture, connection, type_, gen
+ self, type_specific_fixture, connection, type_, generate_data
):
- table = type_specific_fixture(gen)
+ table = type_specific_fixture(generate_data)
- new_gen = gen(3)
+ new_gen = generate_data(3)
if not table.c.bar.type._variant_mapping:
# this is not likely to occur to users but we need to just
select(table.c.bar).order_by(table.c.id)
).all()
- sliced_gen = gen(2)
+ sliced_gen = generate_data(2)
sliced_gen[0:3] = new_gen[1:4]
- eq_(rows, [(gen(1),), (sliced_gen,)])
+ eq_(rows, [(generate_data(1),), (sliced_gen,)])
@_type_combinations(exclude_json=True, exclude_empty_lists=True)
def test_type_specific_value_delete(
- self, type_specific_fixture, connection, type_, gen
+ self, type_specific_fixture, connection, type_, generate_data
):
- table = type_specific_fixture(gen)
+ table = type_specific_fixture(generate_data)
- new_gen = gen(2)
+ new_gen = generate_data(2)
connection.execute(table.delete().where(table.c.bar == new_gen))
):
ARRAY = postgresql.ARRAY
- @ArrayRoundTripTest._cls_type_combinations(exclude_json=True)
+ @ArrayRoundTripTest._cls_type_combinations(
+ exclude_json=True, exclude_arrays_with_none=True
+ )
def test_type_specific_contains(
- self, type_specific_fixture, connection, type_, gen
+ self, type_specific_fixture, connection, type_, generate_data
):
- table = type_specific_fixture(gen)
+ table = type_specific_fixture(generate_data)
connection.execute(
table.insert(),
- [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+ [
+ {"id": 1, "bar": generate_data(1)},
+ {"id": 2, "bar": generate_data(2)},
+ ],
)
id_, value = connection.execute(
- select(table).where(table.c.bar.contains(gen(1)))
+ select(table).where(table.c.bar.contains(generate_data(1)))
).first()
eq_(id_, 1)
- eq_(value, gen(1))
+ eq_(value, generate_data(1))
@testing.combinations(
(set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"