]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the "foreign_key" attribute on Column and ColumnElement in general
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 02:46:40 +0000 (02:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 02:46:40 +0000 (02:46 +0000)
    is deprecated, in favor of the "foreign_keys" list/set-based attribute,
    which takes into account multiple foreign keys on one column.
    "foreign_key" will return the first element in the "foreign_keys" list/set
    or None if the list is empty.
- added a user test to the relationships test, testing various new things this
change allows

CHANGES
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/engine/reflection.py
test/orm/relationships.py

diff --git a/CHANGES b/CHANGES
index 3abebef666c583dcf652367dab96b0e635858176..06268a1f4b68579d829343df54d787a79351a1ea 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     - fixed condition that occurred during reflection when a primary key
     column was explciitly overridden, where the PrimaryKeyConstraint would
     get both the reflected and the programmatic column doubled up
+    - the "foreign_key" attribute on Column and ColumnElement in general
+    is deprecated, in favor of the "foreign_keys" list/set-based attribute,
+    which takes into account multiple foreign keys on one column. 
+    "foreign_key" will return the first element in the "foreign_keys" list/set
+    or None if the list is empty.
 - Connections/Pooling/Execution:
     - connection pool tracks open cursors and automatically closes them
     if connection is returned to pool with cursors still opened.  Can be
index 2c29bbe2ad17bc8826497f668cd94758cff89010..88efcb755042a1b4d72690a83266d3c5fdb8fdfe 100644 (file)
@@ -444,7 +444,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
         if not column.nullable:
             colspec += " NOT NULL"
         if column.primary_key:
-            if not column.foreign_key and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer):
+            if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer):
                 colspec += " AUTO_INCREMENT"
         return colspec
 
index 6fe51ad9afd19fe902f0086a2bb7ba16e696fd1e..e052fe8c00ed89fbe90712f86f480ce0fee42ca7 100644 (file)
@@ -490,7 +490,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
         
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
-        if column.primary_key and not column.foreign_key and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+        if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
             colspec += " SERIAL"
         else:
             colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
index af39950390b0fb6f1b2f88b7c6cc96b79f34e6e2..2ad2c2b8c9d165d1b2843000c4167781b943652b 100644 (file)
@@ -216,7 +216,7 @@ class PropertyLoader(StrategizedProperty):
         elif len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None]):
             return sync.MANYTOONE
         else:
-            raise exceptions.ArgumentError("Cant determine relation direction '%s', for '%s' in mapper '%s' with primary join\n '%s'" %(repr(self.foreignkey), self.key, str(self.mapper), str(self.primaryjoin)))
+            raise exceptions.ArgumentError("Cant determine relation direction for '%s' in mapper '%s' with primary join\n '%s'" %(self.key, str(self.mapper), str(self.primaryjoin)))
             
     def _find_dependent(self):
         """searches through the primary join condition to determine which side
@@ -226,12 +226,16 @@ class PropertyLoader(StrategizedProperty):
         def foo(binary):
             if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                 return
-            if binary.left.foreign_key is not None and binary.left.foreign_key.references(binary.right.table):
-                foreignkeys.add(binary.left)
-            elif binary.right.foreign_key is not None and binary.right.foreign_key.references(binary.left.table):
-                foreignkeys.add(binary.right)
+            for f in binary.left.foreign_keys:
+                if f.references(binary.right.table):
+                    foreignkeys.add(binary.left)
+            for f in binary.right.foreign_keys:
+                if f.references(binary.left.table):
+                    foreignkeys.add(binary.right)
         visitor = mapperutil.BinaryVisitor(foo)
         self.primaryjoin.accept_visitor(visitor)
+        if len(foreignkeys) == 0:
+            raise exceptions.ArgumentError("On relation '%s', can't figure out which side is the foreign key for join condition '%s'.  Specify the 'foreignkey' argument to the relation." % (self.key, str(self.primaryjoin)))
         self.foreignkey = foreignkeys
         
     def get_join(self):
index 1d42095614249e4c5413caef9c4231f7b903a4d4..18d1d7b144cadfcb1f57f7ef5726baf5454beb23 100644 (file)
@@ -433,12 +433,12 @@ class Column(SchemaItem, sql.ColumnClause):
         self.__originating_column = self
         if self.index is not None and self.unique is not None:
             raise exceptions.ArgumentError("Column may not define both index and unique")
-        self._foreign_key = None
+        self._foreign_keys = util.Set()
         if len(kwargs):
             raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys()))
 
     primary_key = util.SimpleProperty('_primary_key')
-    foreign_key = util.SimpleProperty('_foreign_key')
+    foreign_keys = util.SimpleProperty('_foreign_keys')
     columns = property(lambda self:[self])
 
     def __str__(self):
@@ -459,7 +459,7 @@ class Column(SchemaItem, sql.ColumnClause):
     def __repr__(self):
        return "Column(%s)" % string.join(
         [repr(self.name)] + [repr(self.type)] +
-        [repr(x) for x in [self.foreign_key] if x is not None] +
+        [repr(x) for x in self.foreign_keys if x is not None] +
         ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']]
        , ',')
         
@@ -501,11 +501,8 @@ class Column(SchemaItem, sql.ColumnClause):
         
         This is a copy of this Column referenced 
         by a different parent (such as an alias or select statement)"""
