From: Mike Bayer Date: Sun, 8 Oct 2006 02:46:40 +0000 (+0000) Subject: - the "foreign_key" attribute on Column and ColumnElement in general X-Git-Tag: rel_0_3_0~76 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=51f16d14980c4a061ea3e224c52acf91008f0b20;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - the "foreign_key" attribute on Column and ColumnElement in general is deprecated, in favor of the "foreign_keys" list/set-based attribute, which takes into account multiple foreign keys on one column. "foreign_key" will return the first element in the "foreign_keys" list/set or None if the list is empty. - added a user test to the relationships test, testing various new things this change allows --- diff --git a/CHANGES b/CHANGES index 3abebef666..06268a1f4b 100644 --- a/CHANGES +++ b/CHANGES @@ -38,6 +38,11 @@ - fixed condition that occurred during reflection when a primary key column was explciitly overridden, where the PrimaryKeyConstraint would get both the reflected and the programmatic column doubled up + - the "foreign_key" attribute on Column and ColumnElement in general + is deprecated, in favor of the "foreign_keys" list/set-based attribute, + which takes into account multiple foreign keys on one column. + "foreign_key" will return the first element in the "foreign_keys" list/set + or None if the list is empty. - Connections/Pooling/Execution: - connection pool tracks open cursors and automatically closes them if connection is returned to pool with cursors still opened. Can be diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 2c29bbe2ad..88efcb7550 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -444,7 +444,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" if column.primary_key: - if not column.foreign_key and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer): + if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer): colspec += " AUTO_INCREMENT" return colspec diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 6fe51ad9af..e052fe8c00 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -490,7 +490,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - if column.primary_key and not column.foreign_key and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.engine_impl(self.engine).get_col_spec() diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index af39950390..2ad2c2b8c9 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -216,7 +216,7 @@ class PropertyLoader(StrategizedProperty): elif len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None]): return sync.MANYTOONE else: - raise exceptions.ArgumentError("Cant determine relation direction '%s', for '%s' in mapper '%s' with primary join\n '%s'" %(repr(self.foreignkey), self.key, str(self.mapper), str(self.primaryjoin))) + raise exceptions.ArgumentError("Cant determine relation direction for '%s' in mapper '%s' with primary join\n '%s'" %(self.key, str(self.mapper), str(self.primaryjoin))) def _find_dependent(self): """searches through the primary join condition to determine which side @@ -226,12 +226,16 @@ class PropertyLoader(StrategizedProperty): def foo(binary): if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return - if binary.left.foreign_key is not None and binary.left.foreign_key.references(binary.right.table): - foreignkeys.add(binary.left) - elif binary.right.foreign_key is not None and binary.right.foreign_key.references(binary.left.table): - foreignkeys.add(binary.right) + for f in binary.left.foreign_keys: + if f.references(binary.right.table): + foreignkeys.add(binary.left) + for f in binary.right.foreign_keys: + if f.references(binary.left.table): + foreignkeys.add(binary.right) visitor = mapperutil.BinaryVisitor(foo) self.primaryjoin.accept_visitor(visitor) + if len(foreignkeys) == 0: + raise exceptions.ArgumentError("On relation '%s', can't figure out which side is the foreign key for join condition '%s'. Specify the 'foreignkey' argument to the relation." % (self.key, str(self.primaryjoin))) self.foreignkey = foreignkeys def get_join(self): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 1d42095614..18d1d7b144 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -433,12 +433,12 @@ class Column(SchemaItem, sql.ColumnClause): self.__originating_column = self if self.index is not None and self.unique is not None: raise exceptions.ArgumentError("Column may not define both index and unique") - self._foreign_key = None + self._foreign_keys = util.Set() if len(kwargs): raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) primary_key = util.SimpleProperty('_primary_key') - foreign_key = util.SimpleProperty('_foreign_key') + foreign_keys = util.SimpleProperty('_foreign_keys') columns = property(lambda self:[self]) def __str__(self): @@ -459,7 +459,7 @@ class Column(SchemaItem, sql.ColumnClause): def __repr__(self): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + - [repr(x) for x in [self.foreign_key] if x is not None] + + [repr(x) for x in self.foreign_keys if x is not None] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']] , ',') @@ -501,11 +501,8 @@ class Column(SchemaItem, sql.ColumnClause): This is a copy of this Column referenced by a different parent (such as an alias or select statement)""" - if self.foreign_key is None: - fk = None - else: - fk = self.foreign_key.copy() - c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote) + fk = [ForeignKey(f._colspec) for f in self.foreign_keys] + c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden, quote=self.quote, *fk) c.table = selectable c.orig_set = self.orig_set c.__originating_column = self.__originating_column @@ -513,8 +510,7 @@ class Column(SchemaItem, sql.ColumnClause): selectable.columns[c.key] = c if self.primary_key: selectable.primary_key.append(c) - if fk is not None: - c._init_items(fk) + [c._init_items(f) for f in fk] return c def _case_sens(self): @@ -530,8 +526,8 @@ class Column(SchemaItem, sql.ColumnClause): self.default.accept_schema_visitor(visitor) if self.onupdate is not None: self.onupdate.accept_schema_visitor(visitor) - if self.foreign_key is not None: - self.foreign_key.accept_schema_visitor(visitor) + for f in self.foreign_keys: + f.accept_schema_visitor(visitor) visitor.visit_column(self) @@ -631,11 +627,11 @@ class ForeignKey(SchemaItem): # if a foreign key was already set up for the parent column, replace it with # this one - if self.parent.foreign_key is not None: - self.parent.table.foreign_keys.remove(self.parent.foreign_key) - self.parent.foreign_key = self - self.parent.table.foreign_keys.append(self) - + #if self.parent.foreign_key is not None: + # self.parent.table.foreign_keys.remove(self.parent.foreign_key) + #self.parent.foreign_key = self + self.parent.foreign_keys.add(self) + self.parent.table.foreign_keys.add(self) class DefaultGenerator(SchemaItem): """Base class for column "default" values.""" def __init__(self, for_update=False, metadata=None): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index a07536bc9c..c113edaa32 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -618,8 +618,14 @@ class ColumnElement(Selectable, CompareMixin): may correspond to several TableClause-attached columns).""" primary_key = property(lambda self:getattr(self, '_primary_key', False), doc="primary key flag. indicates if this Column represents part or whole of a primary key.") - foreign_key = property(lambda self:getattr(self, '_foreign_key', False), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.") + foreign_keys = property(lambda self:getattr(self, '_foreign_keys', []), doc="foreign key accessor. points to a ForeignKey object which represents a Foreign Key placed on this column's ultimate ancestor.") columns = property(lambda self:[self], doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.") + def _one_fkey(self): + if len(self._foreign_keys): + return list(self._foreign_keys)[0] + else: + return None + foreign_key = property(_one_fkey) def _get_orig_set(self): try: @@ -731,7 +737,7 @@ class FromClause(Selectable): return self._columns = util.OrderedProperties() self._primary_key = [] - self._foreign_keys = [] + self._foreign_keys = util.Set() self._orig_cols = {} export = self._exportable_columns() for column in export: @@ -1077,8 +1083,8 @@ class Join(FromClause): self._columns[column._label] = column if column.primary_key: self._primary_key.append(column) - if column.foreign_key: - self._foreign_keys.append(column.foreign_key) + for f in column.foreign_keys: + self._foreign_keys.add(f) return column def _match_primaries(self, primary, secondary): crit = [] @@ -1252,7 +1258,7 @@ class TableClause(FromClause): super(TableClause, self).__init__(name) self.name = self.fullname = name self._columns = util.OrderedProperties() - self._foreign_keys = [] + self._foreign_keys = util.Set() self._primary_key = [] for c in columns: self.append_column(c) diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 2e2c501278..45010111e4 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -71,7 +71,6 @@ class ReflectionTest(PersistTest): addresses = Table('engine_email_addresses', meta, autoload = True) # reference the addresses foreign key col, which will require users to be # reflected at some point - print addresses.c.remote_user_id.foreign_key.column users = Table('engine_users', meta, autoload = True) finally: addresses.drop() @@ -120,8 +119,8 @@ class ReflectionTest(PersistTest): autoload=True) u2 = Table('users', meta2, autoload=True) - assert a2.c.user_id.foreign_key is not None - assert a2.c.user_id.foreign_key.parent is a2.c.user_id + assert len(a2.c.user_id.foreign_keys)>0 + assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id assert u2.join(a2).onclause == u2.c.id==a2.c.user_id meta3 = BoundMetaData(testbase.db) @@ -336,8 +335,8 @@ class ReflectionTest(PersistTest): assert table_c.c.description.nullable assert table.primary_key is not table_c.primary_key assert [x.name for x in table.primary_key] == [x.name for x in table_c.primary_key] - assert table2_c.c.myid.foreign_key.column is table_c.c.myid - assert table2_c.c.myid.foreign_key.column is not table.c.myid + assert list(table2_c.c.myid.foreign_keys)[0].column is table_c.c.myid + assert list(table2_c.c.myid.foreign_keys)[0].column is not table.c.myid # mysql throws its own exception for no such table, resulting in # a sqlalchemy.SQLError instead of sqlalchemy.NoSuchTableError. diff --git a/test/orm/relationships.py b/test/orm/relationships.py index c6c2b5e846..dcf87c4f68 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -100,6 +100,79 @@ class RelationTest(testbase.PersistTest): session.delete(c) # fails session.flush() +class RelationTest2(testbase.PersistTest): + """this test tests a relationship on a column that is included in multiple foreign keys, + as well as a self-referential relationship on a composite key where one column in the foreign key + is 'joined to itself'.""" + def setUpAll(self): + global metadata, company_tbl, employee_tbl + metadata = BoundMetaData(testbase.db) + + company_tbl = Table('company', metadata, + Column('company_id', Integer, primary_key=True), + Column('name', Unicode(30))) + + employee_tbl = Table('employee', metadata, + Column('company_id', Integer, primary_key=True), + Column('emp_id', Integer, primary_key=True), + Column('name', Unicode(30)), + Column('reports_to_id', Integer), + ForeignKeyConstraint(['company_id'], ['company.company_id']), + ForeignKeyConstraint(['company_id', 'reports_to_id'], + ['employee.company_id', 'employee.emp_id'])) + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() + + def testbasic(self): + class Company(object): + pass + class Employee(object): + def __init__(self, name, company, emp_id, reports_to=None): + self.name = name + self.company = company + self.emp_id = emp_id + self.reports_to = reports_to + + mapper(Company, company_tbl) + mapper(Employee, employee_tbl, properties= { + 'company':relation(Company, primaryjoin=employee_tbl.c.company_id==company_tbl.c.company_id, backref='employees'), + 'reports_to':relation(Employee, primaryjoin= + and_( + employee_tbl.c.emp_id==employee_tbl.c.reports_to_id, + employee_tbl.c.company_id==employee_tbl.c.company_id + ), + foreignkey=[employee_tbl.c.company_id, employee_tbl.c.emp_id], + backref='employees') + }) + + sess = create_session() + c1 = Company() + c2 = Company() + + e1 = Employee('emp1', c1, 1) + e2 = Employee('emp2', c1, 2, e1) + e3 = Employee('emp3', c1, 3, e1) + e4 = Employee('emp4', c1, 4, e3) + e5 = Employee('emp5', c2, 1) + e6 = Employee('emp6', c2, 2, e5) + e7 = Employee('emp7', c2, 3, e5) + + [sess.save(x) for x in [c1,c2]] + sess.flush() + sess.clear() + + test_c1 = sess.query(Company).get(c1.company_id) + test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) + assert test_e1.name == 'emp1' + test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) + assert test_e5.name == 'emp5' + assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] + assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' + assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' + + if __name__ == "__main__": testbase.main()