]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Modernize tests
authorGord Thompson <gord@gordthompson.com>
Thu, 24 Jun 2021 18:16:32 +0000 (12:16 -0600)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Jul 2021 22:50:03 +0000 (18:50 -0400)
Eliminate engine.execute() and engine.scalar()

Change-Id: I99f76d0e615ddebab2da4fd07a40a0a2796995c7

.gitignore
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/testing/warnings.py
test/dialect/mysql/test_dialect.py
test/dialect/postgresql/test_dialect.py
test/dialect/postgresql/test_query.py
test/engine/test_pool.py
test/ext/test_horizontal_shard.py
test/orm/test_composites.py
test/orm/test_dynamic.py
test/requirements.py

index 8d9d546578d6fcb9830ca18c6a3b3d8a97a815ec..c566ded772b5df14e062752a20ae40721b1b09f1 100644 (file)
@@ -38,3 +38,4 @@ test/test_schema.db
 /querytest.db
 /.mypy_cache
 /.pytest_cache
+/db_idents.txt
index cfb6c29247dc7504e19bd044f1ba89de511594de..e39010762b876b35bdae21535a783a981463f5a4 100644 (file)
@@ -3358,10 +3358,11 @@ class MySQLDialect(default.DefaultDialect):
         # https://dev.mysql.com/doc/refman/en/identifier-case-sensitivity.html
 
         charset = self._connection_charset
+        show_var = connection.execute(
+            sql.text("SHOW VARIABLES LIKE 'lower_case_table_names'")
+        )
         row = self._compat_first(
-            connection.execute(
-                sql.text("SHOW VARIABLES LIKE 'lower_case_table_names'")
-            ),
+            show_var,
             charset=charset,
         )
         if not row:
index 30f50a44f70c0b267c5e2622eea96a9e96dcb5e2..df0e5aa5e515c4b7b7b3c2c7fa96e46ad3ab84e0 100644 (file)
@@ -68,7 +68,6 @@ def setup_filters():
         #
         # Core execution
         #
-        r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method",
         #        r".*DefaultGenerator.execute\(\)",
         #
         #
index 45d119cf3c6c95b3ba6d4c183010f177f23167e3..57dd9d393ddbf599847b73aea158fdd56f42fe35 100644 (file)
@@ -499,8 +499,8 @@ class ExecutionTest(fixtures.TestBase):
         eq_(cx.dialect._connection_charset, charset)
         cx.close()
 
-    def test_sysdate(self):
-        d = testing.db.scalar(func.sysdate())
+    def test_sysdate(self, connection):
+        d = connection.execute(func.sysdate()).scalar()
         assert isinstance(d, datetime.datetime)
 
 
index 5a53e0b7e9df7478fdc1ddf5c87f710ccecaad8f..371a17819dd8a9dbfa1951f03533ec65ed0bb2e3 100644 (file)
@@ -1117,9 +1117,9 @@ $$ LANGUAGE plpgsql;
         )
 
     def test_extract(self, connection):
-        fivedaysago = testing.db.scalar(
+        fivedaysago = connection.execute(
             select(func.now().op("at time zone")("UTC"))
-        ) - datetime.timedelta(days=5)
+        ).scalar() - datetime.timedelta(days=5)
 
         for field, exp in (
             ("year", fivedaysago.year),
index db76f61ffafd348a7aa1007f08b3c3af810fc87b..a1e9c46572933c1851b465d424f0110f3f37094f 100644 (file)
@@ -903,6 +903,10 @@ class ExtractTest(fixtures.TablesTest):
     run_inserts = "once"
     run_deletes = None
 
+    class TZ(datetime.tzinfo):
+        def utcoffset(self, dt):
+            return datetime.timedelta(hours=4)
+
     @classmethod
     def setup_bind(cls):
         from sqlalchemy import event
@@ -932,11 +936,6 @@ class ExtractTest(fixtures.TablesTest):
 
     @classmethod
     def insert_data(cls, connection):
-        # TODO: why does setting hours to anything
-        # not affect the TZ in the DB col ?
-        class TZ(datetime.tzinfo):
-            def utcoffset(self, dt):
-                return datetime.timedelta(hours=4)
 
         connection.execute(
             cls.tables.t.insert(),
@@ -946,12 +945,12 @@ class ExtractTest(fixtures.TablesTest):
                 "tm": datetime.time(12, 15, 25),
                 "intv": datetime.timedelta(seconds=570),
                 "dttz": datetime.datetime(
-                    2012, 5, 10, 12, 15, 25, tzinfo=TZ()
+                    2012, 5, 10, 12, 15, 25, tzinfo=cls.TZ()
                 ),
             },
         )
 
-    def _test(self, expr, field="all", overrides=None):
+    def _test(self, connection, expr, field="all", overrides=None):
         t = self.tables.t
 
         if field == "all":
