]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added __autoload__ = True for declarative
authorJason Kirtland <jek@discorporate.us>
Mon, 10 Mar 2008 18:32:07 +0000 (18:32 +0000)
committerJason Kirtland <jek@discorporate.us>
Mon, 10 Mar 2008 18:32:07 +0000 (18:32 +0000)
- declarative Base.__init__ is pickier about its kwargs

lib/sqlalchemy/ext/declarative.py
test/ext/declarative.py

index 77bffad7ac194f036c0a65dca5c6adeac6ba44fe..fbb68f579713374f7fbf5e0bbca778f72fe1e0b2 100644 (file)
@@ -170,9 +170,16 @@ class DeclarativeMeta(type):
         if '__table__' not in cls.__dict__:
             if '__tablename__' in cls.__dict__:
                 tablename = cls.__tablename__
+                autoload = cls.__dict__.get('__autoload__')
+                if autoload is True:
+                    table_kw = {'autoload': True}
+                elif autoload:
+                    table_kw = {'autoload': True, 'autoload_with': autoload}
+                else:
+                    table_kw = {}
                 cls.__table__ = table = Table(tablename, cls.metadata, *[
                     c for c in our_stuff.values() if isinstance(c, Column)
-                ])
+                ], **table_kw)
         else:
             table = cls.__table__
         
@@ -234,6 +241,9 @@ def declarative_base(engine=None, metadata=None):
         _decl_class_registry = {}
         def __init__(self, **kwargs):
             for k in kwargs:
+                if not hasattr(type(self), k):
+                    raise TypeError('%r is an invalid keyword argument for %s' %
+                                    (k, type(self).__name__))
                 setattr(self, k, kwargs[k])
     return Base
 
index 73d2578852a9961dcf0ce589941f3c1ca331959d..cef34b994641f027dadfad3d7a2c78bbd2161612 100644 (file)
@@ -216,25 +216,112 @@ class DeclarativeTest(TestBase):
         
         self.assertEquals(sess.query(Person).filter(Engineer.primary_language=='cobol').first(), Engineer(name='vlad'))
         self.assertEquals(sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).first(), c2)
-    
-    def test_reflection(self):
+
+    def test_with_explicit_autoloaded(self):
         meta = MetaData(testing.db)
         t1 = Table('t1', meta, Column('id', String(50), primary_key=True), Column('data', String(50)))
         meta.create_all()
         try:
             class MyObj(Base):
                 __table__ = Table('t1', Base.metadata, autoload=True)
-            
+
             sess = create_session()
             m = MyObj(id="someid", data="somedata")
             sess.save(m)
             sess.flush()
-            
+
             assert t1.select().execute().fetchall() == [('someid', 'somedata')]
-            
         finally:
             meta.drop_all()
-        
-        
+
+
+class DeclarativeReflectionTest(TestBase):
+    def setUpAll(self):
+        global reflection_metadata
+        reflection_metadata = MetaData(testing.db)
+
+        Table('users', reflection_metadata,
+              Column('id', Integer, primary_key=True),
+              Column('name', String(50)))
+        Table('addresses', reflection_metadata,
+              Column('id', Integer, primary_key=True),
+              Column('email', String(50)),
+              Column('user_id', Integer, ForeignKey('users.id')))
+        reflection_metadata.create_all()
+
+    def setUp(self):
+        global Base
+        Base = declarative_base(testing.db)
+
+    def tearDown(self):
+        for t in reflection_metadata.table_iterator():
+            t.delete().execute()
+
+    def tearDownAll(self):
+        reflection_metadata.drop_all()
+
+    def test_basic(self):
+        meta = MetaData(testing.db)
+
+        class User(Base, Fixture):
+            __tablename__ = 'users'
+            __autoload__ = True
+            addresses = relation("Address", backref="user")
+
+        class Address(Base, Fixture):
+            __tablename__ = 'addresses'
+            __autoload__ = True
+
+        u1 = User(name='u1', addresses=[
+            Address(email='one'),
+            Address(email='two'),
+            ])
+        sess = create_session()
+        sess.save(u1)
+        sess.flush()
+        sess.clear()
+
+        self.assertEquals(sess.query(User).all(), [User(name='u1', addresses=[
+            Address(email='one'),
+            Address(email='two'),
+            ])])
+
+        a1 = sess.query(Address).filter(Address.email=='two').one()
+        self.assertEquals(a1, Address(email='two'))
+        self.assertEquals(a1.user, User(name='u1'))
+
+    def test_rekey(self):
+        meta = MetaData(testing.db)
+
+        class User(Base, Fixture):
+            __tablename__ = 'users'
+            __autoload__ = True
+            nom = Column('name', String(50), key='nom')
+            addresses = relation("Address", backref="user")
+
+        class Address(Base, Fixture):
+            __tablename__ = 'addresses'
+            __autoload__ = True
+
+        u1 = User(nom='u1', addresses=[
+            Address(email='one'),
+            Address(email='two'),
+            ])
+        sess = create_session()
+        sess.save(u1)
+        sess.flush()
+        sess.clear()
+
+        self.assertEquals(sess.query(User).all(), [User(nom='u1', addresses=[
+            Address(email='one'),
+            Address(email='two'),
+            ])])
+
+        a1 = sess.query(Address).filter(Address.email=='two').one()
+        self.assertEquals(a1, Address(email='two'))
+        self.assertEquals(a1.user, User(nom='u1'))
+
+        self.assertRaises(TypeError, User, name='u3')
+
 if __name__ == '__main__':
-    testing.main()
\ No newline at end of file
+    testing.main()