]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- mapped classes which extend "object" and do not provide an
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Dec 2007 19:20:38 +0000 (19:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Dec 2007 19:20:38 +0000 (19:20 +0000)
__init__() method will now raise TypeError if non-empty *args
or **kwargs are present at instance construction time (and are
not consumed by any extensions such as the scoped_session mapper),
consistent with the behavior of normal Python classes [ticket:908]

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/scoping.py
test/orm/inheritance/basic.py
test/orm/inheritance/manytomany.py
test/orm/mapper.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index e8d066fb2ac92901cc8f49ebb454212bc96fbf8c..4f9cf26f9af16e5043b658cf460593ed7ec0cd97 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -107,7 +107,13 @@ CHANGES
      
    - columns which are missing from a Query's select statement
      now get automatically deferred during load.
-   
+
+   - mapped classes which extend "object" and do not provide an 
+     __init__() method will now raise TypeError if non-empty *args 
+     or **kwargs are present at instance construction time (and are 
+     not consumed by any extensions such as the scoped_session mapper), 
+     consistent with the behavior of normal Python classes [ticket:908]
+     
    - fixed Query bug when filter_by() compares a relation against None
      [ticket:899]
      
index f18e54521f6fbef3ca0b0b1339787457f33475dd..09406652a9afe755bf4017bde5f7754d4e993720 100644 (file)
@@ -1121,14 +1121,19 @@ def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_l
         if extra_init:
             extra_init(class_, oldinit, instance, args, kwargs)
 
-        if doinit:
-            try:
+        try:
+            if doinit:
                 oldinit(instance, *args, **kwargs)
-            except:
-                if on_exception:
-                    on_exception(class_, oldinit, instance, args, kwargs)
-                raise
-    
+            elif args or kwargs:
+                # simulate error message raised by object(), but don't copy
+                # the text verbatim
+                raise TypeError("default constructor for object() takes no parameters")
+        except:
+            if on_exception:
+                on_exception(class_, oldinit, instance, args, kwargs)
+            raise
+                
+            
     # override oldinit
     oldinit = class_.__init__
     if oldinit is None or not hasattr(oldinit, '_oldinit'):
index 3f2f2f049faf409264e675b26147a5cd29bc2c5b..19cd44884cf7f4840b728f491e7e03cda95d4333 100644 (file)
@@ -118,15 +118,19 @@ class _ScopedExt(MapperExtension):
             class_.query = query()
         
     def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+        if self.save_on_init:
+            entity_name = kwargs.pop('_sa_entity_name', None)
+            session = kwargs.pop('_sa_session', None)
         if not isinstance(oldinit, types.MethodType):
             for key, value in kwargs.items():
                 if self.validate:
                     if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
                         raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
                 setattr(instance, key, value)
+            kwargs.clear()
         if self.save_on_init:
-            session = kwargs.pop('_sa_session', self.context.registry())
-            session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+            session = session or self.context.registry()
+            session._save_impl(instance, entity_name=entity_name)
         return EXT_CONTINUE
 
     def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
index 2ef76b6d8da762c69c017b538801963d35a5b7b4..f9c98ac1ce00f03861859129b6159e8995284212 100644 (file)
@@ -348,7 +348,7 @@ class FlushTest(ORMTest):
         )
         admin_mapper = mapper(Admin, admins, inherits=user_mapper)
         sess = create_session()
-        adminrole = Role('admin')
+        adminrole = Role()
         sess.save(adminrole)
         sess.flush()
 
index 7886e90ad1f5ea45bd9b1b47979cee56da270f58..d28ce8adaa6fd5e9ff062952a1a751750cc93446 100644 (file)
@@ -83,9 +83,9 @@ class InheritTest2(ORMTest):
             Column('bar_id', Integer, ForeignKey('bar.bid')))
 
     def testget(self):
-        class Foo(object):pass
-        def __init__(self, data=None):
-            self.data = data
+        class Foo(object):
+            def __init__(self, data=None):
+                self.data = data
         class Bar(Foo):pass
 
         mapper(Foo, foo)
@@ -128,7 +128,7 @@ class InheritTest2(ORMTest):
         sess.flush()
         sess.clear()
 
-        l = sess.query(Bar).select()
+        l = sess.query(Bar).all()
         print l[0]
         print l[0].foos
         self.assert_unordered_result(l, Bar,
@@ -191,7 +191,7 @@ class InheritTest3(ORMTest):
         sess.flush()
         compare = repr(b) + repr(sorted([repr(o) for o in b.foos]))
         sess.clear()
-        l = sess.query(Bar).select()
+        l = sess.query(Bar).all()
         print repr(l[0]) + repr(l[0].foos)
         found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos]))
         self.assertEqual(found, compare)
@@ -233,11 +233,11 @@ class InheritTest3(ORMTest):
         blubid = bl1.id
         sess.clear()
 
-        l = sess.query(Blub).select()
+        l = sess.query(Blub).all()
         print l
         self.assert_(repr(l[0]) == compare)
         sess.clear()
-        x = sess.query(Blub).get_by(id=blubid)
+        x = sess.query(Blub).filter_by(id=blubid).one()
         print x
         self.assert_(repr(x) == compare)
 
index e1aca134579f7e525a1f889c29be174b22ab73dc..28885356cbd914df7a99ff514564df5933d89be8 100644 (file)
@@ -140,6 +140,30 @@ class MapperTest(MapperSuperTest):
         except Exception, e:
             assert e is ex
 
+        clear_mappers()
+        
+        # test that TypeError is raised for illegal constructor args,
+        # whether or not explicit __init__ is present [ticket:908]
+        class Foo(object):
+            def __init__(self):
+                pass
+        class Bar(object):
+            pass
+                
+        mapper(Foo, users)
+        mapper(Bar, addresses)
+        try:
+            Foo(x=5)
+            assert False
+        except TypeError:
+            assert True
+
+        try:
+            Bar(x=5)
+            assert False
+        except TypeError:
+            assert True
+
     def test_props(self):
         m = mapper(User, users, properties = {
             'addresses' : relation(mapper(Address, addresses))
@@ -1247,7 +1271,7 @@ class MapperExtensionTest(PersistTest):
         
         sess = create_session()
         i1 = Item()
-        k1 = Keyword('blue')
+        k1 = Keyword()
         sess.save(i1)
         sess.save(k1)
         sess.flush()
index db9245c7284ac93261e224396dff71b010277b56..8baf752754b9188b033046363f761f58238aeac9 100644 (file)
@@ -894,7 +894,7 @@ class ScopedMapperTest(PersistTest):
             pass
         Session.mapper(Baz, table2, extension=ext)
         assert hasattr(Baz, 'query')
-
+    
     def test_validating_constructor(self):
         s2 = SomeObject(someid=12)
         s3 = SomeOtherObject(someid=123, bogus=345)