From: Jason Kirtland Date: Mon, 10 Mar 2008 18:32:07 +0000 (+0000) Subject: - Added __autoload__ = True for declarative X-Git-Tag: rel_0_4_4~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ac13a8445b60e5c4373bcba301a3df16cf743311;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Added __autoload__ = True for declarative - declarative Base.__init__ is pickier about its kwargs --- diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 77bffad7ac..fbb68f5797 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -170,9 +170,16 @@ class DeclarativeMeta(type): if '__table__' not in cls.__dict__: if '__tablename__' in cls.__dict__: tablename = cls.__tablename__ + autoload = cls.__dict__.get('__autoload__') + if autoload is True: + table_kw = {'autoload': True} + elif autoload: + table_kw = {'autoload': True, 'autoload_with': autoload} + else: + table_kw = {} cls.__table__ = table = Table(tablename, cls.metadata, *[ c for c in our_stuff.values() if isinstance(c, Column) - ]) + ], **table_kw) else: table = cls.__table__ @@ -234,6 +241,9 @@ def declarative_base(engine=None, metadata=None): _decl_class_registry = {} def __init__(self, **kwargs): for k in kwargs: + if not hasattr(type(self), k): + raise TypeError('%r is an invalid keyword argument for %s' % + (k, type(self).__name__)) setattr(self, k, kwargs[k]) return Base diff --git a/test/ext/declarative.py b/test/ext/declarative.py index 73d2578852..cef34b9946 100644 --- a/test/ext/declarative.py +++ b/test/ext/declarative.py @@ -216,25 +216,112 @@ class DeclarativeTest(TestBase): self.assertEquals(sess.query(Person).filter(Engineer.primary_language=='cobol').first(), Engineer(name='vlad')) self.assertEquals(sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).first(), c2) - - def test_reflection(self): + + def test_with_explicit_autoloaded(self): meta = MetaData(testing.db) t1 = Table('t1', meta, Column('id', String(50), primary_key=True), Column('data', String(50))) meta.create_all() try: class MyObj(Base): __table__ = Table('t1', Base.metadata, autoload=True) - + sess = create_session() m = MyObj(id="someid", data="somedata") sess.save(m) sess.flush() - + assert t1.select().execute().fetchall() == [('someid', 'somedata')] - finally: meta.drop_all() - - + + +class DeclarativeReflectionTest(TestBase): + def setUpAll(self): + global reflection_metadata + reflection_metadata = MetaData(testing.db) + + Table('users', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + Table('addresses', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('email', String(50)), + Column('user_id', Integer, ForeignKey('users.id'))) + reflection_metadata.create_all() + + def setUp(self): + global Base + Base = declarative_base(testing.db) + + def tearDown(self): + for t in reflection_metadata.table_iterator(): + t.delete().execute() + + def tearDownAll(self): + reflection_metadata.drop_all() + + def test_basic(self): + meta = MetaData(testing.db) + + class User(Base, Fixture): + __tablename__ = 'users' + __autoload__ = True + addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + __autoload__ = True + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email=='two').one() + self.assertEquals(a1, Address(email='two')) + self.assertEquals(a1.user, User(name='u1')) + + def test_rekey(self): + meta = MetaData(testing.db) + + class User(Base, Fixture): + __tablename__ = 'users' + __autoload__ = True + nom = Column('name', String(50), key='nom') + addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + __autoload__ = True + + u1 = User(nom='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(nom='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email=='two').one() + self.assertEquals(a1, Address(email='two')) + self.assertEquals(a1.user, User(nom='u1')) + + self.assertRaises(TypeError, User, name='u3') + if __name__ == '__main__': - testing.main() \ No newline at end of file + testing.main()