@@ -983,29 +982,31 @@ class ExtractTest(fixtures.TablesTest):
             fields.update(overrides)
 
         for field in fields:
-            result = self.bind.scalar(
+            result = connection.execute(
                 select(extract(field, expr)).select_from(t)
-            )
+            ).scalar()
             eq_(result, fields[field])
 
-    def test_one(self):
+    def test_one(self, connection):
         t = self.tables.t
-        self._test(t.c.dtme, "all")
+        self._test(connection, t.c.dtme, "all")
 
-    def test_two(self):
+    def test_two(self, connection):
         t = self.tables.t
         self._test(
+            connection,
             t.c.dtme + t.c.intv,
             overrides={"epoch": 1336652695.0, "minute": 24},
         )
 
-    def test_three(self):
+    def test_three(self, connection):
         self.tables.t
 
-        actual_ts = self.bind.scalar(
+        actual_ts = self.bind.connect().execute(
             func.current_timestamp()
-        ) - datetime.timedelta(days=5)
+        ).scalar() - datetime.timedelta(days=5)
         self._test(
+            connection,
             func.current_timestamp() - datetime.timedelta(days=5),
             {
                 "hour": actual_ts.hour,
@@ -1014,9 +1015,10 @@ class ExtractTest(fixtures.TablesTest):
             },
         )
 
-    def test_four(self):
+    def test_four(self, connection):
         t = self.tables.t
         self._test(
+            connection,
             datetime.timedelta(days=5) + t.c.dt,
             overrides={
                 "day": 15,
@@ -1026,23 +1028,26 @@ class ExtractTest(fixtures.TablesTest):
             },
         )
 
-    def test_five(self):
+    def test_five(self, connection):
         t = self.tables.t
         self._test(
+            connection,
             func.coalesce(t.c.dtme, func.current_timestamp()),
             overrides={"epoch": 1336652125.0},
         )
 
-    def test_six(self):
+    def test_six(self, connection):
         t = self.tables.t
         self._test(
+            connection,
             t.c.tm + datetime.timedelta(seconds=30),
             "time",
             overrides={"second": 55},
         )
 
-    def test_seven(self):
+    def test_seven(self, connection):
         self._test(
+            connection,
             literal(datetime.timedelta(seconds=10))
             - literal(datetime.timedelta(seconds=10)),
             "all",
@@ -1056,49 +1061,53 @@ class ExtractTest(fixtures.TablesTest):
             },
         )
 
-    def test_eight(self):
+    def test_eight(self, connection):
         t = self.tables.t
         self._test(
+            connection,
             t.c.tm + datetime.timedelta(seconds=30),
             {"hour": 12, "minute": 15, "second": 55},
         )
 
-    def test_nine(self):
-        self._test(text("t.dt + t.tm"))
+    def test_nine(self, connection):
+        self._test(connection, text("t.dt + t.tm"))
 
-    def test_ten(self):
+    def test_ten(self, connection):
         t = self.tables.t
-        self._test(t.c.dt + t.c.tm)
+        self._test(connection, t.c.dt + t.c.tm)
 
-    def test_eleven(self):
+    def test_eleven(self, connection):
         self._test(
+            connection,
             func.current_timestamp() - func.current_timestamp(),
             {"year": 0, "month": 0, "day": 0, "hour": 0},
         )
 
-    def test_twelve(self):
+    def test_twelve(self, connection):
         t = self.tables.t
-        actual_ts = self.bind.scalar(func.current_timestamp()).replace(
-            tzinfo=None
-        ) - datetime.datetime(2012, 5, 10, 12, 15, 25)
 
