from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import literal_column
-from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import text
from sqlalchemy.sql import column
from sqlalchemy.sql import table
+from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-info_table = None
-
-
-class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
+class CaseTest(fixtures.TablesTest, AssertsCompiledSQL):
__dialect__ = "default"
+ run_inserts = "once"
+ run_deletes = "never"
+
@classmethod
- def setup_test_class(cls):
- metadata = MetaData()
- global info_table
- info_table = Table(
- "infos",
+ def define_tables(cls, metadata):
+ Table(
+ "info_table",
metadata,
Column("pk", Integer, primary_key=True),
Column("info", String(30)),
)
- with testing.db.begin() as conn:
- info_table.create(conn)
-
- conn.execute(
- info_table.insert(),
- [
- {"pk": 1, "info": "pk_1_data"},
- {"pk": 2, "info": "pk_2_data"},
- {"pk": 3, "info": "pk_3_data"},
- {"pk": 4, "info": "pk_4_data"},
- {"pk": 5, "info": "pk_5_data"},
- {"pk": 6, "info": "pk_6_data"},
- ],
- )
-
@classmethod
- def teardown_test_class(cls):
- with testing.db.begin() as conn:
- info_table.drop(conn)
+ def insert_data(cls, connection):
+ info_table = cls.tables.info_table
+
+ connection.execute(
+ info_table.insert(),
+ [
+ {"pk": 1, "info": "pk_1_data"},
+ {"pk": 2, "info": "pk_2_data"},
+ {"pk": 3, "info": "pk_3_data"},
+ {"pk": 4, "info": "pk_4_data"},
+ {"pk": 5, "info": "pk_5_data"},
+ {"pk": 6, "info": "pk_6_data"},
+ ],
+ )
+ connection.commit()
@testing.requires.subqueries
def test_case(self, connection):
+ info_table = self.tables.info_table
+
inner = select(
case(
(info_table.c.pk < 3, "lessthan3"),
)
def test_text_doesnt_explode(self, connection):
+ info_table = self.tables.info_table
+
for s in [
select(
case(
)
def testcase_with_dict(self):
+ info_table = self.tables.info_table
+
query = select(
case(
{
("two", 2),
("other", 3),
]
+
+ @testing.variation("add_else", [True, False])
+ def test_type_of_case_expression_with_all_nulls(self, add_else):
+ info_table = self.tables.info_table
+
+ expr = case(
+ (info_table.c.pk < 0, None),
+ (info_table.c.pk > 9, None),
+ else_=column("q") if add_else else None,
+ )
+
+ assert isinstance(expr.type, NullType)
+
+ @testing.combinations(
+ lambda info_table: (
+ [
+ # test non-None in middle of WHENS takes precedence over Nones
+ (info_table.c.pk < 0, None),
+ (info_table.c.pk < 5, "five"),
+ (info_table.c.pk <= 9, info_table.c.pk),
+ (info_table.c.pk > 9, None),
+ ],
+ None,
+ ),
+ lambda info_table: (
+ # test non-None ELSE takes precedence over WHENs that are None
+ [(info_table.c.pk < 0, None)],
+ info_table.c.pk,
+ ),
+ lambda info_table: (
+ # test non-None WHEN takes precedence over non-None ELSE
+ [
+ (info_table.c.pk < 0, None),
+ (info_table.c.pk <= 9, info_table.c.pk),
+ (info_table.c.pk > 9, None),
+ ],
+ column("q", String),
+ ),
+ lambda info_table: (
+ # test last WHEN in list takes precedence
+ [
+ (info_table.c.pk < 0, String),
+ (info_table.c.pk > 9, None),
+ (info_table.c.pk <= 9, info_table.c.pk),
+ ],
+ column("q", String),
+ ),
+ )
+ def test_type_of_case_expression(self, when_lambda):
+ info_table = self.tables.info_table
+
+ whens, else_ = testing.resolve_lambda(
+ when_lambda, info_table=info_table
+ )
+
+ expr = case(*whens, else_=else_)
+
+ assert isinstance(expr.type, Integer)