from unittest.mock import Mock
from sqlalchemy import exc
+from sqlalchemy import testing
from sqlalchemy.orm import collections
from sqlalchemy.orm import relationship
from sqlalchemy.orm import validates
users,
)
- def test_validator_wo_backrefs_wo_removes(self):
- self._test_validator_backrefs(False, False)
-
- def test_validator_wo_backrefs_w_removes(self):
- self._test_validator_backrefs(False, True)
-
- def test_validator_w_backrefs_wo_removes(self):
- self._test_validator_backrefs(True, False)
-
- def test_validator_w_backrefs_w_removes(self):
- self._test_validator_backrefs(True, True)
-
- def _test_validator_backrefs(self, include_backrefs, include_removes):
+ @testing.variation("include_backrefs", [True, False, "default"])
+ @testing.variation("include_removes", [True, False, "default"])
+ def test_validator_backrefs(self, include_backrefs, include_removes):
users, addresses = (self.tables.users, self.tables.addresses)
canary = Mock()
+ need_remove_param = (
+ bool(include_removes) and not include_removes.default
+ )
+ validate_kw = {}
+ if not include_removes.default:
+ validate_kw["include_removes"] = bool(include_removes)
+ if not include_backrefs.default:
+ validate_kw["include_backrefs"] = bool(include_backrefs)
+
+ expect_include_backrefs = include_backrefs.default or bool(
+ include_backrefs
+ )
+ expect_include_removes = (
+ bool(include_removes) and not include_removes.default
+ )
+
class User(fixtures.ComparableEntity):
- if include_removes:
+ if need_remove_param:
- @validates(
- "addresses",
- include_removes=True,
- include_backrefs=include_backrefs,
- )
+ @validates("addresses", **validate_kw)
def validate_address(self, key, item, remove):
canary(key, item, remove)
return item
else:
- @validates(
- "addresses",
- include_removes=False,
- include_backrefs=include_backrefs,
- )
+ @validates("addresses", **validate_kw)
def validate_address(self, key, item):
canary(key, item)
return item
class Address(fixtures.ComparableEntity):
- if include_removes:
- @validates(
- "user",
- include_backrefs=include_backrefs,
- include_removes=True,
- )
+ if need_remove_param:
+
+ @validates("user", **validate_kw)
def validate_user(self, key, item, remove):
canary(key, item, remove)
return item
else:
- @validates("user", include_backrefs=include_backrefs)
+ @validates("user", **validate_kw)
def validate_user(self, key, item):
canary(key, item)
return item
# comparisons don't get caught
calls = list(canary.mock_calls)
- if include_backrefs:
- if include_removes:
+ if expect_include_backrefs:
+ if expect_include_removes:
eq_(
calls,
[
],
)
else:
- if include_removes:
+ if expect_include_removes:
eq_(
calls,
[