From: Mike Bayer Date: Tue, 10 Jul 2007 06:51:58 +0000 (+0000) Subject: more "column targeting" enhancements..columns have a "depth" from their ultimate... X-Git-Tag: rel_0_3_9~35 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=be29010e292739ca3545315eb2e6a9243aa53e1a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git more "column targeting" enhancements..columns have a "depth" from their ultimate source column so that corresponding_column() can find the column that is "closest" (i.e. fewest levels of proxying) to the requested column --- diff --git a/CHANGES b/CHANGES index 606d9baba6..b281a8fb6c 100644 --- a/CHANGES +++ b/CHANGES @@ -39,6 +39,9 @@ - DynamicMetaData has been renamed to ThreadLocalMetaData. the DynamicMetaData name is deprecated and is an alias for ThreadLocalMetaData or a regular MetaData if threadlocal=False + - some enhancements to "column targeting", the ability to match a column + to a "corresponding" column in another selectable. this affects mostly + ORM ability to map to complex joins - MetaData and all SchemaItems are safe to use with pickle. slow table reflections can be dumped into a pickled file to be reused later. Just reconnect the engine to the metadata after unpickling. [ticket:619] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 923dd67977..3b3b9b7ed5 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -89,8 +89,8 @@ class TranslatingDict(dict): def __translate_col(self, col): ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False) - #if col is not ourcol: - # print "TD TRANSLATING ", col, "TO", ourcol +# if col is not ourcol and ourcol is not None: +# print "TD TRANSLATING ", col, "TO", ourcol if ourcol is None: return col else: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index eb7eb8c1d9..d7d728f2b6 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -600,11 +600,11 @@ class Column(SchemaItem, sql._ColumnClause): This is a copy of this ``Column`` referenced by a different parent (such as an alias or select statement). """ - fk = [ForeignKey(f._colspec) for f in self.foreign_keys] c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk) c.table = selectable c.orig_set = self.orig_set + c._source_column = self c.__originating_column = self.__originating_column if not c._is_oid: selectable.columns.add(c) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c86fc561a6..2a22a40c13 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1542,7 +1542,17 @@ class ColumnElement(Selectable, _CompareMixin): return True else: return False - + + def _distance(self, othercolumn): + c = othercolumn + count = 0 + while c is not self: + c = c._source_column + if c is None: + return -1 + count += 1 + return count + def _make_proxy(self, selectable, name=None): """Create a new ``ColumnElement`` representing this ``ColumnElement`` as it appears in the select list of a @@ -1695,7 +1705,7 @@ class FromClause(Selectable): """ if column in self.c: return column - + if require_embedded and column not in util.Set(self._get_all_embedded_columns()): if not raiseerr: return None @@ -1757,9 +1767,9 @@ class FromClause(Selectable): for co in self._adjusted_exportable_columns(): cp = self._proxy_column(co) for ci in cp.orig_set: - # note that some ambiguity is raised here, whereby a selectable might have more than - # one column that maps to an "original" column. examples include unions and joins - self._orig_cols[ci] = cp + cx = self._orig_cols.get(ci) + if cx is None or ci._distance(cp) < ci._distance(cx): + self._orig_cols[ci] = cp if self.oid_column is not None: for ci in self.oid_column.orig_set: self._orig_cols[ci] = self.oid_column @@ -2078,7 +2088,8 @@ class _Cast(ColumnElement): self.type = sqltypes.to_instance(totype) self.clause = clause self.typeclause = _TypeClause(self.type) - + self._source_column = None + def get_children(self, **kwargs): return self.clause, self.typeclause def accept_visitor(self, visitor): @@ -2090,6 +2101,7 @@ class _Cast(ColumnElement): def _make_proxy(self, selectable, name=None): if name is not None: co = _ColumnClause(name, selectable, type=self.type) + co._source_column = self co.orig_set = self.orig_set selectable.columns[name]= co return co @@ -2512,6 +2524,7 @@ class _ColumnClause(ColumnElement): self.table = selectable self.type = sqltypes.to_instance(type) self._is_oid = _is_oid + self._source_column = None self.__label = None self.case_sensitive = case_sensitive self.is_literal = is_literal @@ -2571,6 +2584,7 @@ class _ColumnClause(ColumnElement): is_literal = self.is_literal and (name is None or name == self.name) c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal) c.orig_set = self.orig_set + c._source_column = self if not self._is_oid: selectable.columns[c.name] = c return c diff --git a/test/orm/eagertest3.py b/test/orm/eagertest3.py index a731581d55..8e77358124 100644 --- a/test/orm/eagertest3.py +++ b/test/orm/eagertest3.py @@ -415,7 +415,101 @@ class EagerTest5(testbase.ORMTest): # object is not in the session; therefore the lazy load cant trigger here, # eager load had to succeed assert len([c for c in d2.comments]) == 1 + +class EagerTest6(testbase.ORMTest): + def define_tables(self, metadata): + global project_t, task_t, task_status_t, task_type_t, message_t, message_type_t - + project_t = Table('prj', metadata, + Column('id', Integer, primary_key=True), + Column('created', DateTime , ), + Column('title', Unicode(100)), + ) + + task_t = Table('task', metadata, + Column('id', Integer, primary_key=True), + Column('status_id', Integer, ForeignKey('task_status.id'), nullable=False), + Column('title', Unicode(100)), + Column('task_type_id', Integer , ForeignKey('task_type.id'), nullable=False), + Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False), + ) + + task_status_t = Table('task_status', metadata, + Column('id', Integer, primary_key=True), + ) + + task_type_t = Table('task_type', metadata, + Column('id', Integer, primary_key=True), + ) + + message_t = Table('msg', metadata, + Column('id', Integer, primary_key=True), + Column('posted', DateTime, index=True,), + Column('type_id', Integer, ForeignKey('msg_type.id')), + Column('task_id', Integer, ForeignKey('task.id')), + ) + + message_type_t = Table('msg_type', metadata, + Column('id', Integer, primary_key=True), + Column('name', Unicode(20)), + Column('display_name', Unicode(20)), + ) + + def setUp(self): + testbase.db.execute("INSERT INTO prj (title) values('project 1');") + testbase.db.execute("INSERT INTO task_status (id) values(1);") + testbase.db.execute("INSERT INTO task_type(id) values(1);") + testbase.db.execute("INSERT INTO task (title, task_type_id, status_id, prj_id) values('task 1',1,1,1);") + + def test_nested_joins(self): + # this is testing some subtle column resolution stuff, + # concerning corresponding_column() being extremely accurate + # as well as how mapper sets up its column properties + + class Task(object):pass + class Task_Type(object):pass + class Message(object):pass + class Message_Type(object):pass + + tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id) + + ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], + from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s') + j = join(project_t, ss, project_t.c.id == ss.c.prj_id) + + mapper(Task_Type, task_type_t) + + mapper( Task, task_t, + properties=dict(type=relation(Task_Type, lazy=False), + )) + + mapper(Message_Type, message_type_t) + + mapper(Message, message_t, + properties=dict(type=relation(Message_Type, lazy=False, uselist=False), + )) + + tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id) + ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], + from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s') + j = join(project_t, ss, project_t.c.id == ss.c.prj_id) + + j = outerjoin( task_t, message_t, task_t.c.id==message_t.c.task_id) + jj = select([ task_t.c.id.label('task_id'), + func.count(message_t.c.id).label('props_cnt')], + from_obj=[j], group_by=[task_t.c.id]).alias('prop_c_s') + jjj = join(task_t, jj, task_t.c.id == jj.c.task_id) + + class cls(object):pass + + props =dict(type=relation(Task_Type, lazy=False)) + print [c.key for c in jjj.c] + cls.mapper = mapper( cls, jjj, properties=props) + + session = create_session() + + for t in session.query(cls.mapper).limit(10).offset(0).list(): + print t.id, t.title, t.props_cnt + if __name__ == "__main__": testbase.main() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 63af53b967..197f988b7a 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -497,7 +497,7 @@ class MapperTest(MapperSuperTest): class_mapper(User) except exceptions.ArgumentError, e: assert str(e) == "Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(f) - clear_mappers() + clear_mappers() mapper(User, users, properties={ 'concat': column_property(f), diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 57ad998860..340c55837d 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -27,17 +27,42 @@ table2 = Table('table2', db, ) class SelectableTest(testbase.AssertMixin): + def testdistance(self): + s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) + + # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far + #assert s.corresponding_column(table.c.col1) is s.c.col1 + assert s.corresponding_column(s.c.col1) is s.c.col1 + assert s.corresponding_column(s.c.c1) is s.c.c1 + def testjoinagainstself(self): jj = select([table.c.col1.label('bar_col1')]) jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + + # test column directly agaisnt itself assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1 + + # test alias of the join, targets the column with the least + # "distance" between the requested column and the returned column + # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than + # there is from j2.c.bar_col1 to table.c.col1) + j2 = jjj.alias('foo') + assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1 + + def testjoinagainstjoin(self): j = outerjoin(table, table2, table.c.col1==table2.c.col2) jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') jjj = join(table, jj, table.c.col1==jj.c.bar_col1) assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + + j2 = jjj.alias('foo') + print j2.corresponding_column(jjj.c.table1_col1) + assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1 + assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 def testtablealias(self): a = table.alias('a') @@ -110,8 +135,8 @@ class SelectableTest(testbase.AssertMixin): j = join(a, table2) criterion = a.c.col1 == table2.c.col2 - print - print str(j) + print criterion + print j.onclause self.assert_(criterion.compare(j.onclause)) def testselectlabels(self):