From: Mike Bayer Date: Sat, 14 Jul 2007 21:57:51 +0000 (+0000) Subject: - improved ability to get the "correct" and most minimal set of primary key X-Git-Tag: rel_0_3_9~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8889d2c1bc7c527271909c0896e5d053c6aa369e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - improved ability to get the "correct" and most minimal set of primary key columns from a join, equating foreign keys and otherwise equated columns. this is also mostly to help inheritance scenarios formulate the best choice of primary key columns. [ticket:185] - added 'bind' argument to Sequence.create()/drop(), ColumnDefault.execute() --- diff --git a/CHANGES b/CHANGES index 0a1990ea1b..699fb9bcaf 100644 --- a/CHANGES +++ b/CHANGES @@ -13,6 +13,7 @@ InstrumentedList-like (e.g. over keys instead of values) - association proxies no longer bind tightly to source collections [ticket:597], and are constructed with a thunk instead + - added selectone_by() to assignmapper - orm - forwards-compatibility with 0.4: added one(), first(), and all() to Query. almost all Query functionality from 0.4 is @@ -56,6 +57,11 @@ - composite primary key is represented as a non-keyed set to allow for composite keys consisting of cols with the same name; occurs within a Join. helps inheritance scenarios formulate correct PK. + - improved ability to get the "correct" and most minimal set of primary key + columns from a join, equating foreign keys and otherwise equated columns. + this is also mostly to help inheritance scenarios formulate the best + choice of primary key columns. [ticket:185] + - added 'bind' argument to Sequence.create()/drop(), ColumnDefault.execute() - some enhancements to "column targeting", the ability to match a column to a "corresponding" column in another selectable. this affects mostly ORM ability to map to complex joins @@ -117,9 +123,6 @@ - fix port option handling for pyodbc [ticket:634] - now able to reflect start and increment values for identity columns - preliminary support for using scope_identity() with pyodbc - -- extensions - - added selectone_by() to assignmapper 0.3.8 - engines diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index c62ed33734..d9df0da90c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -59,13 +59,15 @@ class SchemaItem(object): """Return the engine or None if no engine.""" if raiseerr: - e = self._derived_metadata().bind + m = self._derived_metadata() + e = m and m.bind or None if e is None: raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") else: return e else: - return self._derived_metadata().bind + m = self._derived_metadata() + return m and m.bind or None def get_engine(self): """Return the engine or raise an error if no engine. @@ -280,7 +282,9 @@ class Table(SchemaItem, sql.TableClause): self.schema = kwargs.pop('schema', None) self.indexes = util.Set() self.constraints = util.Set() + self._columns = sql.ColumnCollection() self.primary_key = PrimaryKeyConstraint() + self._foreign_keys = util.OrderedSet() self.quote = kwargs.pop('quote', False) self.quote_schema = kwargs.pop('quote_schema', False) if self.schema is not None: @@ -298,6 +302,11 @@ class Table(SchemaItem, sql.TableClause): # store extra kwargs, which should only contain db-specific options self.kwargs = kwargs + def _export_columns(self, columns=None): + # override FromClause's collection initialization logic; TableClause and Table + # implement it differently + pass + def _get_case_sensitive_schema(self): try: return getattr(self, '_case_sensitive_schema') @@ -545,6 +554,14 @@ class Column(SchemaItem, sql._ColumnClause): def _get_engine(self): return self.table.bind + def references(self, column): + """return true if this column references the given column via foreign key""" + for fk in self.foreign_keys: + if fk.column is column: + return True + else: + return False + def append_foreign_key(self, fk): fk._set_parent(self) @@ -763,7 +780,7 @@ class DefaultGenerator(SchemaItem): def __init__(self, for_update=False, metadata=None): self.for_update = for_update - self._metadata = metadata + self._metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') def _derived_metadata(self): try: @@ -782,8 +799,10 @@ class DefaultGenerator(SchemaItem): else: self.column.default = self - def execute(self, **kwargs): - return self._get_engine(raiseerr=True).execute_default(self, **kwargs) + def execute(self, bind=None, **kwargs): + if bind is None: + bind = self._get_engine(raiseerr=True) + return bind.execute_default(self, **kwargs) def __repr__(self): return "DefaultGenerator()" @@ -845,12 +864,15 @@ class Sequence(DefaultGenerator): super(Sequence, self)._set_parent(column) column.sequence = self - def create(self): - self._get_engine(raiseerr=True).create(self) - return self + def create(self, bind=None): + if bind is None: + bind = self._get_engine(raiseerr=True) + bind.create(self) - def drop(self): - self._get_engine(raiseerr=True).drop(self) + def drop(self, bind=None): + if bind is None: + bind = self._get_engine(raiseerr=True) + bind.drop(self) def accept_visitor(self, visitor): """Call the visit_seauence method on the given visitor.""" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index d1fc3fef17..0961cd4ee3 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1764,7 +1764,7 @@ class FromClause(Selectable): """) oid_column = property(_get_oid_column) - def _export_columns(self): + def _export_columns(self, columns=None): """Initialize column collections. The collections include the primary key, foreign keys, list of @@ -1777,14 +1777,16 @@ class FromClause(Selectable): its parent ``Selectable`` is this ``FromClause``. """ - if hasattr(self, '_columns'): + if hasattr(self, '_columns') and columns is None: # TODO: put a mutex here ? this is a key place for threading probs return self._columns = ColumnCollection() self._primary_key = ColumnSet() self._foreign_keys = util.Set() self._orig_cols = {} - for co in self._adjusted_exportable_columns(): + if columns is None: + columns = self._adjusted_exportable_columns() + for co in columns: cp = self._proxy_column(co) for ci in cp.orig_set: cx = self._orig_cols.get(ci) @@ -2250,15 +2252,38 @@ class Join(FromClause): encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace')) def _init_primary_key(self): - pkcol = util.OrderedSet() - for col in self._adjusted_exportable_columns(): - if col.primary_key: - pkcol.add(col) - for col in list(pkcol): - for f in col.foreign_keys: - if f.column in pkcol: - pkcol.remove(col) - self.primary_key.extend(pkcol) + pkcol = util.Set([c for c in self._adjusted_exportable_columns() if c.primary_key]) + + equivs = {} + def add_equiv(a, b): + for x, y in ((a, b), (b, a)): + if x in equivs: + equivs[x].add(y) + else: + equivs[x] = util.Set([y]) + + class BinaryVisitor(ClauseVisitor): + def visit_binary(self, binary): + if binary.operator == '=': + add_equiv(binary.left, binary.right) + BinaryVisitor().traverse(self.onclause) + + for col in pkcol: + for fk in col.foreign_keys: + if fk.column in pkcol: + add_equiv(col, fk.column) + + omit = util.Set() + for col in pkcol: + p = col + for c in equivs.get(col, util.Set()): + if p.references(c) or (c.primary_key and not p.primary_key): + omit.add(p) + p = c + + self.__primary_key = ColumnSet([c for c in self._adjusted_exportable_columns() if c.primary_key and c not in omit]) + + primary_key = property(lambda s:s.__primary_key) def _locate_oid_column(self): return self.left.oid_column @@ -2333,7 +2358,11 @@ class Join(FromClause): collist.append(c) self.__folded_equivalents = collist return self.__folded_equivalents - + + folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, " + "equated columns folded into one column, where 'equated' means they are " + "equated to each other in the ON clause of this join.") + def select(self, whereclause = None, fold_equivalents=False, **kwargs): """Create a ``Select`` from this ``Join``. @@ -2353,7 +2382,7 @@ class Join(FromClause): """ if fold_equivalents: - collist = self._get_folded_equivalents() + collist = self.folded_equivalents else: collist = [self.left, self.right] @@ -2632,12 +2661,8 @@ class TableClause(FromClause): super(TableClause, self).__init__(name) self.name = self.fullname = name self.encodedname = self.name.encode('ascii', 'backslashreplace') - self._columns = ColumnCollection() - self._foreign_keys = util.OrderedSet() - self._primary_key = ColumnCollection() - for c in columns: - self.append_column(c) self._oid_column = _ColumnClause('oid', self, _is_oid=True) + self._export_columns(columns) def named_with_column(self): return True @@ -2649,6 +2674,10 @@ class TableClause(FromClause): def _locate_oid_column(self): return self._oid_column + def _proxy_column(self, c): + self.append_column(c) + return c + def _orig_columns(self): try: return self._orig_cols @@ -2666,6 +2695,7 @@ class TableClause(FromClause): return [c for c in self.c] else: return [] + def accept_visitor(self, visitor): visitor.visit_table(self) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index c827f1e7d6..b47822d613 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -10,6 +10,7 @@ except ImportError: import dummy_thread as thread import dummy_threading as threading +from sqlalchemy import exceptions import md5 import sys import warnings @@ -128,7 +129,16 @@ def duck_type_collection(col, default=None): return dict else: return default - + +def assert_arg_type(arg, argtype, name): + if isinstance(arg, argtype): + return arg + else: + if isinstance(argtype, tuple): + raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg)))) + else: + raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg)))) + def warn_exception(func): """executes the given function, catches all exceptions and converts to a warning.""" try: diff --git a/test/orm/inheritance.py b/test/orm/inheritance.py index 0458716e5a..d608b2387a 100644 --- a/test/orm/inheritance.py +++ b/test/orm/inheritance.py @@ -481,6 +481,80 @@ class InheritTest7(testbase.ORMTest): a.password = 'sadmin' sess.flush() assert user_roles.count().scalar() == 1 + +class InheritTest8(testbase.ORMTest): + """test the construction of mapper.primary_key when an inheriting relationship + joins on a column other than primary key column.""" + keep_data = True + + def define_tables(self, metadata): + global person_table, employee_table, Person, Employee + + person_table = Table("persons", metadata, + Column("id", Integer, primary_key=True), + Column("name", String(80)), + ) + + employee_table = Table("employees", metadata, + Column("id", Integer, primary_key=True), + Column("salary", Integer), + Column("person_id", Integer, ForeignKey("persons.id")), + ) + + class Person(object): + def __init__(self, name): + self.name = name + + class Employee(Person): pass + + def insert_data(self): + person_insert = person_table.insert() + person_insert.execute(id=1, name='alice') + person_insert.execute(id=2, name='bob') + + employee_insert = employee_table.insert() + employee_insert.execute(id=2, salary=250, person_id=1) # alice + employee_insert.execute(id=3, salary=200, person_id=2) # bob + + def test_implicit(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper) + print class_mapper(Employee).primary_key + assert list(class_mapper(Employee).primary_key) == [person_table.c.id, employee_table.c.id] + self._do_test(True) + + def test_explicit_props(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id}) + self._do_test(True) + + def test_explicit_composite_pk(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id]) + self._do_test(True) + + def test_explicit_pk(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id]) + self._do_test(False) + + def _do_test(self, composite): + session = create_session() + query = session.query(Employee) + + if composite: + alice1 = query.get([1,2]) + bob = query.get([2,3]) + alice2 = query.get([1,2]) + else: + alice1 = query.get(1) + bob = query.get(2) + alice2 = query.get(1) + + assert alice1.name == alice2.name == 'alice' + assert bob.name == 'bob' + + if __name__ == "__main__": testbase.main() diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 09c58d2c2f..10a3610f99 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -231,6 +231,16 @@ class SequenceTest(PersistTest): self.assert_(x == 1) finally: s.drop() + + @testbase.supported('postgres', 'oracle') + def teststandalone_explicit(self): + s = Sequence("my_sequence") + s.create(bind=testbase.db) + try: + x = s.execute(testbase.db) + self.assert_(x == 1) + finally: + s.drop(testbase.db) @testbase.supported('postgres', 'oracle') def teststandalone2(self): diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 853821f9af..ecd8253b8f 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -170,6 +170,61 @@ class SelectableTest(testbase.AssertMixin): print str(criterion) print str(j.onclause) self.assert_(criterion.compare(j.onclause)) + +class PrimaryKeyTest(testbase.AssertMixin): + def test_join_pk_collapse_implicit(self): + """test that redundant columns in a join get 'collapsed' into a minimal primary key, + which is the root column along a chain of foreign key relationships.""" + + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True)) + c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True)) + d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True)) + + assert c.c.id.references(b.c.id) + assert not d.c.id.references(a.c.id) + + assert list(a.join(b).primary_key) == [a.c.id] + assert list(b.join(c).primary_key) == [b.c.id] + assert list(a.join(b).join(c).primary_key) == [a.c.id] + assert list(b.join(c).join(d).primary_key) == [b.c.id] + assert list(d.join(c).join(b).primary_key) == [b.c.id] + assert list(a.join(b).join(c).join(d).primary_key) == [a.c.id] + + def test_join_pk_collapse_explicit(self): + """test that redundant columns in a join get 'collapsed' into a minimal primary key, + which is the root column along a chain of explicit join conditions.""" + + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) + c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True), Column('x', Integer)) + d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True), Column('x', Integer)) + + print list(a.join(b, a.c.x==b.c.id).primary_key) + assert list(a.join(b, a.c.x==b.c.id).primary_key) == [b.c.id] + assert list(b.join(c, b.c.x==c.c.id).primary_key) == [b.c.id] + assert list(a.join(b).join(c, c.c.id==b.c.x).primary_key) == [a.c.id] + assert list(b.join(c, c.c.x==b.c.id).join(d).primary_key) == [c.c.id] + assert list(b.join(c, c.c.id==b.c.x).join(d).primary_key) == [b.c.id] + assert list(d.join(b, d.c.id==b.c.id).join(c, b.c.id==c.c.x).primary_key) == [c.c.id] + assert list(a.join(b).join(c, c.c.id==b.c.x).join(d).primary_key) == [a.c.id] + + assert list(a.join(b, and_(a.c.id==b.c.id, a.c.x==b.c.id)).primary_key) == [a.c.id] + + def test_init_doesnt_blowitaway(self): + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) + + j = a.join(b) + assert list(j.primary_key) == [a.c.id] + + j.foreign_keys + assert list(j.primary_key) == [a.c.id] + + if __name__ == "__main__": testbase.main() diff --git a/test/testbase.py b/test/testbase.py index 54fcd9db23..7c5095d1a0 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -267,8 +267,11 @@ class ORMTest(AssertMixin): metadata = MetaData(db) self.define_tables(metadata) metadata.create_all() + self.insert_data() def define_tables(self, metadata): raise NotImplementedError() + def insert_data(self): + pass def get_metadata(self): return metadata def tearDownAll(self):