-        if self.foreign_key is None:
-            fk = None
-        else:
-            fk = self.foreign_key.copy()
-        c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote)
+        fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
+        c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote, *fk)
         c.table = selectable
         c.orig_set = self.orig_set
         c.__originating_column = self.__originating_column
@@ -513,8 +510,7 @@ class Column(SchemaItem, sql.ColumnClause):
             selectable.columns[c.key] = c
             if self.primary_key:
                 selectable.primary_key.append(c)
-        if fk is not None:
-            c._init_items(fk)
+        [c._init_items(f) for f in fk]
         return c
 
     def _case_sens(self):
@@ -530,8 +526,8 @@ class Column(SchemaItem, sql.ColumnClause):
             self.default.accept_schema_visitor(visitor)
         if self.onupdate is not None:
             self.onupdate.accept_schema_visitor(visitor)
-        if self.foreign_key is not None:
-            self.foreign_key.accept_schema_visitor(visitor)
+        for f in self.foreign_keys:
+            f.accept_schema_visitor(visitor)
         visitor.visit_column(self)
 
 
@@ -631,11 +627,11 @@ class ForeignKey(SchemaItem):
 
         # if a foreign key was already set up for the parent column, replace it with 
         # this one
-        if self.parent.foreign_key is not None:
-            self.parent.table.foreign_keys.remove(self.parent.foreign_key)
-        self.parent.foreign_key = self
-        self.parent.table.foreign_keys.append(self)
-
+        #if self.parent.foreign_key is not None:
+        #    self.parent.table.foreign_keys.remove(self.parent.foreign_key)
+        #self.parent.foreign_key = self
+        self.parent.foreign_keys.add(self)
+        self.parent.table.foreign_keys.add(self)
 class DefaultGenerator(SchemaItem):
     """Base class for column "default" values."""
     def __init__(self, for_update=False, metadata=None):
