From b920869ef54d05e73e2a980b73647d6050ffeb9d Mon Sep 17 00:00:00 2001 From: Gord Thompson Date: Thu, 24 Jun 2021 12:16:32 -0600 Subject: [PATCH] Modernize tests Eliminate engine.execute() and engine.scalar() Change-Id: I99f76d0e615ddebab2da4fd07a40a0a2796995c7 --- .gitignore | 1 + lib/sqlalchemy/dialects/mysql/base.py | 7 ++- lib/sqlalchemy/testing/warnings.py | 1 - test/dialect/mysql/test_dialect.py | 4 +- test/dialect/postgresql/test_dialect.py | 4 +- test/dialect/postgresql/test_query.py | 81 ++++++++++++++----------- test/engine/test_pool.py | 3 +- test/ext/test_horizontal_shard.py | 4 +- test/orm/test_composites.py | 6 +- test/orm/test_dynamic.py | 30 +++++---- test/requirements.py | 16 ++--- 11 files changed, 89 insertions(+), 68 deletions(-) diff --git a/.gitignore b/.gitignore index 8d9d546578..c566ded772 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ test/test_schema.db /querytest.db /.mypy_cache /.pytest_cache +/db_idents.txt diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index cfb6c29247..e39010762b 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -3358,10 +3358,11 @@ class MySQLDialect(default.DefaultDialect): # https://dev.mysql.com/doc/refman/en/identifier-case-sensitivity.html charset = self._connection_charset + show_var = connection.execute( + sql.text("SHOW VARIABLES LIKE 'lower_case_table_names'") + ) row = self._compat_first( - connection.execute( - sql.text("SHOW VARIABLES LIKE 'lower_case_table_names'") - ), + show_var, charset=charset, ) if not row: diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 30f50a44f7..df0e5aa5e5 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -68,7 +68,6 @@ def setup_filters(): # # Core execution # - r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method", # r".*DefaultGenerator.execute\(\)", # # diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index 45d119cf3c..57dd9d393d 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -499,8 +499,8 @@ class ExecutionTest(fixtures.TestBase): eq_(cx.dialect._connection_charset, charset) cx.close() - def test_sysdate(self): - d = testing.db.scalar(func.sysdate()) + def test_sysdate(self, connection): + d = connection.execute(func.sysdate()).scalar() assert isinstance(d, datetime.datetime) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 5a53e0b7e9..371a17819d 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -1117,9 +1117,9 @@ $$ LANGUAGE plpgsql; ) def test_extract(self, connection): - fivedaysago = testing.db.scalar( + fivedaysago = connection.execute( select(func.now().op("at time zone")("UTC")) - ) - datetime.timedelta(days=5) + ).scalar() - datetime.timedelta(days=5) for field, exp in ( ("year", fivedaysago.year), diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index db76f61ffa..a1e9c46572 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -903,6 +903,10 @@ class ExtractTest(fixtures.TablesTest): run_inserts = "once" run_deletes = None + class TZ(datetime.tzinfo): + def utcoffset(self, dt): + return datetime.timedelta(hours=4) + @classmethod def setup_bind(cls): from sqlalchemy import event @@ -932,11 +936,6 @@ class ExtractTest(fixtures.TablesTest): @classmethod def insert_data(cls, connection): - # TODO: why does setting hours to anything - # not affect the TZ in the DB col ? - class TZ(datetime.tzinfo): - def utcoffset(self, dt): - return datetime.timedelta(hours=4) connection.execute( cls.tables.t.insert(), @@ -946,12 +945,12 @@ class ExtractTest(fixtures.TablesTest): "tm": datetime.time(12, 15, 25), "intv": datetime.timedelta(seconds=570), "dttz": datetime.datetime( - 2012, 5, 10, 12, 15, 25, tzinfo=TZ() + 2012, 5, 10, 12, 15, 25, tzinfo=cls.TZ() ), }, ) - def _test(self, expr, field="all", overrides=None): + def _test(self, connection, expr, field="all", overrides=None): t = self.tables.t if field == "all": @@ -983,29 +982,31 @@ class ExtractTest(fixtures.TablesTest): fields.update(overrides) for field in fields: - result = self.bind.scalar( + result = connection.execute( select(extract(field, expr)).select_from(t) - ) + ).scalar() eq_(result, fields[field]) - def test_one(self): + def test_one(self, connection): t = self.tables.t - self._test(t.c.dtme, "all") + self._test(connection, t.c.dtme, "all") - def test_two(self): + def test_two(self, connection): t = self.tables.t self._test( + connection, t.c.dtme + t.c.intv, overrides={"epoch": 1336652695.0, "minute": 24}, ) - def test_three(self): + def test_three(self, connection): self.tables.t - actual_ts = self.bind.scalar( + actual_ts = self.bind.connect().execute( func.current_timestamp() - ) - datetime.timedelta(days=5) + ).scalar() - datetime.timedelta(days=5) self._test( + connection, func.current_timestamp() - datetime.timedelta(days=5), { "hour": actual_ts.hour, @@ -1014,9 +1015,10 @@ class ExtractTest(fixtures.TablesTest): }, ) - def test_four(self): + def test_four(self, connection): t = self.tables.t self._test( + connection, datetime.timedelta(days=5) + t.c.dt, overrides={ "day": 15, @@ -1026,23 +1028,26 @@ class ExtractTest(fixtures.TablesTest): }, ) - def test_five(self): + def test_five(self, connection): t = self.tables.t self._test( + connection, func.coalesce(t.c.dtme, func.current_timestamp()), overrides={"epoch": 1336652125.0}, ) - def test_six(self): + def test_six(self, connection): t = self.tables.t self._test( + connection, t.c.tm + datetime.timedelta(seconds=30), "time", overrides={"second": 55}, ) - def test_seven(self): + def test_seven(self, connection): self._test( + connection, literal(datetime.timedelta(seconds=10)) - literal(datetime.timedelta(seconds=10)), "all", @@ -1056,49 +1061,53 @@ class ExtractTest(fixtures.TablesTest): }, ) - def test_eight(self): + def test_eight(self, connection): t = self.tables.t self._test( + connection, t.c.tm + datetime.timedelta(seconds=30), {"hour": 12, "minute": 15, "second": 55}, ) - def test_nine(self): - self._test(text("t.dt + t.tm")) + def test_nine(self, connection): + self._test(connection, text("t.dt + t.tm")) - def test_ten(self): + def test_ten(self, connection): t = self.tables.t - self._test(t.c.dt + t.c.tm) + self._test(connection, t.c.dt + t.c.tm) - def test_eleven(self): + def test_eleven(self, connection): self._test( + connection, func.current_timestamp() - func.current_timestamp(), {"year": 0, "month": 0, "day": 0, "hour": 0}, ) - def test_twelve(self): + def test_twelve(self, connection): t = self.tables.t - actual_ts = self.bind.scalar(func.current_timestamp()).replace( - tzinfo=None - ) - datetime.datetime(2012, 5, 10, 12, 15, 25) - self._test( + actual_ts = connection.scalar( func.current_timestamp() - - func.coalesce(t.c.dtme, func.current_timestamp()), + ) - datetime.datetime(2012, 5, 10, 12, 15, 25, tzinfo=self.TZ()) + + self._test( + connection, + func.current_timestamp() - t.c.dttz, {"day": actual_ts.days}, ) - def test_thirteen(self): + def test_thirteen(self, connection): t = self.tables.t - self._test(t.c.dttz, "all+tz") + self._test(connection, t.c.dttz, "all+tz") - def test_fourteen(self): + def test_fourteen(self, connection): t = self.tables.t - self._test(t.c.tm, "time") + self._test(connection, t.c.tm, "time") - def test_fifteen(self): + def test_fifteen(self, connection): t = self.tables.t self._test( + connection, datetime.timedelta(days=5) + t.c.dtme, overrides={"day": 15, "epoch": 1337084125.0}, ) diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 70671134f1..fa5bd8cffb 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -738,7 +738,8 @@ class PoolEventsTest(PoolTestBase): event.listen(engine, "connect", listen_three) event.listen(engine.__class__, "connect", listen_four) - engine.execute(select(1)).close() + with engine.connect() as conn: + conn.execute(select(1)) eq_( canary, ["listen_one", "listen_four", "listen_two", "listen_three"] ) diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index bb06d9648b..76e94918a8 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -233,11 +233,11 @@ class ShardTest(object): self._fixture_data() # not sure what this is testing except the fixture data itself eq_( - db2.execute(weather_locations.select()).fetchall(), + db2.connect().execute(weather_locations.select()).fetchall(), [(1, "Asia", "Tokyo")], ) eq_( - db1.execute(weather_locations.select()).fetchall(), + db1.connect().execute(weather_locations.select()).fetchall(), [ (2, "North America", "New York"), (3, "North America", "Toronto"), diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 6ee87eefe7..5fb7cf50ff 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -775,7 +775,7 @@ class MappedSelectTest(fixtures.MappedTest): }, ) - def test_set_composite_attrs_via_selectable(self): + def test_set_composite_attrs_via_selectable(self, connection): Values, CustomValues, values, Descriptions, descriptions = ( self.classes.Values, self.classes.CustomValues, @@ -796,11 +796,11 @@ class MappedSelectTest(fixtures.MappedTest): session.add(d) session.commit() eq_( - testing.db.execute(descriptions.select()).fetchall(), + connection.execute(descriptions.select()).fetchall(), [(1, "Color", "Number")], ) eq_( - testing.db.execute(values.select()).fetchall(), + connection.execute(values.select()).fetchall(), [(1, 1, "Red", "5"), (2, 1, "Blue", "1")], ) diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index b4ab4047d5..4a936dd3e5 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -923,15 +923,19 @@ class UOWTest( sess.add(u) sess.commit() eq_( - testing.db.scalar( + sess.connection() + .execute( select(func.count("*")).where(addresses.c.user_id == None) - ), # noqa + ) + .scalar(), # noqa 0, ) eq_( - testing.db.scalar( + sess.connection() + .execute( select(func.count("*")).where(addresses.c.user_id != None) - ), # noqa + ) + .scalar(), # noqa 6, ) @@ -941,26 +945,30 @@ class UOWTest( if expected: eq_( - testing.db.scalar( + sess.connection() + .execute( select(func.count("*")).where( addresses.c.user_id == None ) # noqa - ), + ) + .scalar(), 6, ) eq_( - testing.db.scalar( + sess.connection() + .execute( select(func.count("*")).where( addresses.c.user_id != None ) # noqa - ), + ) + .scalar(), 0, ) else: eq_( - testing.db.scalar( - select(func.count("*")).select_from(addresses) - ), + sess.connection() + .execute(select(func.count("*")).select_from(addresses)) + .scalar(), 0, ) diff --git a/test/requirements.py b/test/requirements.py index 86ae3d4b5b..6786c8dafc 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1660,15 +1660,17 @@ class DefaultRequirements(SuiteRequirements): ) def _has_mysql_on_windows(self, config): - return ( - against(config, ["mysql", "mariadb"]) - ) and config.db.dialect._detect_casing(config.db) == 1 + with config.db.connect() as conn: + return ( + against(config, ["mysql", "mariadb"]) + ) and config.db.dialect._detect_casing(conn) == 1 def _has_mysql_fully_case_sensitive(self, config): - return ( - against(config, "mysql") - and config.db.dialect._detect_casing(config.db) == 0 - ) + with config.db.connect() as conn: + return ( + against(config, "mysql") + and config.db.dialect._detect_casing(conn) == 0 + ) @property def postgresql_utf8_server_encoding(self): -- 2.47.2