From: Mike Bayer Date: Mon, 1 Dec 2008 22:09:15 +0000 (+0000) Subject: - Improved mapper() check for non-class classes. X-Git-Tag: rel_0_5_0~151 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c136b7a6e9564ce9325ea898a483d9c613261909;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Improved mapper() check for non-class classes. [ticket:1236] --- diff --git a/CHANGES b/CHANGES index f562f2cdea..00aac79d87 100644 --- a/CHANGES +++ b/CHANGES @@ -34,6 +34,9 @@ CHANGES primaryjoin/secondaryjoin are ClauseElement instances, to prevent more confusing errors later on. + + - Improved mapper() check for non-class classes. + [ticket:1236] - comparator_factory argument is now documented and supported by all MapperProperty types, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a6b49ab6b6..b5f43b12e9 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -101,7 +101,9 @@ class Mapper(object): function. See for details. """ - self.class_ = class_ + + self.class_ = util.assert_arg_type(class_, type, 'class_') + self.class_manager = None self.primary_key_argument = primary_key @@ -134,8 +136,6 @@ class Mapper(object): self._requires_row_aliasing = False self._inherits_equated_pairs = None - if not issubclass(class_, object): - raise sa_exc.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) self.select_table = select_table if select_table: diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 8cf4fec06e..b18d8ac001 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -2104,6 +2104,8 @@ class RequirementsTest(_base.MappedTest): self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1) + self.assertRaises(sa.exc.ArgumentError, mapper, 123) + class NoWeakrefSupport(str): pass diff --git a/test/testlib/orm.py b/test/testlib/orm.py index b460102a61..35469edcae 100644 --- a/test/testlib/orm.py +++ b/test/testlib/orm.py @@ -109,7 +109,7 @@ def mapper(type_, *args, **kw): ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)), ('__nonzero__', 'truthless', lambda s: 1), ] - if type_.__bases__ == (object,): + if isinstance(type_, type) and type_.__bases__ == (object,): for method_name, option, fallback in forbidden: if (getattr(config.options, option, False) and method_name not in type_.__dict__):