]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Clean up .execute in test/sql/test_functions.py
authorGord Thompson <gord@gordthompson.com>
Tue, 14 Apr 2020 13:15:07 +0000 (07:15 -0600)
committerGord Thompson <gord@gordthompson.com>
Mon, 20 Apr 2020 13:46:24 +0000 (07:46 -0600)
Change-Id: I2bc7a50893f90c6ea7e119a8558731ee32965871

test/dialect/mysql/test_types.py
test/sql/test_functions.py
test/sql/test_metadata.py

index 74f743852633f691e14de0e1e5d5bff04a17f3b8..009d1b26e8f5afb099e39d1298e23e2bab6e21a6 100644 (file)
@@ -1213,7 +1213,7 @@ class EnumSetTest(
         t2 = Table("table", m2, autoload=True)
 
         # TODO: what's wrong with the last element ?  is there
-        # latin-1 stuff forcing its way in ?
+        #       latin-1 stuff forcing its way in ?
 
         eq_(
             t2.c.value.type.enums[0:2], [u("réveillé"), u("drôle")]
index 5a6e6252b4c92bd23d6510f17ca6a8e2ae327cff..317c4677a80d96c1d8c99bcaea2105503518e409 100644 (file)
@@ -13,7 +13,6 @@ from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import literal_column
-from sqlalchemy import MetaData
 from sqlalchemy import Numeric
 from sqlalchemy import select
 from sqlalchemy import Sequence
@@ -37,7 +36,6 @@ from sqlalchemy.sql.functions import GenericFunction
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
-from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -945,11 +943,10 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
 class ExecuteTest(fixtures.TestBase):
     __backend__ = True
 
-    @engines.close_first
     def tearDown(self):
         pass
 
-    def test_conn_execute(self):
+    def test_conn_execute(self, connection):
         from sqlalchemy.sql.expression import FunctionElement
         from sqlalchemy.ext.compiler import compiles
 
@@ -960,17 +957,14 @@ class ExecuteTest(fixtures.TestBase):
         def compile_(elem, compiler, **kw):
             return compiler.process(func.current_date())
 
-        conn = testing.db.connect()
-        try:
-            x = conn.execute(func.current_date()).scalar()
-            y = conn.execute(func.current_date().select()).scalar()
-            z = conn.scalar(func.current_date())
-            q = conn.scalar(myfunc())
-        finally:
-            conn.close()
+        x = connection.execute(func.current_date()).scalar()
+        y = connection.execute(func.current_date().select()).scalar()
+        z = connection.scalar(func.current_date())
+        q = connection.scalar(myfunc())
+
         assert (x == y == z == q) is True
 
-    def test_exec_options(self):
+    def test_exec_options(self, connection):
         f = func.foo()
         eq_(f._execution_options, {})
 
@@ -979,13 +973,12 @@ class ExecuteTest(fixtures.TestBase):
         s = f.select()
         eq_(s._execution_options, {"foo": "bar"})
 
-        ret = testing.db.execute(func.now().execution_options(foo="bar"))
+        ret = connection.execute(func.now().execution_options(foo="bar"))
         eq_(ret.context.execution_options, {"foo": "bar"})
         ret.close()
 
-    @engines.close_first
     @testing.provide_metadata
-    def test_update(self):
+    def test_update(self, connection):
         """
         Tests sending functions and SQL expressions to the VALUES and SET
         clauses of INSERT/UPDATE instances, and that column-level defaults
@@ -1016,71 +1009,86 @@ class ExecuteTest(fixtures.TestBase):
             Column("value", Integer, default=7),
             Column("stuff", String(20), onupdate="thisisstuff"),
         )
-        meta.create_all()
-        t.insert(values=dict(value=func.length("one"))).execute()
-        assert t.select().execute().first().value == 3
-        t.update(values=dict(value=func.length("asfda"))).execute()
-        assert t.select().execute().first().value == 5
-
-        r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
+        meta.create_all(connection)
+        connection.execute(t.insert(values=dict(value=func.length("one"))))
+        eq_(connection.execute(t.select()).first().value, 3)
+        connection.execute(t.update(values=dict(value=func.length("asfda"))))
+        eq_(connection.execute(t.select()).first().value, 5)
+
+        r = connection.execute(
+            t.insert(values=dict(value=func.length("sfsaafsda")))
+        )
         id_ = r.inserted_primary_key[0]
-        assert t.select(t.c.id == id_).execute().first().value == 9
-        t.update(values={t.c.value: func.length("asdf")}).execute()
-        assert t.select().execute().first().value == 4
-        t2.insert().execute()
-        t2.insert(values=dict(value=func.length("one"))).execute()
-        t2.insert(values=dict(value=func.length("asfda") + -19)).execute(
-            stuff="hi"
+        eq_(connection.execute(t.select(t.c.id == id_)).first().value, 9)
+        connection.execute(t.update(values={t.c.value: func.length("asdf")}))
+        eq_(connection.execute(t.select()).first().value, 4)
+        connection.execute(t2.insert())
+        connection.execute(t2.insert(values=dict(value=func.length("one"))))
+        connection.execute(
+            t2.insert(values=dict(value=func.length("asfda") + -19)),
+            stuff="hi",
         )
 
-        res = exec_sorted(select([t2.c.value, t2.c.stuff]))
+        res = sorted(connection.execute(select([t2.c.value, t2.c.stuff])))
         eq_(res, [(-14, "hi"), (3, None), (7, None)])
 
-        t2.update(values=dict(value=func.length("asdsafasd"))).execute(
-            stuff="some stuff"
+        connection.execute(
+            t2.update(values=dict(value=func.length("asdsafasd"))),
+            stuff="some stuff",
+        )
+        eq_(
+            connection.execute(select([t2.c.value, t2.c.stuff])).fetchall(),
+            [(9, "some stuff"), (9, "some stuff"), (9, "some stuff")],
         )
-        assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [
-            (9, "some stuff"),
-            (9, "some stuff"),
-            (9, "some stuff"),
-        ]
 
-        t2.delete().execute()
+        connection.execute(t2.delete())
 
-        t2.insert(values=dict(value=func.length("one") + 8)).execute()
-        assert t2.select().execute().first().value == 11
+        connection.execute(
+            t2.insert(values=dict(value=func.length("one") + 8))
+        )
+        eq_(connection.execute(t2.select()).first().value, 11)
 
-        t2.update(values=dict(value=func.length("asfda"))).execute()
+        connection.execute(t2.update(values=dict(value=func.length("asfda"))))
         eq_(
-            select([t2.c.value, t2.c.stuff]).execute().first(),
+            connection.execute(select([t2.c.value, t2.c.stuff])).first(),
             (5, "thisisstuff"),
         )
 
-        t2.update(
-            values={t2.c.value: func.length("asfdaasdf"), t2.c.stuff: "foo"}
-        ).execute()
+        connection.execute(
+            t2.update(
+                values={
+                    t2.c.value: func.length("asfdaasdf"),
+                    t2.c.stuff: "foo",
+                }
+            )
+        )
 
-        eq_(select([t2.c.value, t2.c.stuff]).execute().first(), (9, "foo"))
+        eq_(
+            connection.execute(select([t2.c.value, t2.c.stuff])).first(),
+            (9, "foo"),
+        )
 
     @testing.fails_on_everything_except("postgresql")
-    def test_as_from(self):
+    def test_as_from(self, connection):
         # TODO: shouldn't this work on oracle too ?
-        x = func.current_date(bind=testing.db).execute().scalar()
-        y = func.current_date(bind=testing.db).select().execute().scalar()
-        z = func.current_date(bind=testing.db).scalar()
-        w = select(
-            ["*"], from_obj=[func.current_date(bind=testing.db)]
+        x = connection.execute(func.current_date(bind=testing.db)).scalar()
+        y = connection.execute(
+            func.current_date(bind=testing.db).select()
         ).scalar()
+        z = connection.scalar(func.current_date(bind=testing.db))
+        w = connection.scalar(
+            select(["*"], from_obj=[func.current_date(bind=testing.db)])
+        )
 
         assert x == y == z == w
 
-    def test_extract_bind(self):
+    def test_extract_bind(self, connection):
         """Basic common denominator execution tests for extract()"""
 
         date = datetime.date(2010, 5, 1)
 
         def execute(field):
-            return testing.db.execute(select([extract(field, date)])).scalar()
+            return connection.execute(select([extract(field, date)])).scalar()
 
         assert execute("year") == 2010
         assert execute("month") == 5
@@ -1092,34 +1100,25 @@ class ExecuteTest(fixtures.TestBase):
         assert execute("month") == 5
         assert execute("day") == 1
 
-    def test_extract_expression(self):
-        meta = MetaData(testing.db)
+    @testing.provide_metadata
+    def test_extract_expression(self, connection):
+        meta = self.metadata
         table = Table("test", meta, Column("dt", DateTime), Column("d", Date))
-        meta.create_all()
-        try:
-            table.insert().execute(
-                {
-                    "dt": datetime.datetime(2010, 5, 1, 12, 11, 10),
-                    "d": datetime.date(2010, 5, 1),
-                }
-            )
-            rs = select(
-                [extract("year", table.c.dt), extract("month", table.c.d)]
-            ).execute()
-            row = rs.first()
-            assert row[0] == 2010
-            assert row[1] == 5
-            rs.close()
-        finally:
-            meta.drop_all()
-
-
-def exec_sorted(statement, *args, **kw):
-    """Executes a statement and returns a sorted list plain tuple rows."""
-
-    return sorted(
-        [tuple(row) for row in statement.execute(*args, **kw).fetchall()]
-    )
+        meta.create_all(connection)
+        connection.execute(
+            table.insert(),
+            {
+                "dt": datetime.datetime(2010, 5, 1, 12, 11, 10),
+                "d": datetime.date(2010, 5, 1),
+            },
+        )
+        rs = connection.execute(
+            select([extract("year", table.c.dt), extract("month", table.c.d)])
+        )
+        row = rs.first()
+        assert row[0] == 2010
+        assert row[1] == 5
+        rs.close()
 
 
 class RegisterTest(fixtures.TestBase, AssertsCompiledSQL):
index fac369bb4b3e862e1c46364fd9e6a88dbb0f3e01..c57932bede424756a6c1967ac199d1d59d8b44cb 100644 (file)
@@ -670,7 +670,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
 class ToMetaDataTest(fixtures.TestBase, ComparesTables):
     @testing.requires.check_constraints
     def test_copy(self):
-        # TODO: modernize this test
+        # TODO: modernize this test for 2.0
 
         from sqlalchemy.testing.schema import Table