]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
[ticket:1893] implementation
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Nov 2010 00:05:48 +0000 (19:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Nov 2010 00:05:48 +0000 (19:05 -0500)
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/schema.py
test/engine/test_metadata.py
test/ext/test_declarative.py

index dd2df63d30265021025c1caf29f1fd97552349d9..53438dc55444254fd4aa96f6b93b0d516cb2c7e2 100755 (executable)
@@ -924,7 +924,7 @@ Mapped instances then make usage of
 
 """
 
-from sqlalchemy.schema import Table, Column, MetaData
+from sqlalchemy.schema import Table, Column, MetaData, _get_table_key
 from sqlalchemy.orm import synonym as _orm_synonym, mapper,\
                                 comparable_property, class_mapper
 from sqlalchemy.orm.interfaces import MapperProperty
@@ -1257,8 +1257,8 @@ class DeclarativeMeta(type):
 class _GetColumns(object):
     def __init__(self, cls):
         self.cls = cls
-    def __getattr__(self, key):
         
+    def __getattr__(self, key):
         mapper = class_mapper(self.cls, compile=False)
         if mapper:
             prop = mapper.get_property(key, raiseerr=False)
@@ -1273,7 +1273,16 @@ class _GetColumns(object):
                             " directly to a Column)." % key)
         return getattr(self.cls, key)
 
-
+class _GetTable(object):
+    def __init__(self, key, metadata):
+        self.key = key
+        self.metadata = metadata
+    
+    def __getattr__(self, key):
+        return self.metadata.tables[
+                _get_table_key(key, self.key)
+            ]
+        
 def _deferred_relationship(cls, prop):
     def resolve_arg(arg):
         import sqlalchemy
@@ -1283,6 +1292,8 @@ def _deferred_relationship(cls, prop):
                 return _GetColumns(cls._decl_class_registry[key])
             elif key in cls.metadata.tables:
                 return cls.metadata.tables[key]
+            elif key in cls.metadata._schemas:
+                return _GetTable(key, cls.metadata)
             else:
                 return sqlalchemy.__dict__[key]
 
index a332cec36154300658e4cd08d39cf2cbb935db09..cb0cde4df2e091a8e954556ad90faeb104e0f869 100644 (file)
@@ -206,12 +206,13 @@ class Table(SchemaItem, expression.TableClause):
             if mustexist:
                 raise exc.InvalidRequestError(
                     "Table '%s' not defined" % (key))
-            metadata.tables[key] = table = object.__new__(cls)
+            table = object.__new__(cls)                    
+            metadata._add_table(name, schema, table)
             try:
                 table._init(name, metadata, *args, **kw)
                 return table
             except:
-                metadata.tables.pop(key)
+                metadata._remove_table(name, schema)
                 raise
                 
     def __init__(self, *args, **kw):
@@ -406,7 +407,7 @@ class Table(SchemaItem, expression.TableClause):
         self.ddl_listeners[event].append(listener)
 
     def _set_parent(self, metadata):
-        metadata.tables[_get_table_key(self.name, self.schema)] = self
+        metadata._add_table(self.name, self.schema, self)
         self.metadata = metadata
 
     def get_children(self, column_collections=True, 
@@ -1938,7 +1939,8 @@ class MetaData(SchemaItem):
           ``MetaData``.
 
         """
-        self.tables = {}
+        self.tables = util.frozendict()
+        self._schemas = set()
         self.bind = bind
         self.metadata = self
         self.ddl_listeners = util.defaultdict(list)
@@ -1957,6 +1959,20 @@ class MetaData(SchemaItem):
             table_or_key = table_or_key.key
         return table_or_key in self.tables
 
+    def _add_table(self, name, schema, table):
+        key = _get_table_key(name, schema)
+        dict.__setitem__(self.tables, key, table)
+        if schema:
+            self._schemas.add(schema)
+    
+    def _remove_table(self, name, schema):
+        key = _get_table_key(name, schema)
+        dict.pop(self.tables, key, None)
+        if self._schemas:
+            self._schemas = set([t.schema 
+                                for t in self.tables.values() 
+                                if t.schema is not None])
+        
     def __getstate__(self):
         return {'tables': self.tables}
 
@@ -1991,15 +2007,14 @@ class MetaData(SchemaItem):
 
     def clear(self):
         """Clear all Table objects from this MetaData."""
-        # TODO: why have clear()/remove() but not all
-        # other accesors/mutators for the tables dict ?
-        self.tables.clear()
 
+        dict.clear(self.tables)
+        self._schemas.clear()
+        
     def remove(self, table):
         """Remove the given Table object from this MetaData."""
         
-        # TODO: scan all other tables and remove FK _column
-        del self.tables[table.key]
+        self._remove_table(table.name, table.schema)
 
     @property
     def sorted_tables(self):
index b2250c808f38d5dd9c813828077d73a24ef776ba..59aa4c354311a6c26931d76682c05ec9828d3406 100644 (file)
@@ -77,6 +77,51 @@ class MetaDataTest(TestBase, ComparesTables):
             t = Table('foo%d' % i, m, cx)
         eq_(msgs, ['attach foo0.foo', 'attach foo1.foo', 'attach foo2.foo'])
         
+    def test_schema_collection_add(self):
+        metadata = MetaData()
+        
+        t1 = Table('t1', metadata, Column('x', Integer), schema='foo')
+        t2 = Table('t2', metadata, Column('x', Integer), schema='bar')
+        t3 = Table('t3', metadata, Column('x', Integer))
+        
+        eq_(metadata._schemas, set(['foo', 'bar']))
+        eq_(len(metadata.tables), 3)
+    
+    def test_schema_collection_remove(self):
+        metadata = MetaData()
+        
+        t1 = Table('t1', metadata, Column('x', Integer), schema='foo')
+        t2 = Table('t2', metadata, Column('x', Integer), schema='bar')
+        t3 = Table('t3', metadata, Column('x', Integer), schema='bar')
+        
+        metadata.remove(t3)
+        eq_(metadata._schemas, set(['foo', 'bar']))
+        eq_(len(metadata.tables), 2)
+
+        metadata.remove(t1)
+        eq_(metadata._schemas, set(['bar']))
+        eq_(len(metadata.tables), 1)
+    
+    def test_schema_collection_remove_all(self):
+        metadata = MetaData()
+        
+        t1 = Table('t1', metadata, Column('x', Integer), schema='foo')
+        t2 = Table('t2', metadata, Column('x', Integer), schema='bar')
+
+        metadata.clear()
+        eq_(metadata._schemas, set())
+        eq_(len(metadata.tables), 0)
+    
+    def test_metadata_tables_immutable(self):
+        metadata = MetaData()
+        
+        t1 = Table('t1', metadata, Column('x', Integer))
+        assert 't1' in metadata.tables
+        
+        assert_raises(
+            AttributeError,
+            lambda: metadata.tables.pop('t1')
+        )
         
     def test_dupe_tables(self):
         metadata = MetaData()
index 72e2edf30e9a8a9f336ee98263185447f3feea02..64dc81496c82720cb6489f381d5eb860b358a624 100644 (file)
@@ -299,6 +299,38 @@ class DeclarativeTest(DeclarativeTestBase):
         assert class_mapper(User).get_property('props').secondary \
             is user_to_prop
 
+    def test_string_dependency_resolution_schemas(self):
+        Base = decl.declarative_base()
+        
+        class User(Base):
+
+            __tablename__ = 'users'
+            __table_args__ = {'schema':'fooschema'}
+            
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            props = relationship('Prop', secondary='fooschema.user_to_prop',
+                         primaryjoin='User.id==fooschema.user_to_prop.c.user_id',
+                         secondaryjoin='fooschema.user_to_prop.c.prop_id==Prop.id', 
+                         backref='users')
+
+        class Prop(Base):
+
+            __tablename__ = 'props'
+            __table_args__ = {'schema':'fooschema'}
+
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+
+        user_to_prop = Table('user_to_prop', Base.metadata,
+                     Column('user_id', Integer, ForeignKey('fooschema.users.id')), 
+                     Column('prop_id',Integer, ForeignKey('fooschema.props.id')),
+                     schema='fooschema')
+        compile_mappers()
+        
+        assert class_mapper(User).get_property('props').secondary \
+            is user_to_prop
+
     def test_uncompiled_attributes_in_relationship(self):
 
         class Address(Base, ComparableEntity):