index a07536bc9c4a9d4af27a64e1d0a1fe8315b06349..c113edaa32bded4ded14b66c91e3348065ad2f34 100644 (file)
@@ -618,8 +618,14 @@ class ColumnElement(Selectable, CompareMixin):
     may correspond to several TableClause-attached columns)."""
     
     primary_key = property(lambda self:getattr(self, '_primary_key', False), doc="primary key flag.  indicates if this Column represents part or whole of a primary key.")
-    foreign_key = property(lambda self:getattr(self, '_foreign_key', False), doc="foreign key accessor.  points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.")
+    foreign_keys = property(lambda self:getattr(self, '_foreign_keys', []), doc="foreign key accessor.  points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.")
     columns = property(lambda self:[self], doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.")
+    def _one_fkey(self):
+        if len(self._foreign_keys):
+            return list(self._foreign_keys)[0]
+        else:
+            return None
+    foreign_key = property(_one_fkey)
 
     def _get_orig_set(self):
         try:
@@ -731,7 +737,7 @@ class FromClause(Selectable):
             return
         self._columns = util.OrderedProperties()
         self._primary_key = []
-        self._foreign_keys = []
+        self._foreign_keys = util.Set()
         self._orig_cols = {}
         export = self._exportable_columns()
         for column in export:
@@ -1077,8 +1083,8 @@ class Join(FromClause):
         self._columns[column._label] = column
         if column.primary_key:
             self._primary_key.append(column)
-        if column.foreign_key:
-            self._foreign_keys.append(column.foreign_key)
+        for f in column.foreign_keys:
+            self._foreign_keys.add(f)
         return column
     def _match_primaries(self, primary, secondary):
         crit = []
@@ -1252,7 +1258,7 @@ class TableClause(FromClause):
         super(TableClause, self).__init__(name)
         self.name = self.fullname = name
         self._columns = util.OrderedProperties()
-        self._foreign_keys = []
+        self._foreign_keys = util.Set()
         self._primary_key = []
         for c in columns:
             self.append_column(c)
index 2e2c501278778c826c3e58c3ca76e0a8316a7e4a..45010111e4ad0ac2f207d117ef69071e902e817d 100644 (file)
@@ -71,7 +71,6 @@ class ReflectionTest(PersistTest):
             addresses = Table('engine_email_addresses', meta, autoload = True)
             # reference the addresses foreign key col, which will require users to be 
             # reflected at some point
-            print addresses.c.remote_user_id.foreign_key.column
             users = Table('engine_users', meta, autoload = True)
         finally:
             addresses.drop()
@@ -120,8 +119,8 @@ class ReflectionTest(PersistTest):
                 autoload=True)
             u2 = Table('users', meta2, autoload=True)
             
-            assert a2.c.user_id.foreign_key is not None
-            assert a2.c.user_id.foreign_key.parent is a2.c.user_id
+            assert len(a2.c.user_id.foreign_keys)>0
+            assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id
             assert u2.join(a2).onclause == u2.c.id==a2.c.user_id
 
             meta3 = BoundMetaData(testbase.db)
@@ -336,8 +335,8 @@ class ReflectionTest(PersistTest):
         assert table_c.c.description.nullable 
         assert table.primary_key is not table_c.primary_key
         assert [x.name for x in table.primary_key] == [x.name for x in table_c.primary_key]
-        assert table2_c.c.myid.foreign_key.column is table_c.c.myid
-        assert table2_c.c.myid.foreign_key.column is not table.c.myid
+        assert list(table2_c.c.myid.foreign_keys)[0].column is table_c.c.myid
+        assert list(table2_c.c.myid.foreign_keys)[0].column is not table.c.myid
         
     # mysql throws its own exception for no such table, resulting in 
     # a sqlalchemy.SQLError instead of sqlalchemy.NoSuchTableError.
index c6c2b5e8468a54b1909c8ab787a9d9d53c062441..dcf87c4f6866c42df4970ad5ed9e66c348826ad9 100644 (file)
@@ -100,6 +100,79 @@ class RelationTest(testbase.PersistTest):
         session.delete(c) # fails
         session.flush()
         
+class RelationTest2(testbase.PersistTest):
+    """this test tests a relationship on a column that is included in multiple foreign keys,
+    as well as a self-referential relationship on a composite key where one column in the foreign key
+    is 'joined to itself'."""
+    def setUpAll(self):
+        global metadata, company_tbl, employee_tbl
+        metadata = BoundMetaData(testbase.db)
+        
+        company_tbl = Table('company', metadata,
+             Column('company_id', Integer, primary_key=True),
+             Column('name', Unicode(30)))
+
+        employee_tbl = Table('employee', metadata,
+             Column('company_id', Integer, primary_key=True),
+             Column('emp_id', Integer, primary_key=True),
+             Column('name', Unicode(30)),
+            Column('reports_to_id', Integer),
+             ForeignKeyConstraint(['company_id'], ['company.company_id']),
+             ForeignKeyConstraint(['company_id', 'reports_to_id'],
+         ['employee.company_id', 'employee.emp_id']))
+        metadata.create_all()
+         
+    def tearDownAll(self):
+        metadata.drop_all()    
+
+    def testbasic(self):
+        class Company(object):
+            pass
+        class Employee(object):
+            def __init__(self, name, company, emp_id, reports_to=None):
+                self.name = name
+                self.company = company
+                self.emp_id = emp_id
+                self.reports_to = reports_to
+                
+        mapper(Company, company_tbl)
+        mapper(Employee, employee_tbl, properties= {
+            'company':relation(Company, primaryjoin=employee_tbl.c.company_id==company_tbl.c.company_id, backref='employees'),
+            'reports_to':relation(Employee, primaryjoin=
+                and_(
+                    employee_tbl.c.emp_id==employee_tbl.c.reports_to_id,
+                    employee_tbl.c.company_id==employee_tbl.c.company_id
+                ), 
+                foreignkey=[employee_tbl.c.company_id, employee_tbl.c.emp_id],
+                backref='employees')
+        })
+
+        sess = create_session()
+        c1 = Company()
+        c2 = Company()
+
+        e1 = Employee('emp1', c1, 1)
+        e2 = Employee('emp2', c1, 2, e1)
+        e3 = Employee('emp3', c1, 3, e1)
+        e4 = Employee('emp4', c1, 4, e3)
+        e5 = Employee('emp5', c2, 1)
+        e6 = Employee('emp6', c2, 2, e5)
+        e7 = Employee('emp7', c2, 3, e5)
+
+        [sess.save(x) for x in [c1,c2]]
+        sess.flush()
+        sess.clear()
+
+        test_c1 = sess.query(Company).get(c1.company_id)
+        test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id])
+        assert test_e1.name == 'emp1'
+        test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id])
+        assert test_e5.name == 'emp5'
+        assert [x.name for x in test_e1.employees] == ['emp2', 'emp3']
+        assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
+        assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
+        
+        
         
 if __name__ == "__main__":
     testbase.main()