]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Replace engine.execute w/ context manager (step1)
authorGord Thompson <gord@gordthompson.com>
Thu, 13 Feb 2020 19:14:42 +0000 (12:14 -0700)
committerGord Thompson <gord@gordthompson.com>
Mon, 17 Feb 2020 17:15:12 +0000 (10:15 -0700)
First (baby) step at replacing engine.execute
calls in test code with the new preferred way
of executing. MSSQL was targeted because it was
the easiest for me to test locally.

Change-Id: Id2e02f0e39007cbfd28ca6a535115f53c6407015

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/suite/test_insert.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/mssql/test_query.py
test/dialect/mssql/test_types.py

index f900441a2d4962fb03fd0acffc765f36bc706b31..01e07729dabf2245a28e69e54a014f55c691d2ec 100644 (file)
@@ -182,7 +182,8 @@ execution.  Given this example::
                     Column('x', Integer))
     m.create_all(engine)
 
-    engine.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
+    with engine.begin() as conn:
+        conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
 
 The above column will be created with IDENTITY, however the INSERT statement
 we emit is specifying explicit values.  In the echo output we can see
index 62bf9fc1fa587d878056dd051eea2ad9089c17dd..bae0cee89d258582354acd3833bd1c57d1b3a04a 100644 (file)
@@ -56,6 +56,32 @@ class TestBase(object):
         if hasattr(self, "tearDown"):
             self.tearDown()
 
+    @config.fixture()
+    def connection(self):
+        conn = config.db.connect()
+        trans = conn.begin()
+        try:
+            yield conn
+        finally:
+            trans.rollback()
+            conn.close()
+
+    # propose a replacement for @testing.provide_metadata.
+    # the problem with this is that TablesTest below has a ".metadata"
+    # attribute already which is accessed directly as part of the
+    # @testing.provide_metadata pattern.  Might need to call this _metadata
+    # for it to be useful.
+    # @config.fixture()
+    # def metadata(self):
+    #    """Provide bound MetaData for a single test, dropping afterwards."""
+    #
+    #    from . import engines
+    #    metadata = schema.MetaData(config.db)
+    #    try:
+    #        yield metadata
+    #    finally:
+    #       engines.drop_all_tables(metadata, config.db)
+
 
 class TablesTest(TestBase):
 
index 2cc8761b8f3fe1cac7cbc289ef3f43627a44d73f..931b0ef651d71ac90bedfe935a3d0a14029cf9a8 100644 (file)
@@ -109,7 +109,8 @@ class InsertBehaviorTest(fixtures.TablesTest):
         else:
             engine = config.db
 
-        r = engine.execute(self.tables.autoinc_pk.insert(), data="some data")
+        with engine.begin() as conn:
+            r = conn.execute(self.tables.autoinc_pk.insert(), data="some data")
         assert r._soft_closed
         assert not r.closed
         assert r.is_insert
@@ -278,9 +279,10 @@ class ReturningTest(fixtures.TablesTest):
     def test_explicit_returning_pk_autocommit(self):
         engine = config.db
         table = self.tables.autoinc_pk
-        r = engine.execute(
-            table.insert().returning(table.c.id), data="some data"
-        )
+        with engine.begin() as conn:
+            r = conn.execute(
+                table.insert().returning(table.c.id), data="some data"
+            )
         pk = r.first()[0]
         fetched_pk = config.db.scalar(select([table.c.id]))
         eq_(fetched_pk, pk)