-        self._test(
+        actual_ts = connection.scalar(
             func.current_timestamp()
-            - func.coalesce(t.c.dtme, func.current_timestamp()),
+        ) - datetime.datetime(2012, 5, 10, 12, 15, 25, tzinfo=self.TZ())
+
+        self._test(
+            connection,
+            func.current_timestamp() - t.c.dttz,
             {"day": actual_ts.days},
         )
 
-    def test_thirteen(self):
+    def test_thirteen(self, connection):
         t = self.tables.t
-        self._test(t.c.dttz, "all+tz")
+        self._test(connection, t.c.dttz, "all+tz")
 
-    def test_fourteen(self):
+    def test_fourteen(self, connection):
         t = self.tables.t
-        self._test(t.c.tm, "time")
+        self._test(connection, t.c.tm, "time")
 
-    def test_fifteen(self):
+    def test_fifteen(self, connection):
         t = self.tables.t
         self._test(
+            connection,
             datetime.timedelta(days=5) + t.c.dtme,
             overrides={"day": 15, "epoch": 1337084125.0},
         )
index 70671134f1f588853022851892fca7b4ed2a1c0b..fa5bd8cffb3685b392f7ac2e06e0b66598945cfa 100644 (file)
@@ -738,7 +738,8 @@ class PoolEventsTest(PoolTestBase):
         event.listen(engine, "connect", listen_three)
         event.listen(engine.__class__, "connect", listen_four)
 
-        engine.execute(select(1)).close()
+        with engine.connect() as conn:
+            conn.execute(select(1))
         eq_(
             canary, ["listen_one", "listen_four", "listen_two", "listen_three"]
         )
index bb06d9648ba905ca7badaea2e47b9d7c445bc3a7..76e94918a8bfc5567f5042a7cb7b094c1dd6ecf1 100644 (file)
@@ -233,11 +233,11 @@ class ShardTest(object):
         self._fixture_data()
         # not sure what this is testing except the fixture data itself
         eq_(
-            db2.execute(weather_locations.select()).fetchall(),
+            db2.connect().execute(weather_locations.select()).fetchall(),
             [(1, "Asia", "Tokyo")],
         )
         eq_(
-            db1.execute(weather_locations.select()).fetchall(),
+            db1.connect().execute(weather_locations.select()).fetchall(),
             [
                 (2, "North America", "New York"),
                 (3, "North America", "Toronto"),
index 6ee87eefe76514f34ef88b6b30fee71388e747f7..5fb7cf50ffb495a4f8e27b5efd4dac4794dc15f2 100644 (file)
@@ -775,7 +775,7 @@ class MappedSelectTest(fixtures.MappedTest):
             },
         )
 
-    def test_set_composite_attrs_via_selectable(self):
+    def test_set_composite_attrs_via_selectable(self, connection):
         Values, CustomValues, values, Descriptions, descriptions = (
             self.classes.Values,
             self.classes.CustomValues,
@@ -796,11 +796,11 @@ class MappedSelectTest(fixtures.MappedTest):
         session.add(d)
         session.commit()
         eq_(
-            testing.db.execute(descriptions.select()).fetchall(),
+            connection.execute(descriptions.select()).fetchall(),
             [(1, "Color", "Number")],
         )
         eq_(
-            testing.db.execute(values.select()).fetchall(),
+            connection.execute(values.select()).fetchall(),
             [(1, 1, "Red", "5"), (2, 1, "Blue", "1")],
         )
 
index b4ab4047d5f50513b610570548ee80a5c8c8c386..4a936dd3e524bd287cc110738b5382cbdef11645 100644 (file)
@@ -923,15 +923,19 @@ class UOWTest(
         sess.add(u)
         sess.commit()
         eq_(
-            testing.db.scalar(
+            sess.connection()
+            .execute(
                 select(func.count("*")).where(addresses.c.user_id == None)
-            ),  # noqa
+            )
+            .scalar(),  # noqa
             0,
         )
         eq_(
-            testing.db.scalar(
+            sess.connection()
+            .execute(
                 select(func.count("*")).where(addresses.c.user_id != None)
-            ),  # noqa
+            )
+            .scalar(),  # noqa
             6,
         )
 
@@ -941,26 +945,30 @@ class UOWTest(
 
         if expected:
             eq_(
-                testing.db.scalar(
+                sess.connection()
+                .execute(
                     select(func.count("*")).where(
                         addresses.c.user_id == None
                     )  # noqa
-                ),
+                )
+                .scalar(),
                 6,
             )
             eq_(
-                testing.db.scalar(
+                sess.connection()
+                .execute(
                     select(func.count("*")).where(
                         addresses.c.user_id != None
                     )  # noqa
-                ),
+                )
+                .scalar(),
                 0,
             )
         else:
             eq_(
-                testing.db.scalar(
-                    select(func.count("*")).select_from(addresses)
-                ),
+                sess.connection()
+                .execute(select(func.count("*")).select_from(addresses))
+                .scalar(),
                 0,
             )
 
index 86ae3d4b5b4e416a82b8f744f3abc136a9c1d613..6786c8dafcd9b44953950beda9e642552f5e9618 100644 (file)
@@ -1660,15 +1660,17 @@ class DefaultRequirements(SuiteRequirements):
         )
 
     def _has_mysql_on_windows(self, config):
-        return (
-            against(config, ["mysql", "mariadb"])
-        ) and config.db.dialect._detect_casing(config.db) == 1
+        with config.db.connect() as conn:
+            return (
+                against(config, ["mysql", "mariadb"])
+            ) and config.db.dialect._detect_casing(conn) == 1
 
     def _has_mysql_fully_case_sensitive(self, config):
-        return (
-            against(config, "mysql")
-            and config.db.dialect._detect_casing(config.db) == 0
-        )
+        with config.db.connect() as conn:
+            return (
+                against(config, "mysql")
+                and config.db.dialect._detect_casing(conn) == 0
+            )
 
     @property
     def postgresql_utf8_server_encoding(self):