from __future__ import annotations
+from collections.abc import Callable
+from collections.abc import Sequence
+
import sqlalchemy as sa
from .. import assertions
from .. import config
@config.mark_base_test_class()
class TestBase:
# A sequence of requirement names matching testing.requires decorators
- __requires__ = ()
+ __requires__: tuple[str, ...] = ()
# A sequence of dialect names to exclude from the test class.
- __unsupported_on__ = ()
+ __unsupported_on__: tuple[str, ...] = ()
# If present, test class is only runnable for the *single* specified
# dialect. If you need multiple, use __unsupported_on__ and invert.
- __only_on__ = None
+ __only_on__: tuple[str, ...] | str | None = None
# A sequence of no-arg callables. If any are True, the entire testcase is
# skipped.
- __skip_if__ = None
+ __skip_if__: Sequence[Callable[[], bool]] | None = None
# if True, the testing reaper will not attempt to touch connection
# state after a test is completed and before the outer teardown
# starts
- __leave_connections_for_teardown__ = False
+ __leave_connections_for_teardown__: bool = False
+
+ __backend__: bool
def assert_(self, val, msg=None):
assert val, msg
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.assertions import eq_regex
from sqlalchemy.testing.assertions import expect_raises
+from sqlalchemy.testing.assertions import in_
from sqlalchemy.testing.assertions import ne_
+from sqlalchemy.testing.assertions import not_in
class DialectTest(fixtures.TestBase):
is_true(isinstance(cursor, AsyncClientCursor))
await engine.dispose()
+
+
+class TwoPhaseCommitTest(fixtures.TestBase):
+ __only_on__ = ("+psycopg2", "+psycopg")
+ __backend__ = True
+
+ @testing.fixture(autouse=True)
+ def reap_xid(self):
+ with config.db.connect() as connection:
+ before = connection.recover_twophase()
+ yield
+ with config.db.connect() as connection:
+ for xid in connection.recover_twophase():
+ if xid not in before:
+ connection.rollback_prepared(xid, recover=True)
+
+ @testing.variation("mode", ["noid", "withid", "driverid"])
+ @testing.variation("commit", [True, False])
+ def test_provided_id_round_trip(self, mode: testing.Variation, commit):
+ c1 = config.db.connect()
+ dc = c1.connection.driver_connection
+ c2 = config.db.connect()
+ if mode.noid:
+ transaction = c1.begin_twophase()
+ xid = transaction.xid
+ elif mode.withid:
+ xid = "myid"
+ transaction = c1.begin_twophase(xid)
+ eq_(transaction.xid, "myid")
+ elif mode.driverid:
+ xid_obj = dc.xid(42, "abc", "def")
+ xid = str(xid_obj)
+ transaction = c1.begin_twophase(xid_obj)
+ eq_(transaction.xid, xid_obj)
+ else:
+ mode.fail()
+ transaction.prepare()
+ in_(xid, c2.recover_twophase())
+ if commit:
+ c2.commit_prepared(xid, recover=True)
+ else:
+ c2.rollback_prepared(xid, recover=True)
+ not_in(xid, c2.recover_twophase())
+ c2.close()
+ c1.detach()
+ dc.close()
+
+ @testing.variation("commit", [True, False])
+ def test_default_pg_dialect(self, commit):
+ dialect = postgresql.PGDialect
+ c1 = config.db.connect()
+ dc = c1.connection.driver_connection
+ c2 = config.db.connect()
+ c2.execution_options(isolation_level="AUTOCOMMIT")
+ xid = "myid"
+ dialect.do_begin_twophase(c1.dialect, c1, xid)
+ dialect.do_prepare_twophase(c1.dialect, c1, xid)
+
+ in_(xid, dialect.do_recover_twophase(c2.dialect, c2))
+ if commit:
+ dialect.do_commit_twophase(c2.dialect, c2, xid, recover=True)
+ else:
+ dialect.do_rollback_twophase(c2.dialect, c2, xid, recover=True)
+ not_in(xid, dialect.do_recover_twophase(c2.dialect, c2))
+ c2.close()
+ c1.detach()
+ dc.close()