From: Mike Bayer Date: Wed, 19 Feb 2020 22:59:52 +0000 (-0500) Subject: Modernize test_rowcount and move to dialect suite X-Git-Tag: rel_1_4_0b1~515 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1149a81d53ff6825048dd8dea09fb95c803a8944;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Modernize test_rowcount and move to dialect suite Amazingly there are no "rowcount" tests in suite, so these tests should definitely be there. Change-Id: Ib4c595fe6e16b457680ce4ee01180ccc8ddb6a40 --- diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index 4c71157cd5..d76b33f56a 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -5,6 +5,7 @@ from .test_dialect import * # noqa from .test_insert import * # noqa from .test_reflection import * # noqa from .test_results import * # noqa +from .test_rowcount import * # noqa from .test_select import * # noqa from .test_sequence import * # noqa from .test_types import * # noqa diff --git a/test/sql/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py similarity index 59% rename from test/sql/test_rowcount.py rename to lib/sqlalchemy/testing/suite/test_rowcount.py index 8cff8c98f7..83c2f8da47 100644 --- a/test/sql/test_rowcount.py +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -1,30 +1,25 @@ from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import Integer -from sqlalchemy import MetaData from sqlalchemy import Sequence from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text -from sqlalchemy.testing import AssertsExecutionResults +from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): - - """tests rowcount functionality""" +class RowCountTest(fixtures.TablesTest): + """test rowcount functionality""" __requires__ = ("sane_rowcount",) __backend__ = True @classmethod - def setup_class(cls): - global employees_table, metadata - metadata = MetaData(testing.db) - - employees_table = Table( + def define_tables(cls, metadata): + Table( "employees", metadata, Column( @@ -36,11 +31,10 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): Column("name", String(50)), Column("department", String(1)), ) - metadata.create_all() - def setup(self): - global data - data = [ + @classmethod + def insert_data(cls): + cls.data = data = [ ("Angela", "A"), ("Andrew", "A"), ("Anand", "A"), @@ -52,39 +46,43 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): ("Chris", "C"), ] - i = employees_table.insert() - i.execute(*[{"name": n, "department": d} for n, d in data]) - - def teardown(self): - employees_table.delete().execute() - - @classmethod - def teardown_class(cls): - metadata.drop_all() + employees_table = cls.tables.employees + with config.db.begin() as conn: + conn.execute( + employees_table.insert(), + [{"name": n, "department": d} for n, d in data], + ) def test_basic(self): + employees_table = self.tables.employees s = employees_table.select() r = s.execute().fetchall() - assert len(r) == len(data) + assert len(r) == len(self.data) def test_update_rowcount1(self): + employees_table = self.tables.employees + # WHERE matches 3, 3 rows changed department = employees_table.c.department r = employees_table.update(department == "C").execute(department="Z") assert r.rowcount == 3 - def test_update_rowcount2(self): + def test_update_rowcount2(self, connection): + employees_table = self.tables.employees + # WHERE matches 3, 0 rows changed department = employees_table.c.department - r = employees_table.update(department == "C").execute(department="C") - assert r.rowcount == 3 - @testing.skip_if( - testing.requires.oracle5x, "unknown DBAPI error fixed in later version" - ) + r = connection.execute( + employees_table.update(department == "C"), {"department": "C"} + ) + eq_(r.rowcount, 3) + @testing.requires.sane_rowcount_w_returning - def test_update_rowcount_return_defaults(self): + def test_update_rowcount_return_defaults(self, connection): + employees_table = self.tables.employees + department = employees_table.c.department stmt = ( employees_table.update(department == "C") @@ -92,43 +90,41 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): .return_defaults() ) - r = stmt.execute() - assert r.rowcount == 3 + r = connection.execute(stmt) + eq_(r.rowcount, 3) - def test_raw_sql_rowcount(self): + def test_raw_sql_rowcount(self, connection): # test issue #3622, make sure eager rowcount is called for text - with testing.db.connect() as conn: - result = conn.execute( - "update employees set department='Z' where department='C'" - ) - eq_(result.rowcount, 3) + result = connection.execute( + "update employees set department='Z' where department='C'" + ) + eq_(result.rowcount, 3) - def test_text_rowcount(self): + def test_text_rowcount(self, connection): # test issue #3622, make sure eager rowcount is called for text - with testing.db.connect() as conn: - result = conn.execute( - text( - "update employees set department='Z' " - "where department='C'" - ) - ) - eq_(result.rowcount, 3) + result = connection.execute( + text("update employees set department='Z' " "where department='C'") + ) + eq_(result.rowcount, 3) + + def test_delete_rowcount(self, connection): + employees_table = self.tables.employees - def test_delete_rowcount(self): # WHERE matches 3, 3 rows deleted department = employees_table.c.department - r = employees_table.delete(department == "C").execute() - assert r.rowcount == 3 + r = connection.execute(employees_table.delete(department == "C")) + eq_(r.rowcount, 3) @testing.requires.sane_multi_rowcount - def test_multi_update_rowcount(self): + def test_multi_update_rowcount(self, connection): + employees_table = self.tables.employees stmt = ( employees_table.update() .where(employees_table.c.name == bindparam("emp_name")) .values(department="C") ) - r = testing.db.execute( + r = connection.execute( stmt, [ {"emp_name": "Bob"}, @@ -140,12 +136,14 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): eq_(r.rowcount, 2) @testing.requires.sane_multi_rowcount - def test_multi_delete_rowcount(self): + def test_multi_delete_rowcount(self, connection): + employees_table = self.tables.employees + stmt = employees_table.delete().where( employees_table.c.name == bindparam("emp_name") ) - r = testing.db.execute( + r = connection.execute( stmt, [ {"emp_name": "Bob"},