]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Improved mapper() check for non-class classes.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Dec 2008 22:09:15 +0000 (22:09 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Dec 2008 22:09:15 +0000 (22:09 +0000)
[ticket:1236]

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/mapper.py
test/testlib/orm.py

diff --git a/CHANGES b/CHANGES
index f562f2cdea647533f5b56a8dd0b9f3aa2aa50da7..00aac79d87b0f3fb3421b21617a6f178f0082485 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -34,6 +34,9 @@ CHANGES
       primaryjoin/secondaryjoin are ClauseElement 
       instances, to prevent more confusing errors later 
       on.
+
+    - Improved mapper() check for non-class classes.
+      [ticket:1236]
     
     - comparator_factory argument is now documented
       and supported by all MapperProperty types,
index a6b49ab6b65a8199ba9406d5ee4f7ef13a56450d..b5f43b12e96d06c20f03ef028e5b63a9cd7cc746 100644 (file)
@@ -101,7 +101,9 @@ class Mapper(object):
         function.  See for details.
 
         """
-        self.class_ = class_
+
+        self.class_ = util.assert_arg_type(class_, type, 'class_')
+
         self.class_manager = None
 
         self.primary_key_argument = primary_key
@@ -134,8 +136,6 @@ class Mapper(object):
         self._requires_row_aliasing = False
         self._inherits_equated_pairs = None
 
-        if not issubclass(class_, object):
-            raise sa_exc.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
 
         self.select_table = select_table
         if select_table:
index 8cf4fec06e6b229b2e868730c50a26a16c259728..b18d8ac0015f1f963d9bed084dbe9eeaa0da3d4c 100644 (file)
@@ -2104,6 +2104,8 @@ class RequirementsTest(_base.MappedTest):
 
         self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1)
 
+        self.assertRaises(sa.exc.ArgumentError, mapper, 123)
+        
         class NoWeakrefSupport(str):
             pass
 
index b460102a615b6a4c16500da2a01fe294ca3d8f5e..35469edcaeb59e6f1fbd94772177f5ba3a9109e6 100644 (file)
@@ -109,7 +109,7 @@ def mapper(type_, *args, **kw):
         ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)),
         ('__nonzero__', 'truthless', lambda s: 1), ]
 
-    if type_.__bases__ == (object,):
+    if isinstance(type_, type) and type_.__bases__ == (object,):
         for method_name, option, fallback in forbidden:
             if (getattr(config.options, option, False) and
                 method_name not in type_.__dict__):