]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Clean up (engine|db).execute for oracle
authorGord Thompson <gord@gordthompson.com>
Tue, 31 Mar 2020 15:59:16 +0000 (09:59 -0600)
committerGord Thompson <gord@gordthompson.com>
Mon, 6 Apr 2020 21:15:23 +0000 (15:15 -0600)
Change-Id: I6064fe348394152b2a47e83e43c469a153d34d27

test/dialect/oracle/test_dialect.py
test/dialect/oracle/test_types.py

index c2983bfe0f3e0923a95357c51125c8591ab195fe..ea0c230dd485840421492404ffb4c648770adbed 100644 (file)
@@ -346,8 +346,8 @@ end;
                 """
             )
 
-    def test_out_params(self):
-        result = testing.db.execute(
+    def test_out_params(self, connection):
+        result = connection.execute(
             text(
                 "begin foo(:x_in, :x_out, :y_out, " ":z_out); end;"
             ).bindparams(
@@ -363,7 +363,8 @@ end;
 
     @classmethod
     def teardown_class(cls):
-        testing.db.execute(text("DROP PROCEDURE foo"))
+        with testing.db.connect() as conn:
+            conn.execute(text("DROP PROCEDURE foo"))
 
 
 class QuotedBindRoundTripTest(fixtures.TestBase):
@@ -372,7 +373,7 @@ class QuotedBindRoundTripTest(fixtures.TestBase):
     __backend__ = True
 
     @testing.provide_metadata
-    def test_table_round_trip(self):
+    def test_table_round_trip(self, connection):
         oracle.RESERVED_WORDS.remove("UNION")
 
         metadata = self.metadata
@@ -388,14 +389,16 @@ class QuotedBindRoundTripTest(fixtures.TestBase):
         )
         metadata.create_all()
 
-        table.insert().execute({"option": 1, "plain": 1, "union": 1})
-        eq_(testing.db.execute(table.select()).first(), (1, 1, 1))
-        table.update().values(option=2, plain=2, union=2).execute()
-        eq_(testing.db.execute(table.select()).first(), (2, 2, 2))
+        connection.execute(
+            table.insert(), {"option": 1, "plain": 1, "union": 1}
+        )
+        eq_(connection.execute(table.select()).first(), (1, 1, 1))
+        connection.execute(table.update().values(option=2, plain=2, union=2))
+        eq_(connection.execute(table.select()).first(), (2, 2, 2))
 
-    def test_numeric_bind_round_trip(self):
+    def test_numeric_bind_round_trip(self, connection):
         eq_(
-            testing.db.scalar(
+            connection.scalar(
                 select(
                     [
                         literal_column("2", type_=Integer())
@@ -407,25 +410,22 @@ class QuotedBindRoundTripTest(fixtures.TestBase):
         )
 
     @testing.provide_metadata
-    def test_numeric_bind_in_crud(self):
+    def test_numeric_bind_in_crud(self, connection):
         t = Table("asfd", self.metadata, Column("100K", Integer))
-        t.create()
+        t.create(connection)
 
-        testing.db.execute(t.insert(), {"100K": 10})
-        eq_(testing.db.scalar(t.select()), 10)
+        connection.execute(t.insert(), {"100K": 10})
+        eq_(connection.scalar(t.select()), 10)
 
     @testing.provide_metadata
-    def test_expanding_quote_roundtrip(self):
+    def test_expanding_quote_roundtrip(self, connection):
         t = Table("asfd", self.metadata, Column("foo", Integer))
-        t.create()
+        t.create(connection)
 
-        with testing.db.connect() as conn:
-            conn.execute(
-                select([t]).where(
-                    t.c.foo.in_(bindparam("uid", expanding=True))
-                ),
-                uid=[1, 2, 3],
-            )
+        connection.execute(
+            select([t]).where(t.c.foo.in_(bindparam("uid", expanding=True))),
+            uid=[1, 2, 3],
+        )
 
 
 class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -624,15 +624,15 @@ class ExecuteTest(fixtures.TestBase):
                 [(1,)],
             )
 
-    def test_sequences_are_integers(self):
+    def test_sequences_are_integers(self, connection):
         seq = Sequence("foo_seq")
-        seq.create(testing.db)
+        seq.create(connection)
         try:
-            val = testing.db.execute(seq)
+            val = connection.execute(seq)
             eq_(val, 1)
             assert type(val) is int
         finally:
-            seq.drop(testing.db)
+            seq.drop(connection)
 
     @testing.provide_metadata
     def test_limit_offset_for_update(self):
@@ -676,7 +676,7 @@ class UnicodeSchemaTest(fixtures.TestBase):
     __backend__ = True
 
     @testing.provide_metadata
-    def test_quoted_column_non_unicode(self):
+    def test_quoted_column_non_unicode(self, connection):
         metadata = self.metadata
         table = Table(
             "atable",
@@ -685,14 +685,14 @@ class UnicodeSchemaTest(fixtures.TestBase):
         )
         metadata.create_all()
 
-        table.insert().execute({"_underscorecolumn": u("’é")})
-        result = testing.db.execute(
+        connection.execute(table.insert(), {"_underscorecolumn": u("’é")})
+        result = connection.execute(
             table.select().where(table.c._underscorecolumn == u("’é"))
         ).scalar()
         eq_(result, u("’é"))
 
     @testing.provide_metadata
-    def test_quoted_column_unicode(self):
+    def test_quoted_column_unicode(self, connection):
         metadata = self.metadata
         table = Table(
             "atable",
@@ -701,8 +701,8 @@ class UnicodeSchemaTest(fixtures.TestBase):
         )
         metadata.create_all()
 
-        table.insert().execute({u("méil"): u("’é")})
-        result = testing.db.execute(
+        connection.execute(table.insert(), {u("méil"): u("’é")})
+        result = connection.execute(
             table.select().where(table.c[u("méil")] == u("’é"))
         ).scalar()
         eq_(result, u("’é"))
index 70c8f20f24e726c00c66291d004053526d2ba8f3..e0d97230c09ca6ab54c0bd752b486210a71700ff 100644 (file)
@@ -51,9 +51,8 @@ from sqlalchemy.util import py2k
 from sqlalchemy.util import u
 
 
-def exec_sql(engine, sql, *args, **kwargs):
-    with engine.connect() as conn:
-        return conn.exec_driver_sql(sql, *args, **kwargs)
+def exec_sql(conn, sql, *args, **kwargs):
+    return conn.exec_driver_sql(sql, *args, **kwargs)
 
 
 class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -249,15 +248,16 @@ class TypesTest(fixtures.TestBase):
 
         m = self.metadata
         t1 = Table("t1", m, Column("foo", Integer))
-        t1.create()
-        r = engine.execute(t1.insert().values(foo=5).returning(t1.c.foo))
-        x = r.scalar()
-        assert x == 5
-        assert isinstance(x, int)
+        with engine.begin() as conn:
+            t1.create()
+            r = conn.execute(t1.insert().values(foo=5).returning(t1.c.foo))
+            x = r.scalar()
+            assert x == 5
+            assert isinstance(x, int)
 
-        x = t1.select().scalar()
-        assert x == 5
-        assert isinstance(x, int)
+            x = conn.execute(t1.select()).scalar()
+            assert x == 5
+            assert isinstance(x, int)
 
     @testing.provide_metadata
     def test_rowid(self):
@@ -351,7 +351,7 @@ class TypesTest(fixtures.TestBase):
                 )
 
     @testing.provide_metadata
-    def test_numeric_infinity_float(self):
+    def test_numeric_infinity_float(self, connection):
         m = self.metadata
         t1 = Table(
             "t1",
@@ -377,13 +377,13 @@ class TypesTest(fixtures.TestBase):
 
         eq_(
             exec_sql(
-                testing.db, "select numericcol from t1 order by intcol"
+                connection, "select numericcol from t1 order by intcol"
             ).fetchall(),
             [(float("inf"),), (float("-inf"),)],
         )
 
     @testing.provide_metadata
-    def test_numeric_infinity_decimal(self):
+    def test_numeric_infinity_decimal(self, connection):
         m = self.metadata
         t1 = Table(
             "t1",
@@ -409,13 +409,13 @@ class TypesTest(fixtures.TestBase):
 
         eq_(
             exec_sql(
-                testing.db, "select numericcol from t1 order by intcol"
+                connection, "select numericcol from t1 order by intcol"
             ).fetchall(),
             [(decimal.Decimal("Infinity"),), (decimal.Decimal("-Infinity"),)],
         )
 
     @testing.provide_metadata
-    def test_numeric_nan_float(self):
+    def test_numeric_nan_float(self, connection):
         m = self.metadata
         t1 = Table(
             "t1",
@@ -445,7 +445,7 @@ class TypesTest(fixtures.TestBase):
             [
                 tuple(str(col) for col in row)
                 for row in exec_sql(
-                    testing.db, "select numericcol from t1 order by intcol"
+                    connection, "select numericcol from t1 order by intcol"
                 )
             ],
             [("nan",), ("nan",)],
@@ -454,7 +454,7 @@ class TypesTest(fixtures.TestBase):
     # needs https://github.com/oracle/python-cx_Oracle/
     # issues/184#issuecomment-391399292
     @testing.provide_metadata
-    def _dont_test_numeric_nan_decimal(self):
+    def _dont_test_numeric_nan_decimal(self, connection):
         m = self.metadata
         t1 = Table(
             "t1",
@@ -480,13 +480,13 @@ class TypesTest(fixtures.TestBase):
 
         eq_(
             exec_sql(
-                testing.db, "select numericcol from t1 order by intcol"
+                connection, "select numericcol from t1 order by intcol"
             ).fetchall(),
             [(decimal.Decimal("NaN"),), (decimal.Decimal("NaN"),)],
         )
 
     @testing.provide_metadata
-    def test_numerics_broken_inspection(self):
+    def test_numerics_broken_inspection(self, connection):
         """Numeric scenarios where Oracle type info is 'broken',
         returning us precision, scale of the form (0, 0) or (0, -127).
         We convert to Decimal and let int()/float() processors take over.
