From c136b7a6e9564ce9325ea898a483d9c613261909 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 1 Dec 2008 22:09:15 +0000 Subject: [PATCH] - Improved mapper() check for non-class classes. [ticket:1236] --- CHANGES | 3 +++ lib/sqlalchemy/orm/mapper.py | 6 +++--- test/orm/mapper.py | 2 ++ test/testlib/orm.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGES b/CHANGES index f562f2cdea..00aac79d87 100644 --- a/CHANGES +++ b/CHANGES @@ -34,6 +34,9 @@ CHANGES primaryjoin/secondaryjoin are ClauseElement instances, to prevent more confusing errors later on. + + - Improved mapper() check for non-class classes. + [ticket:1236] - comparator_factory argument is now documented and supported by all MapperProperty types, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a6b49ab6b6..b5f43b12e9 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -101,7 +101,9 @@ class Mapper(object): function. See for details. """ - self.class_ = class_ + + self.class_ = util.assert_arg_type(class_, type, 'class_') + self.class_manager = None self.primary_key_argument = primary_key @@ -134,8 +136,6 @@ class Mapper(object): self._requires_row_aliasing = False self._inherits_equated_pairs = None - if not issubclass(class_, object): - raise sa_exc.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) self.select_table = select_table if select_table: diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 8cf4fec06e..b18d8ac001 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -2104,6 +2104,8 @@ class RequirementsTest(_base.MappedTest): self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1) + self.assertRaises(sa.exc.ArgumentError, mapper, 123) + class NoWeakrefSupport(str): pass diff --git a/test/testlib/orm.py b/test/testlib/orm.py index b460102a61..35469edcae 100644 --- a/test/testlib/orm.py +++ b/test/testlib/orm.py @@ -109,7 +109,7 @@ def mapper(type_, *args, **kw): ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)), ('__nonzero__', 'truthless', lambda s: 1), ] - if type_.__bases__ == (object,): + if isinstance(type_, type) and type_.__bases__ == (object,): for method_name, option, fallback in forbidden: if (getattr(config.options, option, False) and method_name not in type_.__dict__): -- 2.47.3