]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Replace engine.execute w/ context manager (step1)
authorGord Thompson <gord@gordthompson.com>
Thu, 13 Feb 2020 19:14:42 +0000 (12:14 -0700)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Apr 2020 20:08:42 +0000 (16:08 -0400)
First (baby) step at replacing engine.execute
calls in test code with the new preferred way
of executing. MSSQL was targeted because it was
the easiest for me to test locally.

Change-Id: Id2e02f0e39007cbfd28ca6a535115f53c6407015
(cherry picked from commit 60f627cbd0d769e65353e720548efac9d8ab95d9)

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/suite/test_insert.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/mssql/test_query.py
test/dialect/mssql/test_types.py

index 6ada7356dc0df95ec98eb4d79d5f395759c11f5c..b03300a6ff30c74cc701edd1eebb29a551bc22f9 100644 (file)
@@ -182,7 +182,8 @@ execution.  Given this example::
                     Column('x', Integer))
     m.create_all(engine)
 
-    engine.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
+    with engine.begin() as conn:
+        conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
 
 The above column will be created with IDENTITY, however the INSERT statement
 we emit is specifying explicit values.  In the echo output we can see
index 946b885d734326d283f364039d922b1c2cb806c0..4f608340195ff21b1318ef5013bc3f363260e219 100644 (file)
@@ -57,6 +57,32 @@ class TestBase(object):
         if hasattr(self, "tearDown"):
             self.tearDown()
 
+    @config.fixture()
+    def connection(self):
+        conn = config.db.connect()
+        trans = conn.begin()
+        try:
+            yield conn
+        finally:
+            trans.rollback()
+            conn.close()
+
+    # propose a replacement for @testing.provide_metadata.
+    # the problem with this is that TablesTest below has a ".metadata"
+    # attribute already which is accessed directly as part of the
+    # @testing.provide_metadata pattern.  Might need to call this _metadata
+    # for it to be useful.
+    # @config.fixture()
+    # def metadata(self):
+    #    """Provide bound MetaData for a single test, dropping afterwards."""
+    #
+    #    from . import engines
+    #    metadata = schema.MetaData(config.db)
+    #    try:
+    #        yield metadata
+    #    finally:
+    #       engines.drop_all_tables(metadata, config.db)
+
 
 class TablesTest(TestBase):
 
index fc535aa23d110bbc15df97f97b40a4db3130401a..c09f5586e35ed368d664d196f33f500ca834a540 100644 (file)
@@ -113,7 +113,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
         else:
             engine = config.db
 
-        r = engine.execute(self.tables.autoinc_pk.insert(), data="some data")
+        with engine.begin() as conn:
+            r = conn.execute(self.tables.autoinc_pk.insert(), data="some data")
         assert r._soft_closed
         assert not r.closed
         assert r.is_insert
@@ -282,9 +283,10 @@ class ReturningTest(fixtures.TablesTest):
     def test_explicit_returning_pk_autocommit(self):
         engine = config.db
         table = self.tables.autoinc_pk
-        r = engine.execute(
-            table.insert().returning(table.c.id), data="some data"
-        )
+        with engine.begin() as conn:
+            r = conn.execute(
+                table.insert().returning(table.c.id), data="some data"
+            )
         pk = r.first()[0]
         fetched_pk = config.db.scalar(select([table.c.id]))
         eq_(fetched_pk, pk)
