From: Gord Thompson Date: Thu, 13 Feb 2020 19:14:42 +0000 (-0700) Subject: Replace engine.execute w/ context manager (step1) X-Git-Tag: rel_1_3_16~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4a03d2100e0141e8df51e58ed8ee14e5584443fe;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Replace engine.execute w/ context manager (step1) First (baby) step at replacing engine.execute calls in test code with the new preferred way of executing. MSSQL was targeted because it was the easiest for me to test locally. Change-Id: Id2e02f0e39007cbfd28ca6a535115f53c6407015 (cherry picked from commit 60f627cbd0d769e65353e720548efac9d8ab95d9) --- diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 6ada7356dc..b03300a6ff 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -182,7 +182,8 @@ execution. Given this example:: Column('x', Integer)) m.create_all(engine) - engine.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) + with engine.begin() as conn: + conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) The above column will be created with IDENTITY, however the INSERT statement we emit is specifying explicit values. In the echo output we can see diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 946b885d73..4f60834019 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -57,6 +57,32 @@ class TestBase(object): if hasattr(self, "tearDown"): self.tearDown() + @config.fixture() + def connection(self): + conn = config.db.connect() + trans = conn.begin() + try: + yield conn + finally: + trans.rollback() + conn.close() + + # propose a replacement for @testing.provide_metadata. + # the problem with this is that TablesTest below has a ".metadata" + # attribute already which is accessed directly as part of the + # @testing.provide_metadata pattern. Might need to call this _metadata + # for it to be useful. + # @config.fixture() + # def metadata(self): + # """Provide bound MetaData for a single test, dropping afterwards.""" + # + # from . import engines + # metadata = schema.MetaData(config.db) + # try: + # yield metadata + # finally: + # engines.drop_all_tables(metadata, config.db) + class TablesTest(TestBase): diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index fc535aa23d..c09f5586e3 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -113,7 +113,8 @@ class InsertBehaviorTest(fixtures.TablesTest): else: engine = config.db - r = engine.execute(self.tables.autoinc_pk.insert(), data="some data") + with engine.begin() as conn: + r = conn.execute(self.tables.autoinc_pk.insert(), data="some data") assert r._soft_closed assert not r.closed assert r.is_insert @@ -282,9 +283,10 @@ class ReturningTest(fixtures.TablesTest): def test_explicit_returning_pk_autocommit(self): engine = config.db table = self.tables.autoinc_pk - r = engine.execute( - table.insert().returning(table.c.id), data="some data" - ) + with engine.begin() as conn: + r = conn.execute( + table.insert().returning(table.c.id), data="some data" + ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) eq_(fetched_pk, pk) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 67b0fd7039..125fefce98 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -225,7 +225,7 @@ class ServerSideCursorsTest( def _is_server_side(self, cursor): if self.engine.dialect.driver == "psycopg2": - return cursor.name + return bool(cursor.name) elif self.engine.dialect.driver == "pymysql": sscursor = __import__("pymysql.cursors").cursors.SSCursor return isinstance(cursor, sscursor) @@ -245,43 +245,48 @@ class ServerSideCursorsTest( engines.testing_reaper.close_all() self.engine.dispose() - def test_global_string(self): - engine = self._fixture(True) - result = engine.execute("select 1") - assert self._is_server_side(result.cursor) - - def test_global_text(self): - engine = self._fixture(True) - result = engine.execute(text("select 1")) - assert self._is_server_side(result.cursor) - - def test_global_expr(self): - engine = self._fixture(True) - result = engine.execute(select([1])) - assert self._is_server_side(result.cursor) - - def test_global_off_explicit(self): - engine = self._fixture(False) - result = engine.execute(text("select 1")) - - # It should be off globally ... - - assert not self._is_server_side(result.cursor) - - def test_stmt_option(self): - engine = self._fixture(False) - - s = select([1]).execution_options(stream_results=True) - result = engine.execute(s) - - # ... but enabled for this one. - - assert self._is_server_side(result.cursor) + @testing.combinations( + ("global_string", True, "select 1", True), + ("global_text", True, text("select 1"), True), + ("global_expr", True, select([1]), True), + ("global_off_explicit", False, text("select 1"), False), + ( + "stmt_option", + False, + select([1]).execution_options(stream_results=True), + True, + ), + ( + "stmt_option_disabled", + True, + select([1]).execution_options(stream_results=False), + False, + ), + ("for_update_expr", True, select([1]).with_for_update(), True), + ("for_update_string", True, "SELECT 1 FOR UPDATE", True), + ("text_no_ss", False, text("select 42"), False), + ( + "text_ss_option", + False, + text("select 42").execution_options(stream_results=True), + True, + ), + id_="iaaa", + argnames="engine_ss_arg, statement, cursor_ss_status", + ) + def test_ss_cursor_status( + self, engine_ss_arg, statement, cursor_ss_status + ): + engine = self._fixture(engine_ss_arg) + with engine.begin() as conn: + result = conn.execute(statement) + eq_(self._is_server_side(result.cursor), cursor_ss_status) + result.close() def test_conn_option(self): engine = self._fixture(False) - # and this one + # should be enabled for this one result = ( engine.connect() .execution_options(stream_results=True) @@ -300,46 +305,21 @@ class ServerSideCursorsTest( ) assert not self._is_server_side(result.cursor) - def test_stmt_option_disabled(self): - engine = self._fixture(True) - s = select([1]).execution_options(stream_results=False) - result = engine.execute(s) - assert not self._is_server_side(result.cursor) - def test_aliases_and_ss(self): engine = self._fixture(False) s1 = select([1]).execution_options(stream_results=True).alias() - result = engine.execute(s1) - assert self._is_server_side(result.cursor) + with engine.begin() as conn: + result = conn.execute(s1) + assert self._is_server_side(result.cursor) + result.close() # s1's options shouldn't affect s2 when s2 is used as a # from_obj. s2 = select([1], from_obj=s1) - result = engine.execute(s2) - assert not self._is_server_side(result.cursor) - - def test_for_update_expr(self): - engine = self._fixture(True) - s1 = select([1]).with_for_update() - result = engine.execute(s1) - assert self._is_server_side(result.cursor) - - def test_for_update_string(self): - engine = self._fixture(True) - result = engine.execute("SELECT 1 FOR UPDATE") - assert self._is_server_side(result.cursor) - - def test_text_no_ss(self): - engine = self._fixture(False) - s = text("select 42") - result = engine.execute(s) - assert not self._is_server_side(result.cursor) - - def test_text_ss_option(self): - engine = self._fixture(False) - s = text("select 42").execution_options(stream_results=True) - result = engine.execute(s) - assert self._is_server_side(result.cursor) + with engine.begin() as conn: + result = conn.execute(s2) + assert not self._is_server_side(result.cursor) + result.close() @testing.provide_metadata def test_roundtrip(self): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 2a5dad2d6c..a334b8ebcd 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -871,54 +871,51 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): # support sqlite :memory: database... data_table.create(engine, checkfirst=True) - engine.execute( - data_table.insert(), {"name": "row1", "data": data_element} - ) - - row = engine.execute(select([data_table.c.data])).first() + with engine.connect() as conn: + conn.execute( + data_table.insert(), {"name": "row1", "data": data_element} + ) + row = conn.execute(select([data_table.c.data])).first() - eq_(row, (data_element,)) - eq_(js.mock_calls, [mock.call(data_element)]) - eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + eq_(row, (data_element,)) + eq_(js.mock_calls, [mock.call(data_element)]) + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) - def test_round_trip_none_as_sql_null(self): + def test_round_trip_none_as_sql_null(self, connection): col = self.tables.data_table.c["nulldata"] - with config.db.connect() as conn: - conn.execute( - self.tables.data_table.insert(), {"name": "r1", "data": None} - ) + conn = connection + conn.execute( + self.tables.data_table.insert(), {"name": "r1", "data": None} + ) - eq_( - conn.scalar( - select([self.tables.data_table.c.name]).where( - col.is_(null()) - ) - ), - "r1", - ) + eq_( + conn.scalar( + select([self.tables.data_table.c.name]).where(col.is_(null())) + ), + "r1", + ) - eq_(conn.scalar(select([col])), None) + eq_(conn.scalar(select([col])), None) - def test_round_trip_json_null_as_json_null(self): + def test_round_trip_json_null_as_json_null(self, connection): col = self.tables.data_table.c["data"] - with config.db.connect() as conn: - conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": JSON.NULL}, - ) + conn = connection + conn.execute( + self.tables.data_table.insert(), {"name": "r1", "data": JSON.NULL}, + ) - eq_( - conn.scalar( - select([self.tables.data_table.c.name]).where( - cast(col, String) == "null" - ) - ), - "r1", - ) + eq_( + conn.scalar( + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) + ), + "r1", + ) - eq_(conn.scalar(select([col])), None) + eq_(conn.scalar(select([col])), None) def test_round_trip_none_as_json_null(self): col = self.tables.data_table.c["data"] diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index bf836fc147..bfd984ce7f 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -356,7 +356,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): metadata.create_all(engine) with self.sql_execution_asserter(engine) as asserter: - engine.execute(t1.insert(), {"data": "somedata"}) + with engine.begin() as conn: + conn.execute(t1.insert(), {"data": "somedata"}) # TODO: need a dialect SQL that acts like Cursor SQL asserter.assert_( @@ -379,7 +380,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): metadata.create_all(engine) with self.sql_execution_asserter(engine) as asserter: - engine.execute(t1.insert()) + with engine.begin() as conn: + conn.execute(t1.insert()) # even with pyodbc, we don't embed the scope identity on a # DEFAULT VALUES insert @@ -403,7 +405,8 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): metadata.create_all(engine) with self.sql_execution_asserter(engine) as asserter: - engine.execute(t1.insert(), {"data": "somedata"}) + with engine.begin() as conn: + conn.execute(t1.insert(), {"data": "somedata"}) # pyodbc-specific system asserter.assert_( diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index 92d3d9e327..c95ac6e6d9 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -1026,18 +1026,24 @@ class TypeRoundTripTest( ] for counter, engine in enumerate(eng): - engine.execute(tbl.insert()) - if "int_y" in tbl.c: - assert engine.scalar(select([tbl.c.int_y])) == counter + 1 - assert ( - list(engine.execute(tbl.select()).first()).count( - counter + 1 + with engine.begin() as conn: + conn.execute(tbl.insert()) + if "int_y" in tbl.c: + eq_( + conn.execute(select([tbl.c.int_y])).scalar(), + counter + 1, ) - == 1 - ) - else: - assert 1 not in list(engine.execute(tbl.select()).first()) - engine.execute(tbl.delete()) + assert ( + list(conn.execute(tbl.select()).first()).count( + counter + 1 + ) + == 1 + ) + else: + assert 1 not in list( + conn.execute(tbl.select()).first() + ) + conn.execute(tbl.delete()) class StringTest(fixtures.TestBase, AssertsCompiledSQL):