index 4fc0bb79ddd46199b858eccef9a8cc9d84185e8b..d77d13efac6bc02a7ca3b381ab277e0854ac8dcd 100644 (file)
@@ -225,7 +225,7 @@ class ServerSideCursorsTest(
 
     def _is_server_side(self, cursor):
         if self.engine.dialect.driver == "psycopg2":
-            return cursor.name
+            return bool(cursor.name)
         elif self.engine.dialect.driver == "pymysql":
             sscursor = __import__("pymysql.cursors").cursors.SSCursor
             return isinstance(cursor, sscursor)
@@ -245,43 +245,48 @@ class ServerSideCursorsTest(
         engines.testing_reaper.close_all()
         self.engine.dispose()
 
-    def test_global_string(self):
-        engine = self._fixture(True)
-        result = engine.execute("select 1")
-        assert self._is_server_side(result.cursor)
-
-    def test_global_text(self):
-        engine = self._fixture(True)
-        result = engine.execute(text("select 1"))
-        assert self._is_server_side(result.cursor)
-
-    def test_global_expr(self):
-        engine = self._fixture(True)
-        result = engine.execute(select([1]))
-        assert self._is_server_side(result.cursor)
-
-    def test_global_off_explicit(self):
-        engine = self._fixture(False)
-        result = engine.execute(text("select 1"))
-
-        # It should be off globally ...
-
-        assert not self._is_server_side(result.cursor)
-
-    def test_stmt_option(self):
-        engine = self._fixture(False)
-
-        s = select([1]).execution_options(stream_results=True)
-        result = engine.execute(s)
-
-        # ... but enabled for this one.
-
-        assert self._is_server_side(result.cursor)
+    @testing.combinations(
+        ("global_string", True, "select 1", True),
+        ("global_text", True, text("select 1"), True),
+        ("global_expr", True, select([1]), True),
+        ("global_off_explicit", False, text("select 1"), False),
+        (
+            "stmt_option",
+            False,
+            select([1]).execution_options(stream_results=True),
+            True,
+        ),
+        (
+            "stmt_option_disabled",
+            True,
+            select([1]).execution_options(stream_results=False),
+            False,
+        ),
+        ("for_update_expr", True, select([1]).with_for_update(), True),
+        ("for_update_string", True, "SELECT 1 FOR UPDATE", True),
+        ("text_no_ss", False, text("select 42"), False),
+        (
+            "text_ss_option",
+            False,
+            text("select 42").execution_options(stream_results=True),
+            True,
+        ),
+        id_="iaaa",
+        argnames="engine_ss_arg, statement, cursor_ss_status",
+    )
+    def test_ss_cursor_status(
+        self, engine_ss_arg, statement, cursor_ss_status
+    ):
+        engine = self._fixture(engine_ss_arg)
+        with engine.begin() as conn:
+            result = conn.execute(statement)
+            eq_(self._is_server_side(result.cursor), cursor_ss_status)
+            result.close()
 
     def test_conn_option(self):
         engine = self._fixture(False)
 
-        # and this one
+        # should be enabled for this one
         result = (
             engine.connect()
             .execution_options(stream_results=True)
@@ -300,46 +305,21 @@ class ServerSideCursorsTest(
         )
         assert not self._is_server_side(result.cursor)
 
-    def test_stmt_option_disabled(self):
-        engine = self._fixture(True)
-        s = select([1]).execution_options(stream_results=False)
-        result = engine.execute(s)
-        assert not self._is_server_side(result.cursor)
-
     def test_aliases_and_ss(self):
         engine = self._fixture(False)
         s1 = select([1]).execution_options(stream_results=True).alias()
-        result = engine.execute(s1)
-        assert self._is_server_side(result.cursor)
+        with engine.begin() as conn:
+            result = conn.execute(s1)
+            assert self._is_server_side(result.cursor)
+            result.close()
 
         # s1's options shouldn't affect s2 when s2 is used as a
         # from_obj.
         s2 = select([1], from_obj=s1)
-        result = engine.execute(s2)
-        assert not self._is_server_side(result.cursor)
-
-    def test_for_update_expr(self):
-        engine = self._fixture(True)
-        s1 = select([1]).with_for_update()
-        result = engine.execute(s1)
-        assert self._is_server_side(result.cursor)
-
-    def test_for_update_string(self):
-        engine = self._fixture(True)
-        result = engine.execute("SELECT 1 FOR UPDATE")
-        assert self._is_server_side(result.cursor)
-
-    def test_text_no_ss(self):
-        engine = self._fixture(False)
-        s = text("select 42")
-        result = engine.execute(s)
-        assert not self._is_server_side(result.cursor)
-
-    def test_text_ss_option(self):
-        engine = self._fixture(False)
-        s = text("select 42").execution_options(stream_results=True)
-        result = engine.execute(s)
-        assert self._is_server_side(result.cursor)
+        with engine.begin() as conn:
+            result = conn.execute(s2)
+            assert not self._is_server_side(result.cursor)
+            result.close()
 
     @testing.provide_metadata
     def test_roundtrip(self):
index 2a5dad2d6c016519e3c10971de70cae9f8b92073..a334b8ebcdd83dbff578d72b5dc3c9bb989cbefe 100644 (file)
@@ -871,54 +871,51 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
 
         # support sqlite :memory: database...
         data_table.create(engine, checkfirst=True)
-        engine.execute(
-            data_table.insert(), {"name": "row1", "data": data_element}
-        )
-
-        row = engine.execute(select([data_table.c.data])).first()
+        with engine.connect() as conn:
+            conn.execute(
+                data_table.insert(), {"name": "row1", "data": data_element}
+            )
+            row = conn.execute(select([data_table.c.data])).first()
 
-        eq_(row, (data_element,))
-        eq_(js.mock_calls, [mock.call(data_element)])
-        eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+            eq_(row, (data_element,))
+            eq_(js.mock_calls, [mock.call(data_element)])
+            eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
 
-    def test_round_trip_none_as_sql_null(self):
+    def test_round_trip_none_as_sql_null(self, connection):
         col = self.tables.data_table.c["nulldata"]
 
-        with config.db.connect() as conn:
-            conn.execute(
-                self.tables.data_table.insert(), {"name": "r1", "data": None}
-            )
+        conn = connection
+        conn.execute(
+            self.tables.data_table.insert(), {"name": "r1", "data": None}
+        )
 
-            eq_(
-                conn.scalar(
-                    select([self.tables.data_table.c.name]).where(
-                        col.is_(null())
-                    )
-                ),
-                "r1",
-            )
+        eq_(
+            conn.scalar(
+                select([self.tables.data_table.c.name]).where(col.is_(null()))
+            ),
+            "r1",
+        )
 
-            eq_(conn.scalar(select([col])), None)
+        eq_(conn.scalar(select([col])), None)
 
-    def test_round_trip_json_null_as_json_null(self):
+    def test_round_trip_json_null_as_json_null(self, connection):
         col = self.tables.data_table.c["data"]
 
-        with config.db.connect() as conn:
-            conn.execute(
-                self.tables.data_table.insert(),
-                {"name": "r1", "data": JSON.NULL},
-            )
+        conn = connection
+        conn.execute(
+            self.tables.data_table.insert(), {"name": "r1", "data": JSON.NULL},
+        )
 
-            eq_(
-                conn.scalar(
-                    select([self.tables.data_table.c.name]).where(
-                        cast(col, String) == "null"
-                    )
-                ),
-                "r1",
-            )
+        eq_(
+            conn.scalar(
+                select([self.tables.data_table.c.name]).where(
+                    cast(col, String) == "null"
+                )
+            ),
+            "r1",
+        )
 
-            eq_(conn.scalar(select([col])), None)
+        eq_(conn.scalar(select([col])), None)
 
     def test_round_trip_none_as_json_null(self):
         col = self.tables.data_table.c["data"]
index 718b18f5b707f9d0c661ebb0b80ff5552ca93804..aa08502223cfb5de11d3b4dce59e57b38f302d76 100644 (file)
@@ -356,7 +356,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase):
         metadata.create_all(engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            engine.execute(t1.insert(), {"data": "somedata"})
+            with engine.begin() as conn:
+                conn.execute(t1.insert(), {"data": "somedata"})
 
         # TODO: need a dialect SQL that acts like Cursor SQL
         asserter.assert_(
@@ -381,7 +382,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase):
         metadata.create_all(engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            engine.execute(t1.insert())
+            with engine.begin() as conn:
+                conn.execute(t1.insert())
 
         # even with pyodbc, we don't embed the scope identity on a
         # DEFAULT VALUES insert
@@ -409,7 +411,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase):
         metadata.create_all(engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            engine.execute(t1.insert(), {"data": "somedata"})
+            with engine.begin() as conn:
+                conn.execute(t1.insert(), {"data": "somedata"})
 
         # pyodbc-specific system
         asserter.assert_(
index 92d3d9e3274e16bcffc41ac1854c9869d016e116..c95ac6e6d9ef23baa72a836ceeb1e4369a42263c 100644 (file)
@@ -1026,18 +1026,24 @@ class TypeRoundTripTest(
                 ]
 
             for counter, engine in enumerate(eng):
-                engine.execute(tbl.insert())
-                if "int_y" in tbl.c:
-                    assert engine.scalar(select([tbl.c.int_y])) == counter + 1
-                    assert (
-                        list(engine.execute(tbl.select()).first()).count(
-                            counter + 1
+                with engine.begin() as conn:
+                    conn.execute(tbl.insert())
+                    if "int_y" in tbl.c:
+                        eq_(
+                            conn.execute(select([tbl.c.int_y])).scalar(),
+                            counter + 1,
                         )
-                        == 1
-                    )
-                else:
-                    assert 1 not in list(engine.execute(tbl.select()).first())
-                engine.execute(tbl.delete())
+                        assert (
+                            list(conn.execute(tbl.select()).first()).count(
+                                counter + 1
+                            )
+                            == 1
+                        )
+                    else:
+                        assert 1 not in list(
+                            conn.execute(tbl.select()).first()
+                        )
+                    conn.execute(tbl.delete())
 
 
 class StringTest(fixtures.TestBase, AssertsCompiledSQL):