From 83c8231ac1821703a2ec7f07bbd68e8712101e73 Mon Sep 17 00:00:00 2001 From: Gord Thompson Date: Sat, 28 Mar 2020 17:34:05 -0600 Subject: [PATCH] Clean up .execute calls in PostgreSQL tests Fixes: #5220 Change-Id: I789e45dffc2b177ebb15ea3268bb965be8b06397 --- test/dialect/postgresql/test_compiler.py | 5 +- test/dialect/postgresql/test_dialect.py | 60 +- test/dialect/postgresql/test_on_conflict.py | 21 +- test/dialect/postgresql/test_query.py | 4 +- test/dialect/postgresql/test_types.py | 900 ++++++++++---------- 5 files changed, 508 insertions(+), 482 deletions(-) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 316f0c240b..4cc9c837d6 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -87,8 +87,9 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): Column(cname[:57], Integer, primary_key=True), ) t.create(engine) - r = engine.execute(t.insert()) - assert r.inserted_primary_key == [1] + with engine.begin() as conn: + r = conn.execute(t.insert()) + assert r.inserted_primary_key == [1] class CompileTest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 03e91482d6..d92559ac3c 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -677,7 +677,7 @@ $$ LANGUAGE plpgsql; "commit prepared 'gilberte'", ) - def test_extract(self): + def test_extract(self, connection): fivedaysago = testing.db.scalar( select([func.now()]) ) - datetime.timedelta(days=5) @@ -686,7 +686,7 @@ $$ LANGUAGE plpgsql; ("month", fivedaysago.month), ("day", fivedaysago.day), ): - r = testing.db.execute( + r = connection.execute( select( [extract(field, func.now() + datetime.timedelta(days=-5))] ) @@ -694,13 +694,13 @@ $$ LANGUAGE plpgsql; eq_(r, exp) @testing.provide_metadata - def test_checksfor_sequence(self): + def test_checksfor_sequence(self, connection): meta1 = self.metadata seq = Sequence("fooseq") t = Table("mytable", meta1, Column("col1", Integer, seq)) - seq.drop() - testing.db.execute(text("CREATE SEQUENCE fooseq")) - t.create(checkfirst=True) + seq.drop(connection) + connection.execute(text("CREATE SEQUENCE fooseq")) + t.create(connection, checkfirst=True) @testing.provide_metadata def test_schema_roundtrips(self): @@ -760,38 +760,38 @@ $$ LANGUAGE plpgsql; "some_name", ) - def test_preexecute_passivedefault(self): + @testing.provide_metadata + def test_preexecute_passivedefault(self, connection): """test that when we get a primary key column back from reflecting a table which has a default value on it, we pre- execute that DefaultClause upon insert.""" - try: - meta = MetaData(testing.db) - testing.db.execute( - text( - """ - CREATE TABLE speedy_users - ( - speedy_user_id SERIAL PRIMARY KEY, - - user_name VARCHAR NOT NULL, - user_password VARCHAR NOT NULL - ); - """ - ) + meta = self.metadata + connection.execute( + text( + """ + CREATE TABLE speedy_users + ( + speedy_user_id SERIAL PRIMARY KEY, + user_name VARCHAR NOT NULL, + user_password VARCHAR NOT NULL + ); + """ ) - t = Table("speedy_users", meta, autoload=True) - r = t.insert().execute(user_name="user", user_password="lala") - assert r.inserted_primary_key == [1] - result = t.select().execute().fetchall() - assert result == [(1, "user", "lala")] - finally: - testing.db.execute(text("drop table speedy_users")) + ) + t = Table("speedy_users", meta, autoload_with=connection) + r = connection.execute( + t.insert(), user_name="user", user_password="lala" + ) + assert r.inserted_primary_key == [1] + result = connection.execute(t.select()).fetchall() + assert result == [(1, "user", "lala")] + connection.execute(text("DROP TABLE speedy_users")) @testing.requires.psycopg2_or_pg8000_compatibility - def test_numeric_raise(self): + def test_numeric_raise(self, connection): stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric) - assert_raises(exc.InvalidRequestError, testing.db.execute, stmt) + assert_raises(exc.InvalidRequestError, connection.execute, stmt) @testing.only_if( "postgresql >= 8.2", "requires standard_conforming_strings" diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index c74f4cbef0..b7316ca606 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -113,20 +113,17 @@ class OnConflictTest(fixtures.TablesTest): [(1, "name1")], ) - def test_on_conflict_do_nothing_connectionless(self): + def test_on_conflict_do_nothing_connectionless(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - result = conn.execute( - insert(users).on_conflict_do_nothing( - constraint="uq_login_email" - ), - dict(name="name1", login_email="email1"), - ) - eq_(result.inserted_primary_key, [1]) - eq_(result.returned_defaults, (1,)) + result = connection.execute( + insert(users).on_conflict_do_nothing(constraint="uq_login_email"), + dict(name="name1", login_email="email1"), + ) + eq_(result.inserted_primary_key, [1]) + eq_(result.returned_defaults, (1,)) - result = testing.db.execute( + result = connection.execute( insert(users).on_conflict_do_nothing(constraint="uq_login_email"), dict(name="name2", login_email="email1"), ) @@ -134,7 +131,7 @@ class OnConflictTest(fixtures.TablesTest): eq_(result.returned_defaults, None) eq_( - testing.db.execute( + connection.execute( users.select().where(users.c.id == 1) ).fetchall(), [(1, "name1", "email1", None)], diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 65defea801..74ad19d1a5 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -872,7 +872,7 @@ class TupleTest(fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True - def test_tuple_containment(self): + def test_tuple_containment(self, connection): for test, exp in [ ([("a", "b")], True), @@ -881,7 +881,7 @@ class TupleTest(fixtures.TestBase): ([("f", "q"), ("a", "c")], False), ]: eq_( - testing.db.execute( + connection.execute( select( [ tuple_( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 26e7bb8d69..1903e47d48 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -82,20 +82,25 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): def insert_data(cls): data_table = cls.tables.data_table - data_table.insert().execute( - {"data": 3}, - {"data": 5}, - {"data": 7}, - {"data": 2}, - {"data": 15}, - {"data": 12}, - {"data": 6}, - {"data": 478}, - {"data": 52}, - {"data": 9}, - ) - - def test_float_coercion(self): + with testing.db.begin() as connection: + connection.execute( + data_table.insert().values( + [ + {"data": 3}, + {"data": 5}, + {"data": 7}, + {"data": 2}, + {"data": 15}, + {"data": 12}, + {"data": 6}, + {"data": 478}, + {"data": 52}, + {"data": 9}, + ] + ) + ) + + def test_float_coercion(self, connection): data_table = self.tables.data_table for type_, result in [ @@ -104,19 +109,19 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): (Float(asdecimal=True), decimal.Decimal("140.381230939")), (Numeric(asdecimal=False), 140.381230939), ]: - ret = testing.db.execute( + ret = connection.execute( select([func.stddev_pop(data_table.c.data, type_=type_)]) ).scalar() eq_(round_decimal(ret, 9), result) - ret = testing.db.execute( + ret = connection.execute( select([cast(func.stddev_pop(data_table.c.data), type_)]) ).scalar() eq_(round_decimal(ret, 9), result) @testing.provide_metadata - def test_arrays_pg(self): + def test_arrays_pg(self, connection): metadata = self.metadata t1 = Table( "t", @@ -127,12 +132,14 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): Column("q", postgresql.ARRAY(Numeric)), ) metadata.create_all() - t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")]) - row = t1.select().execute().first() + connection.execute( + t1.insert(), x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")] + ) + row = connection.execute(t1.select()).first() eq_(row, ([5], [5], [6], [decimal.Decimal("6.4")])) @testing.provide_metadata - def test_arrays_base(self): + def test_arrays_base(self, connection): metadata = self.metadata t1 = Table( "t", @@ -143,8 +150,10 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): Column("q", sqltypes.ARRAY(Numeric)), ) metadata.create_all() - t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")]) - row = t1.select().execute().first() + connection.execute( + t1.insert(), x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")] + ) + row = connection.execute(t1.select()).first() eq_(row, ([5], [5], [6], [decimal.Decimal("6.4")])) @@ -154,7 +163,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql > 8.3" @testing.provide_metadata - def test_create_table(self): + def test_create_table(self, connection): metadata = self.metadata t1 = Table( "table", @@ -164,16 +173,15 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): "value", Enum("one", "two", "three", name="onetwothreetype") ), ) - with testing.db.connect() as conn: - t1.create(conn) - t1.create(conn, checkfirst=True) # check the create - conn.execute(t1.insert(), value="two") - conn.execute(t1.insert(), value="three") - conn.execute(t1.insert(), value="three") - eq_( - conn.execute(t1.select().order_by(t1.c.id)).fetchall(), - [(1, "two"), (2, "three"), (3, "three")], - ) + t1.create(connection) + t1.create(connection, checkfirst=True) # check the create + connection.execute(t1.insert(), value="two") + connection.execute(t1.insert(), value="three") + connection.execute(t1.insert(), value="three") + eq_( + connection.execute(t1.select().order_by(t1.c.id)).fetchall(), + [(1, "two"), (2, "three"), (3, "three")], + ) @testing.combinations(None, "foo") def test_create_table_schema_translate_map(self, symbol_name): @@ -236,7 +244,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) @testing.provide_metadata - def test_unicode_labels(self): + def test_unicode_labels(self, connection): metadata = self.metadata t1 = Table( "table", @@ -253,11 +261,11 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), ) metadata.create_all() - t1.insert().execute(value=util.u("drôle")) - t1.insert().execute(value=util.u("réveillé")) - t1.insert().execute(value=util.u("S’il")) + connection.execute(t1.insert(), value=util.u("drôle")) + connection.execute(t1.insert(), value=util.u("réveillé")) + connection.execute(t1.insert(), value=util.u("S’il")) eq_( - t1.select().order_by(t1.c.id).execute().fetchall(), + connection.execute(t1.select().order_by(t1.c.id)).fetchall(), [ (1, util.u("drôle")), (2, util.u("réveillé")), @@ -272,7 +280,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) @testing.provide_metadata - def test_non_native_enum(self): + def test_non_native_enum(self, connection): metadata = self.metadata t1 = Table( "foo", @@ -298,12 +306,11 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) ], ) - with testing.db.begin() as conn: - conn.execute(t1.insert(), {"bar": "two"}) - eq_(conn.scalar(select([t1.c.bar])), "two") + connection.execute(t1.insert(), {"bar": "two"}) + eq_(connection.scalar(select([t1.c.bar])), "two") @testing.provide_metadata - def test_non_native_enum_w_unicode(self): + def test_non_native_enum_w_unicode(self, connection): metadata = self.metadata t1 = Table( "foo", @@ -331,9 +338,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ], ) - with testing.db.begin() as conn: - conn.execute(t1.insert(), {"bar": util.u("Ü")}) - eq_(conn.scalar(select([t1.c.bar])), util.u("Ü")) + connection.execute(t1.insert(), {"bar": util.u("Ü")}) + eq_(connection.scalar(select([t1.c.bar])), util.u("Ü")) @testing.provide_metadata def test_disable_create(self): @@ -614,7 +620,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): eq_(t2.c.value2.type.schema, "test_schema") @testing.provide_metadata - def test_custom_subclass(self): + def test_custom_subclass(self, connection): class MyEnum(TypeDecorator): impl = Enum("oneHI", "twoHI", "threeHI", name="myenum") @@ -631,12 +637,11 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): t1 = Table("table1", self.metadata, Column("data", MyEnum())) self.metadata.create_all(testing.db) - with testing.db.connect() as conn: - conn.execute(t1.insert(), {"data": "two"}) - eq_(conn.scalar(select([t1.c.data])), "twoHITHERE") + connection.execute(t1.insert(), {"data": "two"}) + eq_(connection.scalar(select([t1.c.data])), "twoHITHERE") @testing.provide_metadata - def test_generic_w_pg_variant(self): + def test_generic_w_pg_variant(self, connection): some_table = Table( "some_table", self.metadata, @@ -656,25 +661,26 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with testing.db.begin() as conn: - assert "my_enum" not in [ - e["name"] for e in inspect(conn).get_enums() - ] + assert "my_enum" not in [ + e["name"] for e in inspect(connection).get_enums() + ] - self.metadata.create_all(conn) + self.metadata.create_all(connection) - assert "my_enum" in [e["name"] for e in inspect(conn).get_enums()] + assert "my_enum" in [ + e["name"] for e in inspect(connection).get_enums() + ] - conn.execute(some_table.insert(), {"data": "five"}) + connection.execute(some_table.insert(), {"data": "five"}) - self.metadata.drop_all(conn) + self.metadata.drop_all(connection) - assert "my_enum" not in [ - e["name"] for e in inspect(conn).get_enums() - ] + assert "my_enum" not in [ + e["name"] for e in inspect(connection).get_enums() + ] @testing.provide_metadata - def test_generic_w_some_other_variant(self): + def test_generic_w_some_other_variant(self, connection): some_table = Table( "some_table", self.metadata, @@ -686,22 +692,23 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with testing.db.begin() as conn: - assert "my_enum" not in [ - e["name"] for e in inspect(conn).get_enums() - ] + assert "my_enum" not in [ + e["name"] for e in inspect(connection).get_enums() + ] - self.metadata.create_all(conn) + self.metadata.create_all(connection) - assert "my_enum" in [e["name"] for e in inspect(conn).get_enums()] + assert "my_enum" in [ + e["name"] for e in inspect(connection).get_enums() + ] - conn.execute(some_table.insert(), {"data": "two"}) + connection.execute(some_table.insert(), {"data": "two"}) - self.metadata.drop_all(conn) + self.metadata.drop_all(connection) - assert "my_enum" not in [ - e["name"] for e in inspect(conn).get_enums() - ] + assert "my_enum" not in [ + e["name"] for e in inspect(connection).get_enums() + ] class OIDTest(fixtures.TestBase): @@ -794,7 +801,7 @@ class NumericInterpretationTest(fixtures.TestBase): assert val in (23.7, decimal.Decimal("23.7")) @testing.provide_metadata - def test_numeric_default(self): + def test_numeric_default(self, connection): metadata = self.metadata # pg8000 appears to fail when the value is 0, # returns an int instead of decimal. @@ -808,9 +815,9 @@ class NumericInterpretationTest(fixtures.TestBase): Column("ff", Float(asdecimal=False), default=1), ) metadata.create_all() - t.insert().execute() + connection.execute(t.insert()) - row = t.select().execute().first() + row = connection.execute(t.select()).first() assert isinstance(row[1], decimal.Decimal) assert isinstance(row[2], float) assert isinstance(row[3], decimal.Decimal) @@ -874,7 +881,7 @@ class TimezoneTest(fixtures.TestBase): def teardown_class(cls): metadata.drop_all() - def test_with_timezone(self): + def test_with_timezone(self, connection): # get a date with a tzinfo @@ -882,35 +889,39 @@ class TimezoneTest(fixtures.TestBase): func.current_timestamp().select() ) assert somedate.tzinfo - tztable.insert().execute(id=1, name="row1", date=somedate) - row = select([tztable.c.date], tztable.c.id == 1).execute().first() + connection.execute(tztable.insert(), id=1, name="row1", date=somedate) + row = connection.execute( + select([tztable.c.date], tztable.c.id == 1) + ).first() eq_(row[0], somedate) eq_( somedate.tzinfo.utcoffset(somedate), row[0].tzinfo.utcoffset(row[0]), ) - result = ( - tztable.update(tztable.c.id == 1) - .returning(tztable.c.date) - .execute(name="newname") + result = connection.execute( + tztable.update(tztable.c.id == 1).returning(tztable.c.date), + name="newname", ) row = result.first() assert row[0] >= somedate - def test_without_timezone(self): + def test_without_timezone(self, connection): # get a date without a tzinfo somedate = datetime.datetime(2005, 10, 20, 11, 52, 0) assert not somedate.tzinfo - notztable.insert().execute(id=1, name="row1", date=somedate) - row = select([notztable.c.date], notztable.c.id == 1).execute().first() + connection.execute( + notztable.insert(), id=1, name="row1", date=somedate + ) + row = connection.execute( + select([notztable.c.date], notztable.c.id == 1) + ).first() eq_(row[0], somedate) eq_(row[0].tzinfo, None) - result = ( - notztable.update(notztable.c.id == 1) - .returning(notztable.c.date) - .execute(name="newname") + result = connection.execute( + notztable.update(notztable.c.id == 1).returning(notztable.c.date), + name="newname", ) row = result.first() assert row[0] >= somedate @@ -1267,7 +1278,8 @@ class ArrayRoundTripTest(object): ) def _fixture_456(self, table): - testing.db.execute(table.insert(), intarr=[4, 5, 6]) + with testing.db.begin() as conn: + conn.execute(table.insert(), intarr=[4, 5, 6]) def test_reflect_array_column(self): metadata2 = MetaData(testing.db) @@ -1290,39 +1302,39 @@ class ArrayRoundTripTest(object): t.create() @testing.provide_metadata - def test_array_agg(self): + def test_array_agg(self, connection): values_table = Table("values", self.metadata, Column("value", Integer)) self.metadata.create_all(testing.db) - testing.db.execute( + connection.execute( values_table.insert(), [{"value": i} for i in range(1, 10)] ) stmt = select([func.array_agg(values_table.c.value)]) - eq_(testing.db.execute(stmt).scalar(), list(range(1, 10))) + eq_(connection.execute(stmt).scalar(), list(range(1, 10))) stmt = select([func.array_agg(values_table.c.value)[3]]) - eq_(testing.db.execute(stmt).scalar(), 3) + eq_(connection.execute(stmt).scalar(), 3) stmt = select([func.array_agg(values_table.c.value)[2:4]]) - eq_(testing.db.execute(stmt).scalar(), [2, 3, 4]) + eq_(connection.execute(stmt).scalar(), [2, 3, 4]) - def test_array_index_slice_exprs(self): + def test_array_index_slice_exprs(self, connection): """test a variety of expressions that sometimes need parenthesizing""" stmt = select([array([1, 2, 3, 4])[2:3]]) - eq_(testing.db.execute(stmt).scalar(), [2, 3]) + eq_(connection.execute(stmt).scalar(), [2, 3]) stmt = select([array([1, 2, 3, 4])[2]]) - eq_(testing.db.execute(stmt).scalar(), 2) + eq_(connection.execute(stmt).scalar(), 2) stmt = select([(array([1, 2]) + array([3, 4]))[2:3]]) - eq_(testing.db.execute(stmt).scalar(), [2, 3]) + eq_(connection.execute(stmt).scalar(), [2, 3]) stmt = select([array([1, 2]) + array([3, 4])[2:3]]) - eq_(testing.db.execute(stmt).scalar(), [1, 2, 4]) + eq_(connection.execute(stmt).scalar(), [1, 2, 4]) stmt = select([array([1, 2])[2:3] + array([3, 4])]) - eq_(testing.db.execute(stmt).scalar(), [2, 3, 4]) + eq_(connection.execute(stmt).scalar(), [2, 3, 4]) stmt = select( [ @@ -1333,9 +1345,9 @@ class ArrayRoundTripTest(object): )[2:5] ] ) - eq_(testing.db.execute(stmt).scalar(), [2, 3, 4, 5]) + eq_(connection.execute(stmt).scalar(), [2, 3, 4, 5]) - def test_any_all_exprs_array(self): + def test_any_all_exprs_array(self, connection): stmt = select( [ 3 @@ -1348,79 +1360,90 @@ class ArrayRoundTripTest(object): ) ] ) - eq_(testing.db.execute(stmt).scalar(), True) + eq_(connection.execute(stmt).scalar(), True) - def test_insert_array(self): + def test_insert_array(self, connection): arrtable = self.tables.arrtable - arrtable.insert().execute( - intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] + connection.execute( + arrtable.insert(), + intarr=[1, 2, 3], + strarr=[util.u("abc"), util.u("def")], ) - results = arrtable.select().execute().fetchall() + results = connection.execute(arrtable.select()).fetchall() eq_(len(results), 1) eq_(results[0].intarr, [1, 2, 3]) eq_(results[0].strarr, [util.u("abc"), util.u("def")]) - def test_insert_array_w_null(self): + def test_insert_array_w_null(self, connection): arrtable = self.tables.arrtable - arrtable.insert().execute( - intarr=[1, None, 3], strarr=[util.u("abc"), None] + connection.execute( + arrtable.insert(), + intarr=[1, None, 3], + strarr=[util.u("abc"), None], ) - results = arrtable.select().execute().fetchall() + results = connection.execute(arrtable.select()).fetchall() eq_(len(results), 1) eq_(results[0].intarr, [1, None, 3]) eq_(results[0].strarr, [util.u("abc"), None]) - def test_array_where(self): + def test_array_where(self, connection): arrtable = self.tables.arrtable - arrtable.insert().execute( - intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] + connection.execute( + arrtable.insert(), + intarr=[1, 2, 3], + strarr=[util.u("abc"), util.u("def")], ) - arrtable.insert().execute(intarr=[4, 5, 6], strarr=util.u("ABC")) - results = ( - arrtable.select() - .where(arrtable.c.intarr == [1, 2, 3]) - .execute() - .fetchall() + connection.execute( + arrtable.insert(), intarr=[4, 5, 6], strarr=util.u("ABC") ) + results = connection.execute( + arrtable.select().where(arrtable.c.intarr == [1, 2, 3]) + ).fetchall() eq_(len(results), 1) eq_(results[0].intarr, [1, 2, 3]) - def test_array_concat(self): + def test_array_concat(self, connection): arrtable = self.tables.arrtable - arrtable.insert().execute( - intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] + connection.execute( + arrtable.insert(), + intarr=[1, 2, 3], + strarr=[util.u("abc"), util.u("def")], ) - results = select([arrtable.c.intarr + [4, 5, 6]]).execute().fetchall() + results = connection.execute( + select([arrtable.c.intarr + [4, 5, 6]]) + ).fetchall() eq_(len(results), 1) eq_(results[0][0], [1, 2, 3, 4, 5, 6]) - def test_array_comparison(self): + def test_array_comparison(self, connection): arrtable = self.tables.arrtable - arrtable.insert().execute( - id=5, intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] - ) - results = ( - select([arrtable.c.id]) - .where(arrtable.c.intarr < [4, 5, 6]) - .execute() - .fetchall() + connection.execute( + arrtable.insert(), + id=5, + intarr=[1, 2, 3], + strarr=[util.u("abc"), util.u("def")], ) + results = connection.execute( + select([arrtable.c.id]).where(arrtable.c.intarr < [4, 5, 6]) + ).fetchall() eq_(len(results), 1) eq_(results[0][0], 5) - def test_array_subtype_resultprocessor(self): + def test_array_subtype_resultprocessor(self, connection): arrtable = self.tables.arrtable - arrtable.insert().execute( + connection.execute( + arrtable.insert(), intarr=[4, 5, 6], strarr=[[util.ue("m\xe4\xe4")], [util.ue("m\xf6\xf6")]], ) - arrtable.insert().execute( + connection.execute( + arrtable.insert(), intarr=[1, 2, 3], strarr=[util.ue("m\xe4\xe4"), util.ue("m\xf6\xf6")], ) - results = ( - arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall() - ) + results = connection.execute( + arrtable.select(order_by=[arrtable.c.intarr]) + ).fetchall() eq_(len(results), 2) eq_(results[0].strarr, [util.ue("m\xe4\xe4"), util.ue("m\xf6\xf6")]) eq_( @@ -1428,9 +1451,9 @@ class ArrayRoundTripTest(object): [[util.ue("m\xe4\xe4")], [util.ue("m\xf6\xf6")]], ) - def test_array_literal_roundtrip(self): + def test_array_literal_roundtrip(self, connection): eq_( - testing.db.scalar( + connection.scalar( select( [postgresql.array([1, 2]) + postgresql.array([3, 4, 5])] ) @@ -1439,7 +1462,7 @@ class ArrayRoundTripTest(object): ) eq_( - testing.db.scalar( + connection.scalar( select( [ ( @@ -1453,7 +1476,7 @@ class ArrayRoundTripTest(object): ) eq_( - testing.db.scalar( + connection.scalar( select( [ ( @@ -1466,9 +1489,9 @@ class ArrayRoundTripTest(object): [2, 3, 4], ) - def test_array_literal_multidimensional_roundtrip(self): + def test_array_literal_multidimensional_roundtrip(self, connection): eq_( - testing.db.scalar( + connection.scalar( select( [ postgresql.array( @@ -1484,7 +1507,7 @@ class ArrayRoundTripTest(object): ) eq_( - testing.db.scalar( + connection.scalar( select( [ postgresql.array( @@ -1499,68 +1522,66 @@ class ArrayRoundTripTest(object): 3, ) - def test_array_literal_compare(self): + def test_array_literal_compare(self, connection): eq_( - testing.db.scalar(select([postgresql.array([1, 2]) < [3, 4, 5]])), + connection.scalar(select([postgresql.array([1, 2]) < [3, 4, 5]])), True, ) - def test_array_getitem_single_exec(self): + def test_array_getitem_single_exec(self, connection): arrtable = self.tables.arrtable self._fixture_456(arrtable) - eq_(testing.db.scalar(select([arrtable.c.intarr[2]])), 5) - testing.db.execute(arrtable.update().values({arrtable.c.intarr[2]: 7})) - eq_(testing.db.scalar(select([arrtable.c.intarr[2]])), 7) + eq_(connection.scalar(select([arrtable.c.intarr[2]])), 5) + connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7})) + eq_(connection.scalar(select([arrtable.c.intarr[2]])), 7) - def test_array_getitem_slice_exec(self): + def test_array_getitem_slice_exec(self, connection): arrtable = self.tables.arrtable - testing.db.execute( + connection.execute( arrtable.insert(), intarr=[4, 5, 6], strarr=[util.u("abc"), util.u("def")], ) - eq_(testing.db.scalar(select([arrtable.c.intarr[2:3]])), [5, 6]) - testing.db.execute( + eq_(connection.scalar(select([arrtable.c.intarr[2:3]])), [5, 6]) + connection.execute( arrtable.update().values({arrtable.c.intarr[2:3]: [7, 8]}) ) - eq_(testing.db.scalar(select([arrtable.c.intarr[2:3]])), [7, 8]) + eq_(connection.scalar(select([arrtable.c.intarr[2:3]])), [7, 8]) - def test_multi_dim_roundtrip(self): + def test_multi_dim_roundtrip(self, connection): arrtable = self.tables.arrtable - testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]]) + connection.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]]) eq_( - testing.db.scalar(select([arrtable.c.dimarr])), + connection.scalar(select([arrtable.c.dimarr])), [[-1, 0, 1], [2, 3, 4]], ) - def test_array_any_exec(self): + def test_array_any_exec(self, connection): arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute(arrtable.insert(), intarr=[4, 5, 6]) - eq_( - conn.scalar( - select([arrtable.c.intarr]).where( - postgresql.Any(5, arrtable.c.intarr) - ) - ), - [4, 5, 6], - ) + connection.execute(arrtable.insert(), intarr=[4, 5, 6]) + eq_( + connection.scalar( + select([arrtable.c.intarr]).where( + postgresql.Any(5, arrtable.c.intarr) + ) + ), + [4, 5, 6], + ) - def test_array_all_exec(self): + def test_array_all_exec(self, connection): arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute(arrtable.insert(), intarr=[4, 5, 6]) - eq_( - conn.scalar( - select([arrtable.c.intarr]).where( - arrtable.c.intarr.all(4, operator=operators.le) - ) - ), - [4, 5, 6], - ) + connection.execute(arrtable.insert(), intarr=[4, 5, 6]) + eq_( + connection.scalar( + select([arrtable.c.intarr]).where( + arrtable.c.intarr.all(4, operator=operators.le) + ) + ), + [4, 5, 6], + ) @testing.provide_metadata - def test_tuple_flag(self): + def test_tuple_flag(self, connection): metadata = self.metadata t1 = Table( @@ -1573,20 +1594,20 @@ class ArrayRoundTripTest(object): ), ) metadata.create_all() - testing.db.execute( + connection.execute( t1.insert(), id=1, data=["1", "2", "3"], data2=[5.4, 5.6] ) - testing.db.execute( + connection.execute( t1.insert(), id=2, data=["4", "5", "6"], data2=[1.0] ) - testing.db.execute( + connection.execute( t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], data2=[[5.4, 5.6], [1.0, 1.1]], ) - r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall() + r = connection.execute(t1.select().order_by(t1.c.id)).fetchall() eq_( r, [ @@ -1641,44 +1662,45 @@ class PGArrayRoundTripTest( def test_undim_array_contains_typed_exec(self, struct): arrtable = self.tables.arrtable self._fixture_456(arrtable) - eq_( - testing.db.scalar( - select([arrtable.c.intarr]).where( - arrtable.c.intarr.contains(struct([4, 5])) - ) - ), - [4, 5, 6], - ) + with testing.db.begin() as conn: + eq_( + conn.scalar( + select([arrtable.c.intarr]).where( + arrtable.c.intarr.contains(struct([4, 5])) + ) + ), + [4, 5, 6], + ) @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) def test_dim_array_contains_typed_exec(self, struct): dim_arrtable = self.tables.dim_arrtable self._fixture_456(dim_arrtable) - eq_( - testing.db.scalar( - select([dim_arrtable.c.intarr]).where( - dim_arrtable.c.intarr.contains(struct([4, 5])) - ) - ), - [4, 5, 6], - ) - - def test_array_contained_by_exec(self): - arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute(arrtable.insert(), intarr=[6, 5, 4]) + with testing.db.begin() as conn: eq_( conn.scalar( - select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) + select([dim_arrtable.c.intarr]).where( + dim_arrtable.c.intarr.contains(struct([4, 5])) + ) ), - True, + [4, 5, 6], ) - def test_undim_array_empty(self): + def test_array_contained_by_exec(self, connection): + arrtable = self.tables.arrtable + connection.execute(arrtable.insert(), intarr=[6, 5, 4]) + eq_( + connection.scalar( + select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) + ), + True, + ) + + def test_undim_array_empty(self, connection): arrtable = self.tables.arrtable self._fixture_456(arrtable) eq_( - testing.db.scalar( + connection.scalar( select([arrtable.c.intarr]).where( arrtable.c.intarr.contains([]) ) @@ -1686,18 +1708,17 @@ class PGArrayRoundTripTest( [4, 5, 6], ) - def test_array_overlap_exec(self): + def test_array_overlap_exec(self, connection): arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute(arrtable.insert(), intarr=[4, 5, 6]) - eq_( - conn.scalar( - select([arrtable.c.intarr]).where( - arrtable.c.intarr.overlap([7, 6]) - ) - ), - [4, 5, 6], - ) + connection.execute(arrtable.insert(), intarr=[4, 5, 6]) + eq_( + connection.scalar( + select([arrtable.c.intarr]).where( + arrtable.c.intarr.overlap([7, 6]) + ) + ), + [4, 5, 6], + ) class HashableFlagORMTest(fixtures.TestBase): @@ -1774,20 +1795,14 @@ class TimestampTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" __backend__ = True - def test_timestamp(self): - engine = testing.db - connection = engine.connect() - + def test_timestamp(self, connection): s = select([text("timestamp '2007-12-25'")]) result = connection.execute(s).first() eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0)) - def test_interval_arithmetic(self): + def test_interval_arithmetic(self, connection): # basically testing that we get timedelta back for an INTERVAL # result. more of a driver assertion. - engine = testing.db - connection = engine.connect() - s = select([text("timestamp '2007-12-25' - timestamp '2007-11-15'")]) result = connection.execute(s).first() eq_(result[0], datetime.timedelta(40)) @@ -1871,16 +1886,16 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): assert t.c.bitstring.type.length == 4 @testing.provide_metadata - def test_tsvector_round_trip(self): + def test_tsvector_round_trip(self, connection): t = Table("t1", self.metadata, Column("data", postgresql.TSVECTOR)) t.create() - testing.db.execute(t.insert(), data="a fat cat sat") - eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'sat'") + connection.execute(t.insert(), data="a fat cat sat") + eq_(connection.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'sat'") - testing.db.execute(t.update(), data="'a' 'cat' 'fat' 'mat' 'sat'") + connection.execute(t.update(), data="'a' 'cat' 'fat' 'mat' 'sat'") eq_( - testing.db.scalar(select([t.c.data])), + connection.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'mat' 'sat'", ) @@ -1924,21 +1939,18 @@ class UUIDTest(fixtures.TestBase): ), ("as_uuid", postgresql.UUID(as_uuid=True), uuid.uuid4(), uuid.uuid4()), id_="iaaa", + argnames="datatype, value1, value2", ) - def test_round_trip(self, datatype, value1, value2): - + def test_round_trip(self, datatype, value1, value2, connection): utable = Table("utable", MetaData(), Column("data", datatype)) - - with testing.db.connect() as conn: - conn.begin() - utable.create(conn) - conn.execute(utable.insert(), {"data": value1}) - conn.execute(utable.insert(), {"data": value2}) - r = conn.execute( - select([utable.c.data]).where(utable.c.data != value1) - ) - eq_(r.fetchone()[0], value2) - eq_(r.fetchone(), None) + utable.create(connection) + connection.execute(utable.insert(), {"data": value1}) + connection.execute(utable.insert(), {"data": value2}) + r = connection.execute( + select([utable.c.data]).where(utable.c.data != value1) + ) + eq_(r.fetchone()[0], value2) + eq_(r.fetchone(), None) @testing.combinations( ( @@ -1954,13 +1966,13 @@ class UUIDTest(fixtures.TestBase): [str(uuid.uuid4()), str(uuid.uuid4())], ), id_="iaaa", + argnames="datatype, value1, value2", ) @testing.fails_on("postgresql+pg8000", "No support for UUID with ARRAY") - def test_uuid_array(self, datatype, value1, value2): - self.test_round_trip(datatype, value1, value2) + def test_uuid_array(self, datatype, value1, value2, connection): + self.test_round_trip(datatype, value1, value2, connection) def test_no_uuid_available(self): - uuid_type = base._python_UUID base._python_UUID = None try: @@ -2258,17 +2270,18 @@ class HStoreRoundTripTest(fixtures.TablesTest): def _fixture_data(self, engine): data_table = self.tables.data_table - engine.execute( - data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, - {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, - {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, - {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}}, - ) + with engine.begin() as conn: + conn.execute( + data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, + {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, + {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, + {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}}, + ) - def _assert_data(self, compare): - data = testing.db.execute( + def _assert_data(self, compare, conn): + data = conn.execute( select([self.tables.data_table.c.data]).order_by( self.tables.data_table.c.name ) @@ -2276,11 +2289,12 @@ class HStoreRoundTripTest(fixtures.TablesTest): eq_([d for d, in data], compare) def _test_insert(self, engine): - engine.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - ) - self._assert_data([{"k1": "r1v1", "k2": "r1v2"}]) + with engine.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + ) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn) def _non_native_engine(self): if testing.requires.psycopg2_native_hstore.enabled: @@ -2297,7 +2311,7 @@ class HStoreRoundTripTest(fixtures.TablesTest): cols = insp.get_columns("data_table") assert isinstance(cols[2]["type"], HSTORE) - def test_literal_round_trip(self): + def test_literal_round_trip(self, connection): # in particular, this tests that the array index # operator against the function is handled by PG; with some # array functions it requires outer parenthezisation on the left and @@ -2305,7 +2319,7 @@ class HStoreRoundTripTest(fixtures.TablesTest): expr = hstore( postgresql.array(["1", "2"]), postgresql.array(["3", None]) )["1"] - eq_(testing.db.scalar(select([expr])), "3") + eq_(connection.scalar(select([expr])), "3") @testing.requires.psycopg2_native_hstore def test_insert_native(self): @@ -2329,26 +2343,28 @@ class HStoreRoundTripTest(fixtures.TablesTest): def _test_criterion(self, engine): data_table = self.tables.data_table - result = engine.execute( - select([data_table.c.data]).where( - data_table.c.data["k1"] == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) + with engine.begin() as conn: + result = conn.execute( + select([data_table.c.data]).where( + data_table.c.data["k1"] == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) def _test_fixed_round_trip(self, engine): - s = select( - [ - hstore( - array(["key1", "key2", "key3"]), - array(["value1", "value2", "value3"]), - ) - ] - ) - eq_( - engine.scalar(s), - {"key1": "value1", "key2": "value2", "key3": "value3"}, - ) + with engine.begin() as conn: + s = select( + [ + hstore( + array(["key1", "key2", "key3"]), + array(["value1", "value2", "value3"]), + ) + ] + ) + eq_( + conn.scalar(s), + {"key1": "value1", "key2": "value2", "key3": "value3"}, + ) def test_fixed_round_trip_python(self): engine = self._non_native_engine() @@ -2360,26 +2376,35 @@ class HStoreRoundTripTest(fixtures.TablesTest): self._test_fixed_round_trip(engine) def _test_unicode_round_trip(self, engine): - s = select( - [ - hstore( - array( - [util.u("réveillé"), util.u("drôle"), util.u("S’il")] - ), - array( - [util.u("réveillé"), util.u("drôle"), util.u("S’il")] - ), - ) - ] - ) - eq_( - engine.scalar(s), - { - util.u("réveillé"): util.u("réveillé"), - util.u("drôle"): util.u("drôle"), - util.u("S’il"): util.u("S’il"), - }, - ) + with engine.begin() as conn: + s = select( + [ + hstore( + array( + [ + util.u("réveillé"), + util.u("drôle"), + util.u("S’il"), + ] + ), + array( + [ + util.u("réveillé"), + util.u("drôle"), + util.u("S’il"), + ] + ), + ) + ] + ) + eq_( + conn.scalar(s), + { + util.u("réveillé"): util.u("réveillé"), + util.u("drôle"): util.u("drôle"), + util.u("S’il"): util.u("S’il"), + }, + ) @testing.requires.psycopg2_native_hstore def test_unicode_round_trip_python(self): @@ -2401,11 +2426,12 @@ class HStoreRoundTripTest(fixtures.TablesTest): self._test_escaped_quotes_round_trip(engine) def _test_escaped_quotes_round_trip(self, engine): - engine.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}}, - ) - self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}]) + with engine.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}}, + ) + self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn) def test_orm_round_trip(self): from sqlalchemy import orm @@ -2582,52 +2608,52 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): cols = insp.get_columns("data_table") assert isinstance(cols[0]["type"], self._col_type) - def _assert_data(self): - data = testing.db.execute( + def _assert_data(self, conn): + data = conn.execute( select([self.tables.data_table.c.range]) ).fetchall() eq_(data, [(self._data_obj(),)]) - def test_insert_obj(self): - testing.db.engine.execute( + def test_insert_obj(self, connection): + connection.execute( self.tables.data_table.insert(), {"range": self._data_obj()} ) - self._assert_data() + self._assert_data(connection) - def test_insert_text(self): - testing.db.engine.execute( + def test_insert_text(self, connection): + connection.execute( self.tables.data_table.insert(), {"range": self._data_str} ) - self._assert_data() + self._assert_data(connection) - def test_union_result(self): + def test_union_result(self, connection): # insert - testing.db.engine.execute( + connection.execute( self.tables.data_table.insert(), {"range": self._data_str} ) # select range_ = self.tables.data_table.c.range - data = testing.db.execute(select([range_ + range_])).fetchall() + data = connection.execute(select([range_ + range_])).fetchall() eq_(data, [(self._data_obj(),)]) - def test_intersection_result(self): + def test_intersection_result(self, connection): # insert - testing.db.engine.execute( + connection.execute( self.tables.data_table.insert(), {"range": self._data_str} ) # select range_ = self.tables.data_table.c.range - data = testing.db.execute(select([range_ * range_])).fetchall() + data = connection.execute(select([range_ * range_])).fetchall() eq_(data, [(self._data_obj(),)]) - def test_difference_result(self): + def test_difference_result(self, connection): # insert - testing.db.engine.execute( + connection.execute( self.tables.data_table.insert(), {"range": self._data_str} ) # select range_ = self.tables.data_table.c.range - data = testing.db.execute(select([range_ - range_])).fetchall() + data = connection.execute(select([range_ - range_])).fetchall() eq_(data, [(self._data_obj().__class__(empty=True),)]) @@ -2701,9 +2727,10 @@ class _DateTimeTZRangeTests(object): def tstzs(self): if self._tstzs is None: - lower = testing.db.scalar(func.current_timestamp().select()) - upper = lower + datetime.timedelta(1) - self._tstzs = (lower, upper) + with testing.db.begin() as conn: + lower = conn.scalar(func.current_timestamp().select()) + upper = lower + datetime.timedelta(1) + self._tstzs = (lower, upper) return self._tstzs @property @@ -2868,65 +2895,66 @@ class JSONRoundTripTest(fixtures.TablesTest): def _fixture_data(self, engine): data_table = self.tables.data_table - engine.execute( - data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, - {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, - {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, - {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, - {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, - ) + with engine.begin() as conn: + conn.execute( + data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, + {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, + {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, + {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, + {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, + ) - def _assert_data(self, compare, column="data"): + def _assert_data(self, compare, conn, column="data"): col = self.tables.data_table.c[column] - - data = testing.db.execute( + data = conn.execute( select([col]).order_by(self.tables.data_table.c.name) ).fetchall() eq_([d for d, in data], compare) - def _assert_column_is_NULL(self, column="data"): + def _assert_column_is_NULL(self, conn, column="data"): col = self.tables.data_table.c[column] - - data = testing.db.execute( - select([col]).where(col.is_(null())) - ).fetchall() + data = conn.execute(select([col]).where(col.is_(null()))).fetchall() eq_([d for d, in data], [None]) - def _assert_column_is_JSON_NULL(self, column="data"): + def _assert_column_is_JSON_NULL(self, conn, column="data"): col = self.tables.data_table.c[column] - - data = testing.db.execute( + data = conn.execute( select([col]).where(cast(col, String) == "null") ).fetchall() eq_([d for d, in data], [None]) def _test_insert(self, engine): - engine.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - ) - self._assert_data([{"k1": "r1v1", "k2": "r1v2"}]) + with engine.connect() as conn: + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + ) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn) def _test_insert_nulls(self, engine): - engine.execute( - self.tables.data_table.insert(), {"name": "r1", "data": null()} - ) - self._assert_data([None]) + with engine.connect() as conn: + conn.execute( + self.tables.data_table.insert(), {"name": "r1", "data": null()} + ) + self._assert_data([None], conn) def _test_insert_none_as_null(self, engine): - engine.execute( - self.tables.data_table.insert(), {"name": "r1", "nulldata": None} - ) - self._assert_column_is_NULL(column="nulldata") + with engine.connect() as conn: + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "nulldata": None}, + ) + self._assert_column_is_NULL(conn, column="nulldata") def _test_insert_nulljson_into_none_as_null(self, engine): - engine.execute( - self.tables.data_table.insert(), - {"name": "r1", "nulldata": JSON.NULL}, - ) - self._assert_column_is_JSON_NULL(column="nulldata") + with engine.connect() as conn: + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "nulldata": JSON.NULL}, + ) + self._assert_column_is_JSON_NULL(conn, column="nulldata") def _non_native_engine(self, json_serializer=None, json_deserializer=None): if json_serializer is not None or json_deserializer is not None: @@ -2966,24 +2994,20 @@ class JSONRoundTripTest(fixtures.TablesTest): assert isinstance(cols[2]["type"], self.test_type) @testing.requires.psycopg2_native_json - def test_insert_native(self): - engine = testing.db - self._test_insert(engine) + def test_insert_native(self, connection): + self._test_insert(connection) @testing.requires.psycopg2_native_json - def test_insert_native_nulls(self): - engine = testing.db - self._test_insert_nulls(engine) + def test_insert_native_nulls(self, connection): + self._test_insert_nulls(connection) @testing.requires.psycopg2_native_json - def test_insert_native_none_as_null(self): - engine = testing.db - self._test_insert_none_as_null(engine) + def test_insert_native_none_as_null(self, connection): + self._test_insert_none_as_null(connection) @testing.requires.psycopg2_native_json - def test_insert_native_nulljson_into_none_as_null(self): - engine = testing.db - self._test_insert_nulljson_into_none_as_null(engine) + def test_insert_native_nulljson_into_none_as_null(self, connection): + self._test_insert_nulljson_into_none_as_null(connection) def test_insert_python(self): engine = self._non_native_engine() @@ -3024,7 +3048,8 @@ class JSONRoundTripTest(fixtures.TablesTest): ) s = select([cast({"key": "value", "x": "q"}, self.test_type)]) - eq_(engine.scalar(s), {"key": "value", "x": "dumps_y_loads"}) + with engine.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) @testing.requires.psycopg2_native_json def test_custom_native(self): @@ -3045,12 +3070,12 @@ class JSONRoundTripTest(fixtures.TablesTest): self._fixture_data(engine) self._test_criterion(engine) - def test_path_query(self): + def test_path_query(self, connection): engine = testing.db self._fixture_data(engine) data_table = self.tables.data_table - result = engine.execute( + result = connection.execute( select([data_table.c.name]).where( data_table.c.data[("k1", "r6v1", "subr")].astext == "[1, 2, 3]" ) @@ -3060,23 +3085,23 @@ class JSONRoundTripTest(fixtures.TablesTest): @testing.fails_on( "postgresql < 9.4", "Improvement in PostgreSQL behavior?" ) - def test_multi_index_query(self): + def test_multi_index_query(self, connection): engine = testing.db self._fixture_data(engine) data_table = self.tables.data_table - result = engine.execute( + result = connection.execute( select([data_table.c.name]).where( data_table.c.data["k1"]["r6v1"]["subr"].astext == "[1, 2, 3]" ) ) eq_(result.scalar(), "r6") - def test_query_returned_as_text(self): + def test_query_returned_as_text(self, connection): engine = testing.db self._fixture_data(engine) data_table = self.tables.data_table - result = engine.execute( + result = connection.execute( select([data_table.c.data["k1"].astext]) ).first() if engine.dialect.returns_unicode_strings: @@ -3084,11 +3109,11 @@ class JSONRoundTripTest(fixtures.TablesTest): else: assert isinstance(result[0], util.string_types) - def test_query_returned_as_int(self): + def test_query_returned_as_int(self, connection): engine = testing.db self._fixture_data(engine) data_table = self.tables.data_table - result = engine.execute( + result = connection.execute( select([data_table.c.data["k3"].astext.cast(Integer)]).where( data_table.c.name == "r5" ) @@ -3097,33 +3122,35 @@ class JSONRoundTripTest(fixtures.TablesTest): def _test_criterion(self, engine): data_table = self.tables.data_table - result = engine.execute( - select([data_table.c.data]).where( - data_table.c.data["k1"].astext == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) + with engine.begin() as conn: + result = conn.execute( + select([data_table.c.data]).where( + data_table.c.data["k1"].astext == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) - result = engine.execute( - select([data_table.c.data]).where( - data_table.c.data["k1"].astext.cast(String) == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) + result = conn.execute( + select([data_table.c.data]).where( + data_table.c.data["k1"].astext.cast(String) == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) def _test_fixed_round_trip(self, engine): - s = select( - [ - cast( - {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, - self.test_type, - ) - ] - ) - eq_( - engine.scalar(s), - {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, - ) + with engine.begin() as conn: + s = select( + [ + cast( + {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, + self.test_type, + ) + ] + ) + eq_( + conn.scalar(s), + {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, + ) def test_fixed_round_trip_python(self): engine = self._non_native_engine() @@ -3135,24 +3162,25 @@ class JSONRoundTripTest(fixtures.TablesTest): self._test_fixed_round_trip(engine) def _test_unicode_round_trip(self, engine): - s = select( - [ - cast( - { - util.u("réveillé"): util.u("réveillé"), - "data": {"k1": util.u("drôle")}, - }, - self.test_type, - ) - ] - ) - eq_( - engine.scalar(s), - { - util.u("réveillé"): util.u("réveillé"), - "data": {"k1": util.u("drôle")}, - }, - ) + with engine.begin() as conn: + s = select( + [ + cast( + { + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, + }, + self.test_type, + ) + ] + ) + eq_( + conn.scalar(s), + { + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, + }, + ) def test_unicode_round_trip_python(self): engine = self._non_native_engine() -- 2.47.3