@@ -506,21 +506,22 @@ class TypesTest(fixtures.TestBase):
             Column("nidata", Numeric(5, 0)),
             Column("fdata", Float()),
         )
-        foo.create()
+        foo.create(connection)
 
-        foo.insert().execute(
+        connection.execute(
+            foo.insert(),
             {
                 "idata": 5,
                 "ndata": decimal.Decimal("45.6"),
                 "ndata2": decimal.Decimal("45.0"),
                 "nidata": decimal.Decimal("53"),
                 "fdata": 45.68392,
-            }
+            },
         )
 
         stmt = "SELECT idata, ndata, ndata2, nidata, fdata FROM foo"
 
-        row = exec_sql(testing.db, stmt).fetchall()[0]
+        row = exec_sql(connection, stmt).fetchall()[0]
         eq_(
             [type(x) for x in row],
             [int, decimal.Decimal, decimal.Decimal, int, float],
@@ -556,7 +557,7 @@ class TypesTest(fixtures.TestBase):
             (SELECT CAST((SELECT fdata FROM foo) AS FLOAT) FROM DUAL) AS fdata
         FROM dual
         """
-        row = exec_sql(testing.db, stmt).fetchall()[0]
+        row = exec_sql(connection, stmt).fetchall()[0]
         eq_(
             [type(x) for x in row],
             [int, decimal.Decimal, int, int, decimal.Decimal],
@@ -566,7 +567,7 @@ class TypesTest(fixtures.TestBase):
             (5, decimal.Decimal("45.6"), 45, 53, decimal.Decimal("45.68392")),
         )
 
-        row = testing.db.execute(
+        row = connection.execute(
             text(stmt).columns(
                 idata=Integer(),
                 ndata=Numeric(20, 2),
@@ -613,7 +614,7 @@ class TypesTest(fixtures.TestBase):
         )
         WHERE ROWNUM >= 0) anon_1
         """
-        row = exec_sql(testing.db, stmt).fetchall()[0]
+        row = exec_sql(connection, stmt).fetchall()[0]
         eq_(
             [type(x) for x in row],
             [int, decimal.Decimal, int, int, decimal.Decimal],
@@ -623,7 +624,7 @@ class TypesTest(fixtures.TestBase):
             (5, decimal.Decimal("45.6"), 45, 53, decimal.Decimal("45.68392")),
         )
 
-        row = testing.db.execute(
+        row = connection.execute(
             text(stmt).columns(
                 anon_1_idata=Integer(),
                 anon_1_ndata=Numeric(20, 2),
@@ -647,7 +648,7 @@ class TypesTest(fixtures.TestBase):
             ),
         )
 
-        row = testing.db.execute(
+        row = connection.execute(
             text(stmt).columns(
                 anon_1_idata=Integer(),
                 anon_1_ndata=Numeric(20, 2, asdecimal=False),
@@ -663,22 +664,23 @@ class TypesTest(fixtures.TestBase):
 
     def test_numeric_no_coerce_decimal_mode(self):
         engine = testing_engine(options=dict(coerce_to_decimal=False))
-
-        # raw SQL no longer coerces to decimal
-        value = exec_sql(engine, "SELECT 5.66 FROM DUAL").scalar()
-        assert isinstance(value, float)
-
-        # explicit typing still *does* coerce to decimal
-        # (change in 1.2)
-        value = engine.scalar(
-            text("SELECT 5.66 AS foo FROM DUAL").columns(
-                foo=Numeric(4, 2, asdecimal=True)
+        with engine.connect() as conn:
+            # raw SQL no longer coerces to decimal
+            value = exec_sql(conn, "SELECT 5.66 FROM DUAL").scalar()
+            assert isinstance(value, float)
+
+            # explicit typing still *does* coerce to decimal
+            # (change in 1.2)
+            value = conn.scalar(
+                text("SELECT 5.66 AS foo FROM DUAL").columns(
+                    foo=Numeric(4, 2, asdecimal=True)
+                )
             )
-        )
-        assert isinstance(value, decimal.Decimal)
+            assert isinstance(value, decimal.Decimal)
 
+    def test_numeric_coerce_decimal_mode(self, connection):
         # default behavior is raw SQL coerces to decimal
-        value = exec_sql(testing.db, "SELECT 5.66 FROM DUAL").scalar()
+        value = exec_sql(connection, "SELECT 5.66 FROM DUAL").scalar()
         assert isinstance(value, decimal.Decimal)
 
     @testing.combinations(
@@ -726,12 +728,15 @@ class TypesTest(fixtures.TestBase):
     @testing.fails_if(
         testing.requires.python3, "cx_oracle always returns unicode on py3k"
     )
-    def test_coerce_to_unicode(self):
+    def test_coerce_to_unicode(self, connection):
         engine = testing_engine(options=dict(coerce_to_unicode=False))
-        value = exec_sql(engine, "SELECT 'hello' FROM DUAL").scalar()
-        assert isinstance(value, util.binary_type)
+        with engine.connect() as conn_no_coerce:
+            value = exec_sql(
+                conn_no_coerce, "SELECT 'hello' FROM DUAL"
+            ).scalar()
+            assert isinstance(value, util.binary_type)
 
-        value = exec_sql(testing.db, "SELECT 'hello' FROM DUAL").scalar()
+        value = exec_sql(connection, "SELECT 'hello' FROM DUAL").scalar()
         assert isinstance(value, util.text_type)
 
     @testing.provide_metadata
@@ -772,7 +777,7 @@ class TypesTest(fixtures.TestBase):
             [row[k] for k in row.keys()]
 
     @testing.provide_metadata
-    def test_raw_roundtrip(self):
+    def test_raw_roundtrip(self, connection):
         metadata = self.metadata
         raw_table = Table(
             "raw",
@@ -781,8 +786,8 @@ class TypesTest(fixtures.TestBase):
             Column("data", oracle.RAW(35)),
         )
         metadata.create_all()
-        testing.db.execute(raw_table.insert(), id=1, data=b("ABCDEF"))
-        eq_(testing.db.execute(raw_table.select()).first(), (1, b("ABCDEF")))
+        connection.execute(raw_table.insert(), id=1, data=b("ABCDEF"))
+        eq_(connection.execute(raw_table.select()).first(), (1, b("ABCDEF")))
 
     @testing.provide_metadata
     def test_reflect_nvarchar(self):
@@ -860,18 +865,19 @@ class TypesTest(fixtures.TestBase):
         eq_(t2.c.c4.type.length, 180)
 
     @testing.provide_metadata
-    def test_long_type(self):
+    def test_long_type(self, connection):
         metadata = self.metadata
 
         t = Table("t", metadata, Column("data", oracle.LONG))
         metadata.create_all(testing.db)
-        testing.db.execute(t.insert(), data="xyz")
-        eq_(testing.db.scalar(select([t.c.data])), "xyz")
+        connection.execute(t.insert(), data="xyz")
+        eq_(connection.scalar(select([t.c.data])), "xyz")
 
-    def test_longstring(self):
-        metadata = MetaData(testing.db)
+    @testing.provide_metadata
+    def test_longstring(self, connection):
+        metadata = self.metadata
         exec_sql(
-            testing.db,
+            connection,
             """
         CREATE TABLE Z_TEST
         (
@@ -881,11 +887,11 @@ class TypesTest(fixtures.TestBase):
         """,
         )
         try:
-            t = Table("z_test", metadata, autoload=True)
-            t.insert().execute(id=1.0, add_user="foobar")
-            assert t.select().execute().fetchall() == [(1, "foobar")]
+            t = Table("z_test", metadata, autoload_with=connection)
+            connection.execute(t.insert(), id=1.0, add_user="foobar")
+            assert connection.execute(t.select()).fetchall() == [(1, "foobar")]
         finally:
-            exec_sql(testing.db, "DROP TABLE Z_TEST")
+            exec_sql(connection, "DROP TABLE Z_TEST")
 
 
 class LOBFetchTest(fixtures.TablesTest):
@@ -923,7 +929,8 @@ class LOBFetchTest(fixtures.TablesTest):
             for i in range(1, 20)
         ]
 
-        testing.db.execute(cls.tables.z_test.insert(), data)
+        with testing.db.begin() as conn:
+            conn.execute(cls.tables.z_test.insert(), data)
 
         binary_table = cls.tables.binary_table
         fname = os.path.join(
@@ -932,24 +939,26 @@ class LOBFetchTest(fixtures.TablesTest):
         with open(fname, "rb") as file_:
             cls.stream = stream = file_.read(12000)
 
-        for i in range(1, 11):
-            binary_table.insert().execute(id=i, data=stream)
+        with testing.db.begin() as conn:
+            for i in range(1, 11):
+                conn.execute(binary_table.insert(), id=i, data=stream)
 
     def test_lobs_without_convert(self):
         engine = testing_engine(options=dict(auto_convert_lobs=False))
         t = self.tables.z_test
-        row = engine.execute(t.select().where(t.c.id == 1)).first()
-        eq_(row["data"].read(), "this is text 1")
-        eq_(row["bindata"].read(), b("this is binary 1"))
+        with engine.begin() as conn:
+            row = conn.execute(t.select().where(t.c.id == 1)).first()
+            eq_(row["data"].read(), "this is text 1")
+            eq_(row["bindata"].read(), b("this is binary 1"))
 
-    def test_lobs_with_convert(self):
+    def test_lobs_with_convert(self, connection):
         t = self.tables.z_test
-        row = testing.db.execute(t.select().where(t.c.id == 1)).first()
+        row = connection.execute(t.select().where(t.c.id == 1)).first()
         eq_(row["data"], "this is text 1")
         eq_(row["bindata"], b("this is binary 1"))
 
-    def test_lobs_with_convert_raw(self):
-        row = exec_sql(testing.db, "select data, bindata from z_test").first()
+    def test_lobs_with_convert_raw(self, connection):
+        row = exec_sql(connection, "select data, bindata from z_test").first()
         eq_(row["data"], "this is text 1")
         eq_(row["bindata"], b("this is binary 1"))
 
@@ -958,7 +967,8 @@ class LOBFetchTest(fixtures.TablesTest):
             options=dict(auto_convert_lobs=False, arraysize=1)
         )
         result = exec_sql(
-            engine, "select id, data, bindata from z_test order by id"
+            engine.connect(),
+            "select id, data, bindata from z_test order by id",
         )
         results = result.fetchall()
 
@@ -991,18 +1001,21 @@ class LOBFetchTest(fixtures.TablesTest):
         engine = testing_engine(
             options=dict(auto_convert_lobs=True, arraysize=1)
         )
-        result = exec_sql(
-            engine, "select id, data, bindata from z_test order by id"
-        )
-        results = result.fetchall()
+        with engine.connect() as conn:
+            result = exec_sql(
+                conn, "select id, data, bindata from z_test order by id",
+            )
+            results = result.fetchall()
 
-        eq_(
-            [
-                dict(id=row["id"], data=row["data"], bindata=row["bindata"])
-                for row in results
-            ],
-            self.data,
-        )
+            eq_(
+                [
+                    dict(
+                        id=row["id"], data=row["data"], bindata=row["bindata"]
+                    )
+                    for row in results
+                ],
+                self.data,
+            )
 
     def test_large_stream(self):
         binary_table = self.tables.binary_table