from sqlalchemy import bindparam
from sqlalchemy import Column
from sqlalchemy import Integer
-from sqlalchemy import MetaData
from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
-from sqlalchemy.testing import AssertsExecutionResults
+from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
-
- """tests rowcount functionality"""
+class RowCountTest(fixtures.TablesTest):
+ """test rowcount functionality"""
__requires__ = ("sane_rowcount",)
__backend__ = True
@classmethod
- def setup_class(cls):
- global employees_table, metadata
- metadata = MetaData(testing.db)
-
- employees_table = Table(
+ def define_tables(cls, metadata):
+ Table(
"employees",
metadata,
Column(
Column("name", String(50)),
Column("department", String(1)),
)
- metadata.create_all()
- def setup(self):
- global data
- data = [
+ @classmethod
+ def insert_data(cls):
+ cls.data = data = [
("Angela", "A"),
("Andrew", "A"),
("Anand", "A"),
("Chris", "C"),
]
- i = employees_table.insert()
- i.execute(*[{"name": n, "department": d} for n, d in data])
-
- def teardown(self):
- employees_table.delete().execute()
-
- @classmethod
- def teardown_class(cls):
- metadata.drop_all()
+ employees_table = cls.tables.employees
+ with config.db.begin() as conn:
+ conn.execute(
+ employees_table.insert(),
+ [{"name": n, "department": d} for n, d in data],
+ )
def test_basic(self):
+ employees_table = self.tables.employees
s = employees_table.select()
r = s.execute().fetchall()
- assert len(r) == len(data)
+ assert len(r) == len(self.data)
def test_update_rowcount1(self):
+ employees_table = self.tables.employees
+
# WHERE matches 3, 3 rows changed
department = employees_table.c.department
r = employees_table.update(department == "C").execute(department="Z")
assert r.rowcount == 3
- def test_update_rowcount2(self):
+ def test_update_rowcount2(self, connection):
+ employees_table = self.tables.employees
+
# WHERE matches 3, 0 rows changed
department = employees_table.c.department
- r = employees_table.update(department == "C").execute(department="C")
- assert r.rowcount == 3
- @testing.skip_if(
- testing.requires.oracle5x, "unknown DBAPI error fixed in later version"
- )
+ r = connection.execute(
+ employees_table.update(department == "C"), {"department": "C"}
+ )
+ eq_(r.rowcount, 3)
+
@testing.requires.sane_rowcount_w_returning
- def test_update_rowcount_return_defaults(self):
+ def test_update_rowcount_return_defaults(self, connection):
+ employees_table = self.tables.employees
+
department = employees_table.c.department
stmt = (
employees_table.update(department == "C")
.return_defaults()
)
- r = stmt.execute()
- assert r.rowcount == 3
+ r = connection.execute(stmt)
+ eq_(r.rowcount, 3)
- def test_raw_sql_rowcount(self):
+ def test_raw_sql_rowcount(self, connection):
# test issue #3622, make sure eager rowcount is called for text
- with testing.db.connect() as conn:
- result = conn.execute(
- "update employees set department='Z' where department='C'"
- )
- eq_(result.rowcount, 3)
+ result = connection.execute(
+ "update employees set department='Z' where department='C'"
+ )
+ eq_(result.rowcount, 3)
- def test_text_rowcount(self):
+ def test_text_rowcount(self, connection):
# test issue #3622, make sure eager rowcount is called for text
- with testing.db.connect() as conn:
- result = conn.execute(
- text(
- "update employees set department='Z' "
- "where department='C'"
- )
- )
- eq_(result.rowcount, 3)
+ result = connection.execute(
+ text("update employees set department='Z' " "where department='C'")
+ )
+ eq_(result.rowcount, 3)
+
+ def test_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
- def test_delete_rowcount(self):
# WHERE matches 3, 3 rows deleted
department = employees_table.c.department
- r = employees_table.delete(department == "C").execute()
- assert r.rowcount == 3
+ r = connection.execute(employees_table.delete(department == "C"))
+ eq_(r.rowcount, 3)
@testing.requires.sane_multi_rowcount
- def test_multi_update_rowcount(self):
+ def test_multi_update_rowcount(self, connection):
+ employees_table = self.tables.employees
stmt = (
employees_table.update()
.where(employees_table.c.name == bindparam("emp_name"))
.values(department="C")
)
- r = testing.db.execute(
+ r = connection.execute(
stmt,
[
{"emp_name": "Bob"},
eq_(r.rowcount, 2)
@testing.requires.sane_multi_rowcount
- def test_multi_delete_rowcount(self):
+ def test_multi_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
stmt = employees_table.delete().where(
employees_table.c.name == bindparam("emp_name")
)
- r = testing.db.execute(
+ r = connection.execute(
stmt,
[
{"emp_name": "Bob"},