index 67b0fd70396adf110484aba37964e54952f58bf3..125fefce98cad3cf7ba9541e8d91b5452821eddf 100644 (file)
@@ -225,7 +225,7 @@ class ServerSideCursorsTest(
 
     def _is_server_side(self, cursor):
         if self.engine.dialect.driver == "psycopg2":
-            return cursor.name
+            return bool(cursor.name)
         elif self.engine.dialect.driver == "pymysql":
             sscursor = __import__("pymysql.cursors").cursors.SSCursor
             return isinstance(cursor, sscursor)
@@ -245,43 +245,48 @@ class ServerSideCursorsTest(
         engines.testing_reaper.close_all()
         self.engine.dispose()
 
-    def test_global_string(self):
-        engine = self._fixture(True)
-        result = engine.execute("select 1")
-        assert self._is_server_side(result.cursor)
-
-    def test_global_text(self):
-        engine = self._fixture(True)
-        result = engine.execute(text("select 1"))
-        assert self._is_server_side(result.cursor)
-
-    def test_global_expr(self):
-        engine = self._fixture(True)
-        result = engine.execute(select([1]))
-        assert self._is_server_side(result.cursor)
-
-    def test_global_off_explicit(self):
-        engine = self._fixture(False)
-        result = engine.execute(text("select 1"))
-
-        # It should be off globally ...
-
-        assert not self._is_server_side(result.cursor)
-
-    def test_stmt_option(self):
-        engine = self._fixture(False)
-
-        s = select([1]).execution_options(stream_results=True)
-        result = engine.execute(s)
-
-        # ... but enabled for this one.
-
-        assert self._is_server_side(result.cursor)
+    @testing.combinations(
+        ("global_string", True, "select 1", True),
+        ("global_text", True, text("select 1"), True),
+        ("global_expr", True, select([1]), True),
+        ("global_off_explicit", False, text("select 1"), False),
+        (
+            "stmt_option",
+            False,
+            select([1]).execution_options(stream_results=True),
+            True,
+        ),
+        (
+            "stmt_option_disabled",
+            True,
+            select([1]).execution_options(stream_results=False),
+            False,
+        ),
+        ("for_update_expr", True, select([1]).with_for_update(), True),
+        ("for_update_string", True, "SELECT 1 FOR UPDATE", True),
+        ("text_no_ss", False, text("select 42"), False),
+        (
+            "text_ss_option",
+            False,
+            text("select 42").execution_options(stream_results=True),
+            True,
+        ),
+        id_="iaaa",
+        argnames="engine_ss_arg, statement, cursor_ss_status",
+    )
+    def test_ss_cursor_status(
+        self, engine_ss_arg, statement, cursor_ss_status
+    ):
+        engine = self._fixture(engine_ss_arg)
+        with engine.begin() as conn:
+            result = conn.execute(statement)
+            eq_(self._is_server_side(result.cursor), cursor_ss_status)
+            result.close()
 
     def test_conn_option(self):
         engine = self._fixture(False)
 
-        # and this one
+        # should be enabled for this one
         result = (
             engine.connect()
             .execution_options(stream_results=True)
@@ -300,46 +305,21 @@ class ServerSideCursorsTest(
         )
         assert not self._is_server_side(result.cursor)
 
-    def test_stmt_option_disabled(self):
-        engine = self._fixture(True)
-        s = select([1]).execution_options(stream_results=False)
-        result = engine.execute(s)
-        assert not self._is_server_side(result.cursor)
-
     def test_aliases_and_ss(self):
         engine = self._fixture(False)
         s1 = select([1]).execution_options(stream_results=True).alias()
-        result = engine.execute(s1)
-        assert self._is_server_side(result.cursor)
+        with engine.begin() as conn:
+            result = conn.execute(s1)
+            assert self._is_server_side(result.cursor)
+            result.close()
 
         # s1's options shouldn't affect s2 when s2 is used as a
         # from_obj.
         s2 = select([1], from_obj=s1)
-        result = engine.execute(s2)
-        assert not self._is_server_side(result.cursor)
-
-    def test_for_update_expr(self):
-        engine = self._fixture(True)
-        s1 = select([1]).with_for_update()
-        result = engine.execute(s1)
-        assert self._is_server_side(result.cursor)
-
-    def test_for_update_string(self):
-        engine = self._fixture(True)
-        result = engine.execute("SELECT 1 FOR UPDATE")
-        assert self._is_server_side(result.cursor)
-
-    def test_text_no_ss(self):
-        engine = self._fixture(False)
-        s = text("select 42")
-        result = engine.execute(s)
-        assert not self._is_server_side(result.cursor)
-
-    def test_text_ss_option(self):
-        engine = self._fixture(False)
-        s = text("select 42").execution_options(stream_results=True)
-        result = engine.execute(s)
-        assert self._is_server_side(result.cursor)
+        with engine.begin() as conn:
+            result = conn.execute(s2)
+            assert not self._is_server_side(result.cursor)
+            result.close()
 
     @testing.provide_metadata
     def test_roundtrip(self):
index 2a5dad2d6c016519e3c10971de70cae9f8b92073..a334b8ebcdd83dbff578d72b5dc3c9bb989cbefe 100644 (file)
@@ -871,54 +871,51 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
 
         # support sqlite :memory: database...
         data_table.create(engine, checkfirst=True)
-        engine.execute(
-            data_table.insert(), {"name": "row1", "data": data_element}
-        )
-
-        row = engine.execute(select([data_table.c.data])).first()
+        with engine.connect() as conn:
+            conn.execute(
+                data_table.insert(), {"name": "row1", "data": data_element}
+            )
+            row = conn.execute(select([data_table.c.data])).first()
 
-        eq_(row, (data_element,))
-        eq_(js.mock_calls, [mock.call(data_element)])
-        eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+            eq_(row, (data_element,))
+            eq_(js.mock_calls, [mock.call(data_element)])
+            eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
 
-    def test_round_trip_none_as_sql_null(self):
+    def test_round_trip_none_as_sql_null(self, connection):
         col = self.tables.data_table.c["nulldata"]
 
-        with config.db.connect() as conn:
-            conn.execute(
-                self.tables.data_table.insert(), {"name": "r1", "data": None}
-            )
+        conn = connection
+        conn.execute(
+            self.tables.data_table.insert(), {"name": "r1", "data": None}
+        )
 
-            eq_(
-                conn.scalar(
-                    select([self.tables.data_table.c.name]).where(
-                        col.is_(null())
-                    )
-                ),
-                "r1",
-            )
+        eq_(
+            conn.scalar(
+                select([self.tables.data_table.c.name]).where(col.is_(null()))
+            ),
+            "r1",
+        )
 
-            eq_(conn.scalar(select([col])), None)
+        eq_(conn.scalar(select([col])), None)
 
-    def test_round_trip_json_null_as_json_null(self):
+    def test_round_trip_json_null_as_json_null(self, connection):
         col = self.tables.data_table.c["data"]
 
-        with config.db.connect() as conn:
-            conn.execute(
-                self.tables.data_table.insert(),
-                {"name": "r1", "data": JSON.NULL},
-            )
+        conn = connection
+        conn.execute(
+            self.tables.data_table.insert(), {"name": "r1", "data": JSON.NULL},
+        )
 
-            eq_(
-                conn.scalar(
-                    select([self.tables.data_table.c.name]).where(
-                        cast(col, String) == "null"
-                    )
-                ),
-                "r1",
-            )
+        eq_(
+            conn.scalar(
+                select([self.tables.data_table.c.name]).where(
+                    cast(col, String) == "null"
+                )
+            ),
+            "r1",
+        )
 
-            eq_(conn.scalar(select([col])), None)
+        eq_(conn.scalar(select([col])), None)
 
     def test_round_trip_none_as_json_null(self):
         col = self.tables.data_table.c["data"]
index bf836fc14741d05029759bd912f15843195b662b..bfd984ce7fba0f9dbd595a1c53f0c61042b24a0d 100644 (file)
@@ -356,7 +356,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase):
         metadata.create_all(engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            engine.execute(t1.insert(), {"data": "somedata"})
+            with engine.begin() as conn:
+                conn.execute(t1.insert(), {"data": "somedata"})
 
         # TODO: need a dialect SQL that acts like Cursor SQL
         asserter.assert_(
@@ -379,7 +380,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase):
         metadata.create_all(engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            engine.execute(t1.insert())
+            with engine.begin() as conn:
+                conn.execute(t1.insert())
 
         # even with pyodbc, we don't embed the scope identity on a
         # DEFAULT VALUES insert
@@ -403,7 +405,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase):
         metadata.create_all(engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            engine.execute(t1.insert(), {"data": "somedata"})
+            with engine.begin() as conn:
+                conn.execute(t1.insert(), {"data": "somedata"})
 
         # pyodbc-specific system
         asserter.assert_(
index 92d3d9e3274e16bcffc41ac1854c9869d016e116..c95ac6e6d9ef23baa72a836ceeb1e4369a42263c 100644 (file)
@@ -1026,18 +1026,24 @@ class TypeRoundTripTest(
                 ]
 
             for counter, engine in enumerate(eng):
-                engine.execute(tbl.insert())
-                if "int_y" in tbl.c:
-                    assert engine.scalar(select([tbl.c.int_y])) == counter + 1
-                    assert (
-                        list(engine.execute(tbl.select()).first()).count(
-                            counter + 1
+                with engine.begin() as conn:
+                    conn.execute(tbl.insert())
+                    if "int_y" in tbl.c:
+                        eq_(
+                            conn.execute(select([tbl.c.int_y])).scalar(),
+                            counter + 1,
                         )
-                        == 1
-                    )
-                else:
-                    assert 1 not in list(engine.execute(tbl.select()).first())
-                engine.execute(tbl.delete())
+                        assert (
+                            list(conn.execute(tbl.select()).first()).count(
+                                counter + 1
+                            )
+                            == 1
+                        )
+                    else:
+                        assert 1 not in list(
+                            conn.execute(tbl.select()).first()
+                        )
+                    conn.execute(tbl.delete())
 
 
 class StringTest(fixtures.TestBase, AssertsCompiledSQL):