]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Modernize test_rowcount and move to dialect suite
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Feb 2020 22:59:52 +0000 (17:59 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Feb 2020 23:05:33 +0000 (18:05 -0500)
Amazingly there are no "rowcount" tests in suite, so these
tests should definitely be there.

Change-Id: Ib4c595fe6e16b457680ce4ee01180ccc8ddb6a40

lib/sqlalchemy/testing/suite/__init__.py
lib/sqlalchemy/testing/suite/test_rowcount.py [moved from test/sql/test_rowcount.py with 59% similarity]

index 4c71157cd5143cb5010e0b9d91dd6edbbc0caea2..d76b33f56a4b83f1f3a646ed806fa161f6d72f36 100644 (file)
@@ -5,6 +5,7 @@ from .test_dialect import *  # noqa
 from .test_insert import *  # noqa
 from .test_reflection import *  # noqa
 from .test_results import *  # noqa
+from .test_rowcount import *  # noqa
 from .test_select import *  # noqa
 from .test_sequence import *  # noqa
 from .test_types import *  # noqa
similarity index 59%
rename from test/sql/test_rowcount.py
rename to lib/sqlalchemy/testing/suite/test_rowcount.py
index 8cff8c98f75e413aae8f2f3d20c68256b61195dc..83c2f8da47e5a61a97d62e46e811e6fc460bcecd 100644 (file)
@@ -1,30 +1,25 @@
 from sqlalchemy import bindparam
 from sqlalchemy import Column
 from sqlalchemy import Integer
-from sqlalchemy import MetaData
 from sqlalchemy import Sequence
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
-from sqlalchemy.testing import AssertsExecutionResults
+from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 
 
-class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
-
-    """tests rowcount functionality"""
+class RowCountTest(fixtures.TablesTest):
+    """test rowcount functionality"""
 
     __requires__ = ("sane_rowcount",)
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global employees_table, metadata
-        metadata = MetaData(testing.db)
-
-        employees_table = Table(
+    def define_tables(cls, metadata):
+        Table(
             "employees",
             metadata,
             Column(
@@ -36,11 +31,10 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
             Column("name", String(50)),
             Column("department", String(1)),
         )
-        metadata.create_all()
 
-    def setup(self):
-        global data
-        data = [
+    @classmethod
+    def insert_data(cls):
+        cls.data = data = [
             ("Angela", "A"),
             ("Andrew", "A"),
             ("Anand", "A"),
@@ -52,39 +46,43 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
             ("Chris", "C"),
         ]
 
-        i = employees_table.insert()
-        i.execute(*[{"name": n, "department": d} for n, d in data])
-
-    def teardown(self):
-        employees_table.delete().execute()
-
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
+        employees_table = cls.tables.employees
+        with config.db.begin() as conn:
+            conn.execute(
+                employees_table.insert(),
+                [{"name": n, "department": d} for n, d in data],
+            )
 
     def test_basic(self):
+        employees_table = self.tables.employees
         s = employees_table.select()
         r = s.execute().fetchall()
 
-        assert len(r) == len(data)
+        assert len(r) == len(self.data)
 
     def test_update_rowcount1(self):
+        employees_table = self.tables.employees
+
         # WHERE matches 3, 3 rows changed
         department = employees_table.c.department
         r = employees_table.update(department == "C").execute(department="Z")
         assert r.rowcount == 3
 
-    def test_update_rowcount2(self):
+    def test_update_rowcount2(self, connection):
+        employees_table = self.tables.employees
+
         # WHERE matches 3, 0 rows changed
         department = employees_table.c.department
-        r = employees_table.update(department == "C").execute(department="C")
-        assert r.rowcount == 3
 
-    @testing.skip_if(
-        testing.requires.oracle5x, "unknown DBAPI error fixed in later version"
-    )
+        r = connection.execute(
+            employees_table.update(department == "C"), {"department": "C"}
+        )
+        eq_(r.rowcount, 3)
+
     @testing.requires.sane_rowcount_w_returning
-    def test_update_rowcount_return_defaults(self):
+    def test_update_rowcount_return_defaults(self, connection):
+        employees_table = self.tables.employees
+
         department = employees_table.c.department
         stmt = (
             employees_table.update(department == "C")
@@ -92,43 +90,41 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
             .return_defaults()
         )
 
-        r = stmt.execute()
-        assert r.rowcount == 3
+        r = connection.execute(stmt)
+        eq_(r.rowcount, 3)
 
-    def test_raw_sql_rowcount(self):
+    def test_raw_sql_rowcount(self, connection):
         # test issue #3622, make sure eager rowcount is called for text
-        with testing.db.connect() as conn:
-            result = conn.execute(
-                "update employees set department='Z' where department='C'"
-            )
-            eq_(result.rowcount, 3)
+        result = connection.execute(
+            "update employees set department='Z' where department='C'"
+        )
+        eq_(result.rowcount, 3)
 
-    def test_text_rowcount(self):
+    def test_text_rowcount(self, connection):
         # test issue #3622, make sure eager rowcount is called for text
-        with testing.db.connect() as conn:
-            result = conn.execute(
-                text(
-                    "update employees set department='Z' "
-                    "where department='C'"
-                )
-            )
-            eq_(result.rowcount, 3)
+        result = connection.execute(
+            text("update employees set department='Z' " "where department='C'")
+        )
+        eq_(result.rowcount, 3)
+
+    def test_delete_rowcount(self, connection):
+        employees_table = self.tables.employees
 
-    def test_delete_rowcount(self):
         # WHERE matches 3, 3 rows deleted
         department = employees_table.c.department
-        r = employees_table.delete(department == "C").execute()
-        assert r.rowcount == 3
+        r = connection.execute(employees_table.delete(department == "C"))
+        eq_(r.rowcount, 3)
 
     @testing.requires.sane_multi_rowcount
-    def test_multi_update_rowcount(self):
+    def test_multi_update_rowcount(self, connection):
+        employees_table = self.tables.employees
         stmt = (
             employees_table.update()
             .where(employees_table.c.name == bindparam("emp_name"))
             .values(department="C")
         )
 
-        r = testing.db.execute(
+        r = connection.execute(
             stmt,
             [
                 {"emp_name": "Bob"},
@@ -140,12 +136,14 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(r.rowcount, 2)
 
     @testing.requires.sane_multi_rowcount
-    def test_multi_delete_rowcount(self):
+    def test_multi_delete_rowcount(self, connection):
+        employees_table = self.tables.employees
+
         stmt = employees_table.delete().where(
             employees_table.c.name == bindparam("emp_name")
         )
 
-        r = testing.db.execute(
+        r = connection.execute(
             stmt,
             [
                 {"emp_name": "Bob"},