From 08d825a39dce16becb176b3c32f4adb2dad6886a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 14 Feb 2009 23:02:33 +0000 Subject: [PATCH] - merged -r5727:5797 of trunk - newest pg8000 handles unicode statements correctly. --- CHANGES | 71 ++++++ .../custom_attributes/listen_for_events.py | 3 +- lib/sqlalchemy/dialects/postgres/base.py | 9 +- lib/sqlalchemy/dialects/postgres/pg8000.py | 2 +- lib/sqlalchemy/ext/sqlsoup.py | 6 +- lib/sqlalchemy/orm/attributes.py | 13 +- lib/sqlalchemy/orm/evaluator.py | 4 +- lib/sqlalchemy/orm/interfaces.py | 60 ++--- lib/sqlalchemy/orm/properties.py | 60 +++-- lib/sqlalchemy/orm/query.py | 176 +++++++------- lib/sqlalchemy/orm/unitofwork.py | 2 +- lib/sqlalchemy/orm/util.py | 11 +- lib/sqlalchemy/sql/compiler.py | 22 +- lib/sqlalchemy/sql/expression.py | 214 +++++++++--------- lib/sqlalchemy/sql/functions.py | 9 +- lib/sqlalchemy/sql/operators.py | 2 +- lib/sqlalchemy/util.py | 5 +- test/dialect/mssql.py | 55 +++++ test/dialect/postgres.py | 11 +- test/orm/cascade.py | 111 +++++++++ test/orm/expire.py | 48 +++- test/orm/inheritance/query.py | 107 ++++++++- test/orm/query.py | 48 +++- test/orm/relationships.py | 75 +++++- test/profiling/memusage.py | 2 + test/sql/functions.py | 6 +- test/sql/generative.py | 7 + test/sql/labels.py | 45 +++- test/sql/select.py | 54 ++++- 29 files changed, 932 insertions(+), 306 deletions(-) diff --git a/CHANGES b/CHANGES index 36c8398c0a..1603ce6594 100644 --- a/CHANGES +++ b/CHANGES @@ -3,6 +3,77 @@ ======= CHANGES ======= +0.5.3 +===== +- orm + - Query now implements __clause_element__() which produces + its selectable, which means a Query instance can be accepted + in many SQL expressions, including col.in_(query), + union(query1, query2), select([foo]).select_from(query), + etc. + + - a session.expire() on a particular collection attribute + will clear any pending backref additions as well, so that + the next access correctly returns only what was present + in the database. Presents some degree of a workaround for + [ticket:1315], although we are considering removing the + flush([objects]) feature altogether. + + - improvements to the "determine direction" logic of + relation() such that the direction of tricky situations + like mapper(A.join(B)) -> relation-> mapper(B) can be + determined. + + - When flushing partial sets of objects using session.flush([somelist]), + pending objects which remain pending after the operation won't + inadvertently be added as persistent. [ticket:1306] + + - Added "post_configure_attribute" method to InstrumentationManager, + so that the "listen_for_events.py" example works again. + [ticket:1314] + + - Fixed bugs in Query regarding simultaneous selection of + multiple joined-table inheritance entities with common base + classes: + + - previously the adaption applied to "B" on + "A JOIN B" would be erroneously partially applied + to "A". + + - comparisons on relations (i.e. A.related==someb) + were not getting adapted when they should. + + - Other filterings, like + query(A).join(A.bs).filter(B.foo=='bar'), were erroneously + adapting "B.foo" as though it were an "A". + +- sql + - Fixed missing _label attribute on Function object, others + when used in a select() with use_labels (such as when used + in an ORM column_property()). [ticket:1302] + + - anonymous alias names now truncate down to the max length + allowed by the dialect. More significant on DBs like + Oracle with very small character limits. [ticket:1309] + + - the __selectable__() interface has been replaced entirely + by __clause_element__(). + + - The per-dialect cache used by TypeEngine to cache + dialect-specific types is now a WeakKeyDictionary. + This to prevent dialect objects from + being referenced forever for an application that + creates an arbitrarily large number of engines + or dialects. There is a small performance penalty + which will be resolved in 0.6. [ticket:1299] + +- postgres + - Index reflection won't fail when an index with + multiple expressions is encountered. + +- mssql + - Preliminary support for pymssql 1.0.1 + 0.5.2 ====== diff --git a/examples/custom_attributes/listen_for_events.py b/examples/custom_attributes/listen_for_events.py index c028e0fb48..de28df5b3a 100644 --- a/examples/custom_attributes/listen_for_events.py +++ b/examples/custom_attributes/listen_for_events.py @@ -7,11 +7,10 @@ across the board. from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager class InstallListeners(InstrumentationManager): - def instrument_attribute(self, class_, key, inst): + def post_configure_attribute(self, class_, key, inst): """Add an event listener to an InstrumentedAttribute.""" inst.impl.extensions.insert(0, AttributeListener(key)) - return super(InstallListeners, self).instrument_attribute(class_, key, inst) class AttributeListener(AttributeExtension): """Generic event listener. diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 7db0dd8823..705778cc5b 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -823,10 +823,11 @@ class PGDialect(default.DefaultDialect): sv_idx_name = None for row in c.fetchall(): idx_name, unique, expr, prd, col = row - if expr and not idx_name == sv_idx_name: - util.warn( - "Skipped unsupported reflection of expression-based index %s" - % idx_name) + if expr: + if idx_name != sv_idx_name: + util.warn( + "Skipped unsupported reflection of expression-based index %s" + % idx_name) sv_idx_name = idx_name continue if prd and not idx_name == sv_idx_name: diff --git a/lib/sqlalchemy/dialects/postgres/pg8000.py b/lib/sqlalchemy/dialects/postgres/pg8000.py index 00636dbfe4..47ccab3f8b 100644 --- a/lib/sqlalchemy/dialects/postgres/pg8000.py +++ b/lib/sqlalchemy/dialects/postgres/pg8000.py @@ -45,7 +45,7 @@ class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext): class Postgres_pg8000(PGDialect): driver = 'pg8000' - supports_unicode_statements = False + supports_unicode_statements = True supports_unicode_binds = True diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 37b9d8fa89..f2754793b0 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -397,7 +397,7 @@ class SelectableClassType(type): def update(cls, whereclause=None, values=None, **kwargs): _ddl_error(cls) - def __selectable__(cls): + def __clause_element__(cls): return cls._table def __getattr__(cls, attr): @@ -442,7 +442,7 @@ def _selectable_name(selectable): return x def class_for_table(selectable, **mapper_kwargs): - selectable = expression._selectable(selectable) + selectable = expression._clause_element_as_expr(selectable) mapname = 'Mapped' + _selectable_name(selectable) if isinstance(mapname, unicode): engine_encoding = selectable.metadata.bind.dialect.encoding @@ -531,7 +531,7 @@ class SqlSoup: def with_labels(self, item): # TODO give meaningful aliases - return self.map(expression._selectable(item).select(use_labels=True).alias('foo')) + return self.map(expression._clause_element_as_expr(item).select(use_labels=True).alias('foo')) def join(self, *args, **kwargs): j = join(*args, **kwargs) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 446c55b41e..e3901f9b10 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -758,6 +758,7 @@ class CollectionAttributeImpl(AttributeImpl): state.commit([self.key]) if self.key in state.pending: + # pending items exist. issue a modified event, # add/remove new items. state.modified_event(self, True, user_data) @@ -1027,6 +1028,7 @@ class InstanceState(object): if impl.accepts_scalar_loader: self.callables[key] = self self.dict.pop(key, None) + self.pending.pop(key, None) self.committed_state.pop(key, None) def reset(self, key): @@ -1201,6 +1203,9 @@ class ClassManager(dict): manager = create_manager_for_cls(cls) manager.instrument_attribute(key, inst, True) + def post_configure_attribute(self, key): + pass + def uninstrument_attribute(self, key, propagated=False): if key not in self: return @@ -1354,6 +1359,9 @@ class _ClassInstrumentationAdapter(ClassManager): if not propagated: self._adapted.instrument_attribute(self.class_, key, inst) + def post_configure_attribute(self, key): + self._adapted.post_configure_attribute(self.class_, key, self[key]) + def install_descriptor(self, key, inst): self._adapted.install_descriptor(self.class_, key, inst) @@ -1579,9 +1587,10 @@ def register_attribute_impl(class_, key, **kw): key, factory or list) else: typecallable = kw.pop('typecallable', None) - + manager[key].impl = _create_prop(class_, key, manager, typecallable=typecallable, **kw) - + manager.post_configure_attribute(key) + def register_descriptor(class_, key, proxy_property=None, comparator=None, parententity=None, property_=None): manager = manager_of_class(class_) diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 8628c8a239..e788913531 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -34,8 +34,8 @@ class EvaluatorCompiler(object): return lambda obj: None def visit_column(self, clause): - if 'parententity' in clause._annotations: - key = clause._annotations['parententity']._get_col_to_prop(clause).key + if 'parentmapper' in clause._annotations: + key = clause._annotations['parentmapper']._get_col_to_prop(clause).key else: key = clause.key get_corresponding_attr = operator.attrgetter(key) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6c3c3b1bae..3b7507def6 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -8,7 +8,7 @@ Semi-private module containing various base classes used throughout the ORM. -Defines the extension classes :class:`MapperExtension`, +Defines the extension classes :class:`MapperExtension`, :class:`SessionExtension`, and :class:`AttributeExtension` as well as other user-subclassable extension objects. @@ -167,7 +167,7 @@ class MapperExtension(object): ``__new__``, and after initial attribute population has occurred. - This typicically occurs when the instance is created based on + This typically occurs when the instance is created based on incoming result rows, and is only called once for that instance's lifetime. @@ -325,7 +325,7 @@ class SessionExtension(object): `query_context` was the query context object. `result` is the result object returned from the bulk operation. """ - + class MapperProperty(object): """Manage the relationship of a ``Mapper`` to a single class attribute, as well as that attribute as it appears on individual @@ -394,36 +394,36 @@ class MapperProperty(object): def instrument_class(self, mapper): raise NotImplementedError() - + _compile_started = False _compile_finished = False - + def init(self): """Called after all mappers are created to assemble relationships between mappers and perform other post-mapper-creation - initialization steps. - + initialization steps. + """ self._compile_started = True self.do_init() self._compile_finished = True - + def do_init(self): """Perform subclass-specific initialization post-mapper-creation steps. This is a *template* method called by the ``MapperProperty`` object's init() method. - + """ pass - + def post_instrument_class(self, mapper): """Perform instrumentation adjustments that need to occur after init() has completed. - + """ pass - + def register_dependencies(self, *args, **kwargs): """Called by the ``Mapper`` in response to the UnitOfWork calling the ``Mapper``'s register_dependencies operation. @@ -482,10 +482,10 @@ class PropComparator(expression.ColumnOperators): def adapted(self, adapter): """Return a copy of this PropComparator which will use the given adaption function on the local side of generated expressions. - + """ return self.__class__(self.prop, self.mapper, adapter) - + @staticmethod def any_op(a, b, **kwargs): return a.any(b, **kwargs) @@ -589,7 +589,7 @@ class StrategizedProperty(MapperProperty): def post_instrument_class(self, mapper): if self.is_primary(): self.strategy.init_class_attribute(mapper) - + def build_path(entity, key, prev=None): if prev: return prev + (entity, key) @@ -738,35 +738,35 @@ class PropertyOption(MapperOption): class AttributeExtension(object): """An event handler for individual attribute change events. - - AttributeExtension is assembled within the descriptors associated - with a mapped class. - + + AttributeExtension is assembled within the descriptors associated + with a mapped class. + """ def append(self, state, value, initiator): """Receive a collection append event. - + The returned value will be used as the actual value to be appended. - + """ return value def remove(self, state, value, initiator): """Receive a remove event. - + No return value is defined. - + """ pass def set(self, state, value, oldvalue, initiator): """Receive a set event. - + The returned value will be used as the actual value to be set. - + """ return value @@ -855,7 +855,12 @@ class LoaderStrategy(object): return fn class InstrumentationManager(object): - """User-defined class instrumentation extension.""" + """User-defined class instrumentation extension. + + The API for this class should be considered as semi-stable, + and may change slightly with new releases. + + """ # r4361 added a mandatory (cls) constructor to this interface. # given that, perhaps class_ should be dropped from all of these @@ -878,6 +883,9 @@ class InstrumentationManager(object): def instrument_attribute(self, class_, key, inst): pass + def post_configure_attribute(self, class_, key, inst): + pass + def install_descriptor(self, class_, key, inst): setattr(class_, key, inst) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index f05613f5c0..2a772dcac2 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -121,7 +121,7 @@ class ColumnProperty(StrategizedProperty): if self.adapter: return self.adapter(self.prop.columns[0]) else: - return self.prop.columns[0]._annotate({"parententity": self.mapper}) + return self.prop.columns[0]._annotate({"parententity": self.mapper, "parentmapper":self.mapper}) def operate(self, op, *other, **kwargs): return op(self.__clause_element__(), *other, **kwargs) @@ -417,7 +417,7 @@ class RelationProperty(StrategizedProperty): if backref: raise sa_exc.ArgumentError("backref and back_populates keyword arguments are mutually exclusive") self.backref = None - elif isinstance(backref, str): + elif isinstance(backref, basestring): # propagate explicitly sent primary/secondary join conditions to the BackRef object if # just a string was sent if secondary is not None: @@ -485,11 +485,11 @@ class RelationProperty(StrategizedProperty): if self.property.direction in [ONETOMANY, MANYTOMANY]: return ~self._criterion_exists() else: - return self.property._optimized_compare(None, adapt_source=self.adapter) + return _orm_annotate(self.property._optimized_compare(None, adapt_source=self.adapter)) elif self.property.uselist: raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") else: - return self.property._optimized_compare(other, adapt_source=self.adapter) + return _orm_annotate(self.property._optimized_compare(other, adapt_source=self.adapter)) def _criterion_exists(self, criterion=None, **kwargs): if getattr(self, '_of_type', None): @@ -889,29 +889,45 @@ class RelationProperty(StrategizedProperty): self.direction = MANYTOONE else: - for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]: - onetomany = [c for c in self._foreign_keys if mappedtable.c.contains_column(c)] - manytoone = [c for c in self._foreign_keys if parenttable.c.contains_column(c)] - - if not onetomany and not manytoone: - raise sa_exc.ArgumentError( - "Can't determine relation direction for relationship '%s' " - "- foreign key columns are present in neither the " - "parent nor the child's mapped tables" %(str(self))) - elif onetomany and manytoone: - continue - elif onetomany: + foreign_keys = [f for c, f in self.synchronize_pairs] + + parentcols = util.column_set(self.parent.mapped_table.c) + targetcols = util.column_set(self.mapper.mapped_table.c) + + # fk collection which suggests ONETOMANY. + onetomany_fk = targetcols.intersection(foreign_keys) + + # fk collection which suggests MANYTOONE. + manytoone_fk = parentcols.intersection(foreign_keys) + + if not onetomany_fk and not manytoone_fk: + raise sa_exc.ArgumentError( + "Can't determine relation direction for relationship '%s' " + "- foreign key columns are present in neither the " + "parent nor the child's mapped tables" % self ) + + elif onetomany_fk and manytoone_fk: + # fks on both sides. do the same + # test only based on the local side. + referents = [c for c, f in self.synchronize_pairs] + onetomany_local = parentcols.intersection(referents) + manytoone_local = targetcols.intersection(referents) + + if onetomany_local and not manytoone_local: self.direction = ONETOMANY - break - elif manytoone: + elif manytoone_local and not onetomany_local: self.direction = MANYTOONE - break - else: + elif onetomany_fk: + self.direction = ONETOMANY + elif manytoone_fk: + self.direction = MANYTOONE + + if not self.direction: raise sa_exc.ArgumentError( "Can't determine relation direction for relationship '%s' " "- foreign key columns are present in both the parent and " "the child's mapped tables. Specify 'foreign_keys' " - "argument." % (str(self))) + "argument." % self) if self.cascade.delete_orphan and not self.single_parent and \ (self.direction is MANYTOMANY or self.direction is MANYTOONE): @@ -1001,7 +1017,7 @@ class RelationProperty(StrategizedProperty): def _refers_to_parent_table(self): - return self.parent.mapped_table is self.target or self.parent.mapped_table is self.target + return self.parent.mapped_table is self.target def _is_self_referential(self): return self.mapper.common_parent(self.parent) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index de7f66882b..db9ce1d676 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -43,7 +43,7 @@ aliased = AliasedClass def _generative(*assertions): """Mark a method as generative.""" - + @util.decorator def generate(fn, *args, **kw): self = args[0]._clone() @@ -127,6 +127,7 @@ class Query(object): def __mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: + self._polymorphic_adapters[m2] = adapter for m in m2.iterate_to_root(): self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter @@ -139,12 +140,13 @@ class Query(object): if isinstance(from_obj, expression.Alias): self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj, equivs) - + def _get_polymorphic_adapter(self, entity, selectable): self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns)) def _reset_polymorphic_adapter(self, mapper): for m2 in mapper._with_polymorphic_mappers: + self._polymorphic_adapters.pop(m2, None) for m in m2.iterate_to_root(): self._polymorphic_adapters.pop(m.mapped_table, None) self._polymorphic_adapters.pop(m.local_table, None) @@ -282,7 +284,7 @@ class Query(object): if self._order_by: raise sa_exc.InvalidRequestError("Query.%s() being called on a Query with existing criterion. " % meth) self.__no_criterion_condition(meth) - + def __no_statement_condition(self, meth): if self._statement: raise sa_exc.InvalidRequestError( @@ -317,37 +319,35 @@ class Query(object): @property def statement(self): """The full SELECT statement represented by this Query.""" - - return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True}) - @property - def _nested_statement(self): - return self.with_labels().enable_eagerloads(False).statement.correlate(None) + return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True}) def subquery(self): """return the full SELECT statement represented by this Query, embedded within an Alias. - + Eager JOIN generation within the query is disabled. - - """ + """ return self.enable_eagerloads(False).statement.alias() + def __clause_element__(self): + return self.enable_eagerloads(False).statement + @_generative() def enable_eagerloads(self, value): """Control whether or not eager joins are rendered. - - When set to False, the returned Query will not render + + When set to False, the returned Query will not render eager joins regardless of eagerload() options or mapper-level lazy=False configurations. - + This is used primarily when nesting the Query's statement into a subquery or other selectable. - + """ self._enable_eagerloads = value - + @_generative() def with_labels(self): """Apply column labels to the return value of Query.statement. @@ -410,7 +410,7 @@ class Query(object): attribute of the mapper will be used, if any. This is useful for mappers that don't have polymorphic loading behavior by default, such as concrete table mappers. - + """ entity = self._generate_mapper_zero() entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, discriminator=discriminator) @@ -554,7 +554,7 @@ class Query(object): those being selected. """ - fromclause = self._nested_statement + fromclause = self.with_labels().enable_eagerloads(False).statement.correlate(None) q = self._from_selectable(fromclause) if entities: q._set_entities(entities) @@ -728,27 +728,27 @@ class Query(object): q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo') q3 = q1.union(q2) - + The method accepts multiple Query objects so as to control the level of nesting. A series of ``union()`` calls such as:: - + x.union(y).union(z).all() - + will nest on each ``union()``, and produces:: - + SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y) UNION SELECT * FROM Z) - + Whereas:: - + x.union(y, z).all() - + produces:: SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION SELECT * FROM Z) """ return self._from_selectable( - expression.union(*([self._nested_statement]+ [x._nested_statement for x in q]))) + expression.union(*([self]+ list(q)))) def union_all(self, *q): """Produce a UNION ALL of this Query against one or more queries. @@ -758,7 +758,7 @@ class Query(object): """ return self._from_selectable( - expression.union_all(*([self._nested_statement]+ [x._nested_statement for x in q])) + expression.union_all(*([self]+ list(q))) ) def intersect(self, *q): @@ -769,7 +769,7 @@ class Query(object): """ return self._from_selectable( - expression.intersect(*([self._nested_statement]+ [x._nested_statement for x in q])) + expression.intersect(*([self]+ list(q))) ) def intersect_all(self, *q): @@ -780,7 +780,7 @@ class Query(object): """ return self._from_selectable( - expression.intersect_all(*([self._nested_statement]+ [x._nested_statement for x in q])) + expression.intersect_all(*([self]+ list(q))) ) def except_(self, *q): @@ -791,7 +791,7 @@ class Query(object): """ return self._from_selectable( - expression.except_(*([self._nested_statement]+ [x._nested_statement for x in q])) + expression.except_(*([self]+ list(q))) ) def except_all(self, *q): @@ -802,7 +802,7 @@ class Query(object): """ return self._from_selectable( - expression.except_all(*([self._nested_statement]+ [x._nested_statement for x in q])) + expression.except_all(*([self]+ list(q))) ) @util.accepts_a_list_as_starargs(list_deprecation='pending') @@ -887,7 +887,7 @@ class Query(object): @_generative(__no_statement_condition, __no_limit_offset) def __join(self, keys, outerjoin, create_aliases, from_joinpoint): - + # copy collections that may mutate so they do not affect # the copied-from query. self.__currenttables = set(self.__currenttables) @@ -904,7 +904,7 @@ class Query(object): # after the method completes, # the query's joinpoint will be set to this. right_entity = None - + for arg1 in util.to_list(keys): aliased_entity = False alias_criterion = False @@ -970,7 +970,7 @@ class Query(object): if ent.corresponds_to(left_entity): clause = ent.selectable break - + if not clause: if isinstance(onclause, interfaces.PropComparator): clause = onclause.__clause_element__() @@ -985,14 +985,14 @@ class Query(object): onclause = prop # start looking at the right side of the join - + mp, right_selectable, is_aliased_class = _entity_info(right_entity) - + if mp is not None and right_mapper is not None and not mp.common_parent(right_mapper): raise sa_exc.InvalidRequestError( "Join target %s does not correspond to the right side of join condition %s" % (right_entity, onclause) ) - + if not right_mapper and mp: right_mapper = mp @@ -1004,7 +1004,7 @@ class Query(object): if not right_selectable.is_derived_from(right_mapper.mapped_table): raise sa_exc.InvalidRequestError( - "Selectable '%s' is not derived from '%s'" % + "Selectable '%s' is not derived from '%s'" % (right_selectable.description, right_mapper.mapped_table.description)) if not isinstance(right_selectable, expression.Alias): @@ -1026,7 +1026,7 @@ class Query(object): # for joins across plain relation()s, try not to specify the # same joins twice. the __currenttables collection tracks # what plain mapped tables we've joined to already. - + if prop.table in self.__currenttables: if prop.secondary is not None and prop.secondary not in self.__currenttables: # TODO: this check is not strong enough for different paths to the same endpoint which @@ -1039,7 +1039,7 @@ class Query(object): if prop.secondary: self.__currenttables.add(prop.secondary) self.__currenttables.add(prop.table) - + if of_type: right_entity = of_type else: @@ -1057,8 +1057,8 @@ class Query(object): onclause = right_adapter.traverse(onclause) onclause = self._adapt_clause(onclause, False, True) - # determine if we want _ORMJoin to alias the onclause - # to the given left side. This is used if we're joining against a + # determine if we want _ORMJoin to alias the onclause + # to the given left side. This is used if we're joining against a # select_from() selectable, from_self() call, or the onclause # has been resolved into a MapperProperty. Otherwise we assume # the onclause itself contains more specific information on how to @@ -1066,10 +1066,10 @@ class Query(object): join_to_left = not is_aliased_class or \ onclause is prop or \ clause is self._from_obj and self._from_obj_alias - - # create the join + + # create the join clause = orm_join(clause, right_entity, onclause, isouter=outerjoin, join_to_left=join_to_left) - + # set up state for the query as a whole if alias_criterion: # adapt filter() calls based on our right side adaptation @@ -1080,14 +1080,14 @@ class Query(object): # and adapt when it renders columns and fetches them from results if aliased_entity: self.__mapper_loads_polymorphically_with( - right_mapper, + right_mapper, ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns) ) - - # loop finished. we're selecting from + + # loop finished. we're selecting from # our final clause now self._from_obj = clause - + # future joins with from_joinpoint=True join from our established right_entity. self._joinpoint = right_entity @@ -1126,13 +1126,13 @@ class Query(object): if isinstance(stop, int) and isinstance(start, int) and stop - start <= 0: return [] - + # perhaps we should execute a count() here so that we # can still use LIMIT/OFFSET ? elif (isinstance(start, int) and start < 0) \ or (isinstance(stop, int) and stop < 0): return list(self)[item] - + res = self.slice(start, stop) if step is not None: return list(res)[None:None:item.step] @@ -1204,7 +1204,7 @@ class Query(object): if not isinstance(statement, (expression._TextClause, expression._SelectBaseMixin)): raise sa_exc.ArgumentError("from_statement accepts text(), select(), and union() objects only.") - + self._statement = statement def first(self): @@ -1439,18 +1439,18 @@ class Query(object): def count(self): """Apply this query's criterion to a SELECT COUNT statement. - + If column expressions or LIMIT/OFFSET/DISTINCT are present, - the query "SELECT count(1) FROM (SELECT ...)" is issued, + the query "SELECT count(1) FROM (SELECT ...)" is issued, so that the result matches the total number of rows this query would return. For mapped entities, - the primary key columns of each is written to the + the primary key columns of each is written to the columns clause of the nested SELECT statement. - + For a Query which is only against mapped entities, - a simpler "SELECT count(1) FROM table1, table2, ... - WHERE criterion" is issued. - + a simpler "SELECT count(1) FROM table1, table2, ... + WHERE criterion" is issued. + """ should_nest = [self._should_nest_selectable] def ent_cols(ent): @@ -1459,8 +1459,8 @@ class Query(object): else: should_nest[0] = True return [ent.column] - - return self._col_aggregate(sql.literal_column('1'), sql.func.count, + + return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=chain(*[ent_cols(ent) for ent in self._entities]), should_nest = should_nest[0] ) @@ -1498,9 +1498,9 @@ class Query(object): def delete(self, synchronize_session='evaluate'): """Perform a bulk delete query. - Deletes rows matched by this query from the database. - - :param synchronize_session: chooses the strategy for the removal of matched + Deletes rows matched by this query from the database. + + :param synchronize_session: chooses the strategy for the removal of matched objects from the session. Valid values are: False @@ -1528,10 +1528,10 @@ class Query(object): The method does *not* offer in-Python cascading of relations - it is assumed that ON DELETE CASCADE is configured for any foreign key references which require it. The Session needs to be expired (occurs automatically after commit(), or call expire_all()) - in order for the state of dependent objects subject to delete or delete-orphan cascade to be + in order for the state of dependent objects subject to delete or delete-orphan cascade to be correctly represented. - - Also, the ``before_delete()`` and ``after_delete()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` + + Also, the ``before_delete()`` and ``after_delete()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` methods are not called from this method. For a delete hook here, use the ``after_bulk_delete()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` method. @@ -1591,11 +1591,11 @@ class Query(object): def update(self, values, synchronize_session='evaluate'): """Perform a bulk update query. - Updates rows matched by this query in the database. - + Updates rows matched by this query in the database. + :param values: a dictionary with attributes names as keys and literal values or sql expressions - as values. - + as values. + :param synchronize_session: chooses the strategy to update the attributes on objects in the session. Valid values are: @@ -1621,10 +1621,14 @@ class Query(object): The method does *not* offer in-Python cascading of relations - it is assumed that ON UPDATE CASCADE is configured for any foreign key references which require it. - - Also, the ``before_update()`` and ``after_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` + + The Session needs to be expired (occurs automatically after commit(), or call expire_all()) + in order for the state of dependent objects subject foreign key cascade to be + correctly represented. + + Also, the ``before_update()`` and ``after_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` methods are not called from this method. For an update hook here, use the - ``after_bulk_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` method. + ``after_bulk_update()`` :class:`~sqlalchemy.orm.interfaces.SessionExtension` method. """ @@ -1694,7 +1698,7 @@ class Query(object): for ext in session.extensions: ext.after_bulk_update(session, self, context, result) - + return result.rowcount def _compile_context(self, labels=True): @@ -1894,10 +1898,7 @@ class _MapperEntity(_QueryEntity): adapter = None if not self.is_aliased_class and query._polymorphic_adapters: - for mapper in self.mapper.iterate_to_root(): - adapter = query._polymorphic_adapters.get(mapper.mapped_table, None) - if adapter: - break + adapter = query._polymorphic_adapters.get(self.mapper, None) if not adapter and self.adapter: adapter = self.adapter @@ -1959,7 +1960,7 @@ class _MapperEntity(_QueryEntity): # apply adaptation to the mapper's order_by if needed. if adapter: context.order_by = adapter.adapt_list(util.to_list(context.order_by)) - + for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic): if query._only_load_props and value.key not in query._only_load_props: continue @@ -1971,14 +1972,14 @@ class _MapperEntity(_QueryEntity): only_load_props=query._only_load_props, column_collection=context.primary_columns ) - + if self._polymorphic_discriminator: if adapter: pd = adapter.columns[self._polymorphic_discriminator] else: pd = self._polymorphic_discriminator context.primary_columns.append(pd) - + def __str__(self): return str(self.mapper) @@ -1995,20 +1996,25 @@ class _ColumnEntity(_QueryEntity): column = column.__clause_element__() else: self._result_label = getattr(column, 'key', None) - + if not isinstance(column, expression.ColumnElement) and hasattr(column, '_select_iterable'): for c in column._select_iterable: if c is column: break _ColumnEntity(query, c) - + if c is not column: return if not isinstance(column, sql.ColumnElement): raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column) - if not hasattr(column, '_label'): + # if the Column is unnamed, give it a + # label() so that mutable column expressions + # can be located in the result even + # if the expression's identity has been changed + # due to adaption + if not column._label: column = column.label(None) query._entities.append(self) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index c756045a1e..61c58b2499 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -267,7 +267,7 @@ class UOWTransaction(object): for elem in self.elements: if elem.isdelete: self.session._remove_newly_deleted(elem.state) - else: + elif not elem.listonly: self.session._register_newly_persistent(elem.state) def _sort_dependencies(self): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index c863729017..1cca9e00b8 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -258,13 +258,20 @@ class ORMAdapter(sql_util.ColumnAdapter): """ def __init__(self, entity, equivalents=None, chain_to=None): - mapper, selectable, is_aliased_class = _entity_info(entity) + self.mapper, selectable, is_aliased_class = _entity_info(entity) if is_aliased_class: self.aliased_class = entity else: self.aliased_class = None sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to) + def replace(self, elem): + entity = elem._annotations.get('parentmapper', None) + if not entity or entity.isa(self.mapper): + return sql_util.ColumnAdapter.replace(self, elem) + else: + return None + class AliasedClass(object): """Represents an 'alias'ed form of a mapped class for usage with Query. @@ -303,7 +310,7 @@ class AliasedClass(object): self.__name__ = 'AliasedClass_' + str(self.__target) def __adapt_element(self, elem): - return self.__adapter.traverse(elem)._annotate({'parententity': self}) + return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper}) def __adapt_prop(self, prop): existing = getattr(self.__target, prop.key) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4608024bbf..787827e8be 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -277,8 +277,9 @@ class SQLCompiler(engine.Compiled): else: schema_prefix = '' tablename = column.table.name - if isinstance(tablename, sql._generated_label): - tablename = tablename % self.anon_map + tablename = isinstance(tablename, sql._generated_label) and \ + self._truncated_identifier("alias", tablename) or tablename + return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name def escape_literal_column(self, text): @@ -330,8 +331,16 @@ class SQLCompiler(engine.Compiled): return sep.join(s for s in (self.process(c) for c in clauselist.clauses) if s is not None) - def visit_calculatedclause(self, clause, **kwargs): - return self.process(clause.clause_expr) + def visit_case(self, clause, **kwargs): + x = "CASE " + if clause.value: + x += self.process(clause.value) + " " + for cond, result in clause.whens: + x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " " + if clause.else_: + x += "ELSE " + self.process(clause.else_) + " " + x += "END" + return x def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) @@ -444,8 +453,11 @@ class SQLCompiler(engine.Compiled): def visit_alias(self, alias, asfrom=False, **kwargs): if asfrom: + alias_name = isinstance(alias.name, sql._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ - self.preparer.format_alias(alias, alias.name % self.anon_map) + self.preparer.format_alias(alias, alias_name) else: return self.process(alias.original, **kwargs) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 34d50c9c59..cfc7f407eb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -461,26 +461,8 @@ def case(whens, value=None, else_=None): }) """ - try: - whens = util.dictlike_iteritems(whens) - except TypeError: - pass - - if value: - crit_filter = _literal_as_binds - else: - crit_filter = _no_literals - - whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None) - for (c,r) in whens] - if else_ is not None: - whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None)) - if whenlist: - type = list(whenlist[-1])[-1].type - else: - type = None - cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END']) - return cc + + return _Case(whens, value=value, else_=else_) def cast(clause, totype, **kwargs): """Return a ``CAST`` function. @@ -508,9 +490,10 @@ def collate(expression, collation): """Return the clause ``expression COLLATE collation``.""" expr = _literal_as_binds(expression) - return _CalculatedClause( - expr, expr, _literal_as_text(collation), - operator=operators.collate, group=False) + return _BinaryExpression( + expr, + _literal_as_text(collation), + operators.collate, type_=expr.type) def exists(*args, **kwargs): """Return an ``EXISTS`` clause as applied to a :class:`~sqlalchemy.sql.expression.Select` object. @@ -922,6 +905,12 @@ def _literal_as_text(element): else: return element +def _clause_element_as_expr(element): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + else: + return element + def _literal_as_column(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() @@ -958,14 +947,6 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): % (column, getattr(column, 'table', None), fromclause.description)) return c -def _selectable(element): - if hasattr(element, '__selectable__'): - return element.__selectable__() - elif isinstance(element, Selectable): - return element - else: - raise exc.ArgumentError("Object %r is not a Selectable and does not implement `__selectable__()`" % element) - def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) @@ -1148,18 +1129,16 @@ class ClauseElement(Visitable): The return value is a :class:`~sqlalchemy.engine.Compiled` object. Calling `str()` or `unicode()` on the returned value will yield - a string representation of the result. The ``Compiled`` + a string representation of the result. The :class:`~sqlalchemy.engine.Compiled` object also can return a dictionary of bind parameter names and values using the `params` accessor. - bind - An ``Engine`` or ``Connection`` from which a + :param bind: An ``Engine`` or ``Connection`` from which a ``Compiled`` will be acquired. This argument takes precedence over this ``ClauseElement``'s bound engine, if any. - dialect - A ``Dialect`` instance frmo which a ``Compiled`` + :param dialect: A ``Dialect`` instance frmo which a ``Compiled`` will be acquired. This argument takes precedence over the `bind` argument as well as this ``ClauseElement``'s bound engine, if any. @@ -1433,6 +1412,8 @@ class _CompareMixin(ColumnOperators): return self._in_impl(operators.in_op, operators.notin_op, other) def _in_impl(self, op, negate_op, seq_or_selectable): + seq_or_selectable = _clause_element_as_expr(seq_or_selectable) + if isinstance(seq_or_selectable, _ScalarSelect): return self.__compare( op, seq_or_selectable, negate=negate_op) @@ -1450,7 +1431,8 @@ class _CompareMixin(ColumnOperators): for o in seq_or_selectable: if not _is_literal(o): if not isinstance( o, _CompareMixin): - raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) + raise exc.InvalidRequestError( + "in() function accepts either a list of non-selectable values, or a selectable: %r" % o) else: o = self._bind_param(o) args.append(o) @@ -1534,9 +1516,7 @@ class _CompareMixin(ColumnOperators): def collate(self, collation): """Produce a COLLATE clause, i.e. `` COLLATE utf8_bin``""" - return _CalculatedClause( - None, self, _literal_as_text(collation), - operator=operators.collate, group=False) + return collate(self, collation) def op(self, operator): """produce a generic operator function. @@ -1607,7 +1587,8 @@ class ColumnElement(ClauseElement, _CompareMixin): primary_key = False foreign_keys = [] quote = None - + _label = None + @property def _select_iterable(self): return (self, ) @@ -1830,6 +1811,10 @@ class FromClause(Selectable): return ClauseAdapter(alias).traverse(self) def correspond_on_equivalents(self, column, equivalents): + """Return corresponding_column for the given column, or if None + search for a match in the given dictionary. + + """ col = self.corresponding_column(column, require_embedded=True) if col is None and col in equivalents: for equiv in equivalents[col]: @@ -1843,11 +1828,9 @@ class FromClause(Selectable): object from this ``Selectable`` which corresponds to that original ``Column`` via a common anscestor column. - column - the target ``ColumnElement`` to be matched + :param column: the target ``ColumnElement`` to be matched - require_embedded - only return corresponding columns for the given + :param require_embedded: only return corresponding columns for the given ``ColumnElement``, if the given ``ColumnElement`` is actually present within a sub-element of this ``FromClause``. Normally the column will match if it merely @@ -2216,73 +2199,55 @@ class BooleanClauseList(ClauseList, ColumnElement): return (self, ) -class _CalculatedClause(ColumnElement): - """Describe a calculated SQL expression that has a type, like ``CASE``. - - Extends ``ColumnElement`` to provide column-level comparison - operators. - - """ +class _Case(ColumnElement): + __visit_name__ = 'case' - __visit_name__ = 'calculatedclause' + def __init__(self, whens, value=None, else_=None): + try: + whens = util.dictlike_iteritems(whens) + except TypeError: + pass - def __init__(self, name, *clauses, **kwargs): - self.name = name - self.type = sqltypes.to_instance(kwargs.get('type_', None)) - self._bind = kwargs.get('bind', None) - self.group = kwargs.pop('group', True) - clauses = ClauseList( - operator=kwargs.get('operator', None), - group_contents=kwargs.get('group_contents', True), - *clauses) - if self.group: - self.clause_expr = clauses.self_group() + if value: + whenlist = [(_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens] else: - self.clause_expr = clauses - - @property - def key(self): - return self.name or '_calc_' + whenlist = [(_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens] + + if whenlist: + type_ = list(whenlist[-1])[-1].type + else: + type_ = None + + self.value = value + self.type = type_ + self.whens = whenlist + if else_ is not None: + self.else_ = _literal_as_binds(else_) + else: + self.else_ = None def _copy_internals(self, clone=_clone): - self.clause_expr = clone(self.clause_expr) - - @property - def clauses(self): - if isinstance(self.clause_expr, _Grouping): - return self.clause_expr.element - else: - return self.clause_expr + if self.value: + self.value = clone(self.value) + self.whens = [(clone(x), clone(y)) for x, y in self.whens] + if self.else_: + self.else_ = clone(self.else_) def get_children(self, **kwargs): - return self.clause_expr, + if self.value: + yield self.value + for x, y in self.whens: + yield x + yield y + if self.else_: + yield self.else_ @property def _from_objects(self): - return self.clauses._from_objects - - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type_=self.type, unique=True) + return itertools.chain(*[x._from_objects for x in self.get_children()]) - def select(self): - return select([self]) - - def scalar(self): - return select([self]).execute().scalar() - - def execute(self): - return select([self]).execute() - - def _compare_type(self, obj): - return self.type - -class Function(_CalculatedClause, FromClause): - """Describe a SQL function. - - Extends ``_CalculatedClause``, turn the *clauselist* into function - arguments, also adds a `packagenames` argument. - - """ +class Function(ColumnElement, FromClause): + """Describe a SQL function.""" __visit_name__ = 'function' @@ -2302,12 +2267,36 @@ class Function(_CalculatedClause, FromClause): def columns(self): return [self] + @util.memoized_property + def clauses(self): + return self.clause_expr.element + + @property + def _from_objects(self): + return self.clauses._from_objects + + def get_children(self, **kwargs): + return self.clause_expr, + def _copy_internals(self, clone=_clone): - _CalculatedClause._copy_internals(self, clone=clone) + self.clause_expr = clone(self.clause_expr) self._reset_exported() + util.reset_memoized(self, 'clauses') + + def _bind_param(self, obj): + return _BindParamClause(self.name, obj, type_=self.type, unique=True) - def get_children(self, **kwargs): - return _CalculatedClause.get_children(self, **kwargs) + def select(self): + return select([self]) + + def scalar(self): + return select([self]).execute().scalar() + + def execute(self): + return select([self]).execute() + + def _compare_type(self, obj): + return self.type class _Cast(ColumnElement): @@ -2493,8 +2482,8 @@ class Join(FromClause): __visit_name__ = 'join' def __init__(self, left, right, onclause=None, isouter=False): - self.left = _selectable(left) - self.right = _selectable(right).self_group() + self.left = _literal_as_text(left) + self.right = _literal_as_text(right).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) @@ -2843,9 +2832,12 @@ class ColumnClause(_Immutable, ColumnElement): elif self.table and self.table.named_with_column: if getattr(self.table, 'schema', None): - label = self.table.schema + "_" + _escape_for_generated(self.table.name) + "_" + _escape_for_generated(self.name) + label = self.table.schema + "_" + \ + _escape_for_generated(self.table.name) + "_" + \ + _escape_for_generated(self.name) else: - label = _escape_for_generated(self.table.name) + "_" + _escape_for_generated(self.name) + label = _escape_for_generated(self.table.name) + "_" + \ + _escape_for_generated(self.name) if label in self.table.c: # TODO: coverage does not seem to be present for this @@ -3133,6 +3125,8 @@ class CompoundSelect(_SelectBaseMixin, FromClause): # some DBs do not like ORDER BY in the inner queries of a UNION, etc. for n, s in enumerate(selects): + s = _clause_element_as_expr(s) + if not numcols: numcols = len(s.c) elif len(s.c) != numcols: @@ -3398,9 +3392,7 @@ class Select(_SelectBaseMixin, FromClause): """return a new select() construct with the given FROM expression applied to its list of FROM objects.""" - if _is_literal(fromclause): - fromclause = _TextClause(fromclause) - + fromclause = _literal_as_text(fromclause) self._froms = self._froms.union([fromclause]) @_generative diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 1bcc6d864f..c6cb938d44 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -13,18 +13,13 @@ class _GenericMeta(VisitableType): class GenericFunction(Function): __metaclass__ = _GenericMeta - def __init__(self, type_=None, group=True, args=(), **kwargs): + def __init__(self, type_=None, args=(), **kwargs): self.packagenames = [] self.name = self.__class__.__name__ self._bind = kwargs.get('bind', None) - if group: - self.clause_expr = ClauseList( + self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *args).self_group() - else: - self.clause_expr = ClauseList( - operator=operators.comma_op, - group_contents=True, *args) self.type = sqltypes.to_instance( type_ or getattr(self, '__return_type__', None)) diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index cd1e48cafe..879f0f3e51 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -122,7 +122,7 @@ _PRECEDENCE = { and_: 3, or_: 2, comma_op: -1, - collate: -2, + collate: 7, as_: -1, exists: 0, _smallest: -1000, diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index e1619cbc07..aeafb76475 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -1383,10 +1383,7 @@ class memoized_instancemethod(object): return oneshot def reset_memoized(instance, name): - try: - del instance.__dict__[name] - except KeyError: - pass + instance.__dict__.pop(name, None) class WeakIdentityMapping(weakref.WeakKeyDictionary): """A WeakKeyDictionary with an object identity index. diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index ace572fd54..3ce8a9220b 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -126,6 +126,61 @@ class CompileTest(TestBase, AssertsCompiledSQL): self.assert_compile(func.current_date(), "GETDATE()") self.assert_compile(func.length(3), "LEN(:length_1)") + +class IdentityInsertTest(TestBase, AssertsCompiledSQL): + __only_on__ = 'mssql' + __dialect__ = mssql.MSSQLDialect() + + def setUpAll(self): + global metadata, cattable + metadata = MetaData(testing.db) + + cattable = Table('cattable', metadata, + Column('id', Integer), + Column('description', String(50)), + PrimaryKeyConstraint('id', name='PK_cattable'), + ) + + def setUp(self): + metadata.create_all() + + def tearDown(self): + metadata.drop_all() + + def test_compiled(self): + self.assert_compile(cattable.insert().values(id=9, description='Python'), "INSERT INTO cattable (id, description) VALUES (:id, :description)") + + def test_execute(self): + cattable.insert().values(id=9, description='Python').execute() + + cats = cattable.select().order_by(cattable.c.id).execute() + self.assertEqual([(9, 'Python')], list(cats)) + + result = cattable.insert().values(description='PHP').execute() + self.assertEqual([10], result.last_inserted_ids()) + lastcat = cattable.select().order_by(desc(cattable.c.id)).execute() + self.assertEqual((10, 'PHP'), lastcat.fetchone()) + + def test_executemany(self): + cattable.insert().execute([ + {'id': 89, 'description': 'Python'}, + {'id': 8, 'description': 'Ruby'}, + {'id': 3, 'description': 'Perl'}, + {'id': 1, 'description': 'Java'}, + ]) + + cats = cattable.select().order_by(cattable.c.id).execute() + self.assertEqual([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats)) + + cattable.insert().execute([ + {'description': 'PHP'}, + {'description': 'Smalltalk'}, + ]) + + lastcats = cattable.select().order_by(desc(cattable.c.id)).limit(2).execute() + self.assertEqual([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats)) + + class ReflectionTest(TestBase): __only_on__ = 'mssql' diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index e62ef93eb3..dfe2dfd182 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -639,15 +639,23 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): m1 = MetaData(testing.db) t1 = Table('party', m1, Column('id', String(10), nullable=False), - Column('name', String(20), index=True) + Column('name', String(20), index=True), + Column('aname', String(20)) ) m1.create_all() + testing.db.execute(""" create index idx1 on party ((id || name)) """, None) testing.db.execute(""" create unique index idx2 on party (id) where name = 'test' """, None) + + testing.db.execute(""" + create index idx3 on party using btree + (lower(name::text), lower(aname::text)) + """) + try: m2 = MetaData(testing.db) @@ -663,6 +671,7 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): # Make sure indexes are in the order we expect them in tmp = [(idx.name, idx) for idx in t2.indexes] tmp.sort() + r1, r2 = [idx[1] for idx in tmp] assert r1.name == 'idx2' diff --git a/test/orm/cascade.py b/test/orm/cascade.py index 3345a5d8cf..746dc0e52f 100644 --- a/test/orm/cascade.py +++ b/test/orm/cascade.py @@ -1176,5 +1176,116 @@ class CollectionAssignmentOrphanTest(_base.MappedTest): eq_(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) + +class PartialFlushTest(_base.MappedTest): + """test cascade behavior as it relates to object lists passed to flush(). + + """ + def define_tables(self, metadata): + Table("base", metadata, + Column("id", Integer, primary_key=True), + Column("descr", String(50)) + ) + + Table("noninh_child", metadata, + Column('id', Integer, primary_key=True), + Column('base_id', Integer, ForeignKey('base.id')) + ) + + Table("parent", metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True) + ) + Table("inh_child", metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("parent_id", Integer, ForeignKey("parent.id")) + ) + + + @testing.resolve_artifact_names + def test_o2m_m2o(self): + class Base(_base.ComparableEntity): + pass + class Child(_base.ComparableEntity): + pass + + mapper(Base, base, properties={ + 'children':relation(Child, backref='parent') + }) + mapper(Child, noninh_child) + + sess = create_session() + + c1, c2 = Child(), Child() + b1 = Base(descr='b1', children=[c1, c2]) + sess.add(b1) + + assert c1 in sess.new + assert c2 in sess.new + sess.flush([b1]) + + # c1, c2 get cascaded into the session on o2m. + # not sure if this is how I like this + # to work but that's how it works for now. + assert c1 in sess and c1 not in sess.new + assert c2 in sess and c2 not in sess.new + assert b1 in sess and b1 not in sess.new + + sess = create_session() + c1, c2 = Child(), Child() + b1 = Base(descr='b1', children=[c1, c2]) + sess.add(b1) + sess.flush([c1]) + # m2o, otoh, doesn't cascade up the other way. + assert c1 in sess and c1 not in sess.new + assert c2 in sess and c2 in sess.new + assert b1 in sess and b1 in sess.new + + sess = create_session() + c1, c2 = Child(), Child() + b1 = Base(descr='b1', children=[c1, c2]) + sess.add(b1) + sess.flush([c1, c2]) + # m2o, otoh, doesn't cascade up the other way. + assert c1 in sess and c1 not in sess.new + assert c2 in sess and c2 not in sess.new + assert b1 in sess and b1 in sess.new + + @testing.resolve_artifact_names + def test_circular_sort(self): + """test ticket 1306""" + + class Base(_base.ComparableEntity): + pass + class Parent(Base): + pass + class Child(Base): + pass + + mapper(Base,base) + + mapper(Child, inh_child, + inherits=Base, + properties={'parent': relation( + Parent, + backref='children', + primaryjoin=inh_child.c.parent_id == parent.c.id + )} + ) + + + mapper(Parent,parent, inherits=Base) + + sess = create_session() + p1 = Parent() + + c1, c2, c3 = Child(), Child(), Child() + p1.children = [c1, c2, c3] + sess.add(p1) + + sess.flush([c1]) + assert p1 in sess.new + assert c1 not in sess.new + assert c2 in sess.new + if __name__ == "__main__": testenv.main() diff --git a/test/orm/expire.py b/test/orm/expire.py index 4e8771347e..c11fb69dfe 100644 --- a/test/orm/expire.py +++ b/test/orm/expire.py @@ -747,6 +747,51 @@ class PolymorphicExpireTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 2) self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1'])) +class ExpiredPendingTest(_fixtures.FixtureTest): + run_define_tables = 'once' + run_setup_classes = 'once' + run_setup_mappers = None + run_inserts = None + + @testing.resolve_artifact_names + def test_expired_pending(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + a1 = Address(email_address='a1') + sess.add(a1) + sess.flush() + + u1 = User(name='u1') + a1.user = u1 + sess.flush() + + # expire 'addresses'. backrefs + # which attach to u1 will expect to be "pending" + sess.expire(u1, ['addresses']) + + # attach an Address. now its "pending" + # in user.addresses + a2 = Address(email_address='a2') + a2.user = u1 + + # expire u1.addresses again. this expires + # "pending" as well. + sess.expire(u1, ['addresses']) + + # insert a new row + sess.execute(addresses.insert(), dict(email_address='a3', user_id=u1.id)) + + # only two addresses pulled from the DB, no "pending" + assert len(u1.addresses) == 2 + + sess.flush() + sess.expire_all() + assert len(u1.addresses) == 3 + class RefreshTest(_fixtures.FixtureTest): @@ -783,9 +828,6 @@ class RefreshTest(_fixtures.FixtureTest): s.expire(u) # get the attribute, it refreshes - print "OK------" -# print u.__dict__ -# print u._state.callables assert u.name == 'jack' assert id(a) not in [id(x) for x in u.addresses] diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index fe948931b6..ca789f8338 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -9,6 +9,8 @@ from sqlalchemy.orm import * from sqlalchemy import exc as sa_exc from testlib import * from testlib import fixtures +from orm import _base +from testlib.testing import eq_ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.engine import default @@ -748,7 +750,7 @@ class SelfReferentialTestJoinedToBase(ORMTest): sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), Engineer(name='dilbert')) -class SelfReferentialTestJoinedToJoined(ORMTest): +class SelfReferentialJ2JTest(ORMTest): keep_mappers = True def define_tables(self, metadata): @@ -773,7 +775,7 @@ class SelfReferentialTestJoinedToJoined(ORMTest): mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer', properties={ - 'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id) + 'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id, backref='engineers') }) def test_has(self): @@ -800,6 +802,62 @@ class SelfReferentialTestJoinedToJoined(ORMTest): self.assertEquals( sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), Engineer(name='dilbert')) + + def test_filter_aliasing(self): + m1 = Manager(name='dogbert') + m2 = Manager(name='foo') + e1 = Engineer(name='wally', primary_language='java', reports_to=m1) + e2 = Engineer(name='dilbert', primary_language='c++', reports_to=m2) + e3 = Engineer(name='etc', primary_language='c++') + sess = create_session() + sess.add_all([m1, m2, e1, e2, e3]) + sess.flush() + sess.expunge_all() + + # filter aliasing applied to Engineer doesn't whack Manager + self.assertEquals( + sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(), + [m1] + ) + + self.assertEquals( + sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(), + [m2] + ) + + self.assertEquals( + sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(), + [ + (m2, e2), + (m1, e1), + ] + ) + + def test_relation_compare(self): + m1 = Manager(name='dogbert') + m2 = Manager(name='foo') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) + e2 = Engineer(name='wally', primary_language='c++', reports_to=m2) + e3 = Engineer(name='etc', primary_language='c++') + sess = create_session() + sess.add(m1) + sess.add(m2) + sess.add(e1) + sess.add(e2) + sess.add(e3) + sess.flush() + sess.expunge_all() + + self.assertEquals( + sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), + [] + ) + + self.assertEquals( + sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), + [m1] + ) + class M2MFilterTest(ORMTest): @@ -868,6 +926,8 @@ class M2MFilterTest(ORMTest): self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL): + keep_mappers = True + def define_tables(self, metadata): Base = declarative_base(metadata=metadata) @@ -895,9 +955,50 @@ class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL): Child1.left_child2 = relation(Child2, secondary = secondary_table, primaryjoin = Parent.id == secondary_table.c.right_id, secondaryjoin = Parent.id == secondary_table.c.left_id, - uselist = False, + uselist = False, backref="right_children" ) + + def test_query_crit(self): + session = create_session() + c11, c12, c13 = Child1(), Child1(), Child1() + c21, c22, c23 = Child2(), Child2(), Child2() + + c11.left_child2 = c22 + c12.left_child2 = c22 + c13.left_child2 = c23 + + session.add_all([c11, c12, c13, c21, c22, c23]) + session.flush() + + # test that the join to Child2 doesn't alias Child1 in the select + eq_( + set(session.query(Child1).join(Child1.left_child2)), + set([c11, c12, c13]) + ) + + eq_( + set(session.query(Child1, Child2).join(Child1.left_child2)), + set([(c11, c22), (c12, c22), (c13, c23)]) + ) + + # test __eq__() on property is annotating correctly + eq_( + set(session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22)), + set([c22]) + ) + + # test the same again + self.assert_compile( + session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22).with_labels().statement, + "SELECT parent.id AS parent_id, child2.id AS child2_id, parent.cls AS parent_cls FROM " + "secondary AS secondary_1, parent JOIN child2 ON parent.id = child2.id JOIN secondary AS secondary_2 " + "ON parent.id = secondary_2.left_id JOIN (SELECT parent.id AS parent_id, parent.cls AS parent_cls, " + "child1.id AS child1_id FROM parent JOIN child1 ON parent.id = child1.id) AS anon_1 ON " + "anon_1.parent_id = secondary_2.right_id WHERE anon_1.parent_id = secondary_1.right_id AND :param_1 = secondary_1.left_id", + dialect=default.DefaultDialect() + ) + def test_eager_join(self): session = create_session() diff --git a/test/orm/query.py b/test/orm/query.py index 6f0b69f17a..a27dcfadc7 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -516,15 +516,57 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): self.assert_compile(sess.query(x).filter(x==5).statement, "SELECT lala(users.id) AS foo FROM users WHERE lala(users.id) = :param_1", dialect=default.DefaultDialect()) -class CompileTest(QueryTest): +class ExpressionTest(QueryTest, AssertsCompiledSQL): - def test_deferred(self): + def test_deferred_instances(self): session = create_session() s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).statement l = list(session.query(User).instances(s.execute(emailad = 'jack@bean.com'))) - assert [User(id=7)] == l + eq_([User(id=7)], l) + + def test_in(self): + session = create_session() + s = session.query(User.id).join(User.addresses).group_by(User.id).having(func.count(Address.id) > 2) + eq_( + session.query(User).filter(User.id.in_(s)).all(), + [User(id=8)] + ) + + def test_union(self): + s = create_session() + + q1 = s.query(User).filter(User.name=='ed').with_labels() + q2 = s.query(User).filter(User.name=='fred').with_labels() + eq_( + s.query(User).from_statement(union(q1, q2).order_by('users_name')).all(), + [User(name='ed'), User(name='fred')] + ) + + def test_select(self): + s = create_session() + + # this is actually not legal on most DBs since the subquery has no alias + q1 = s.query(User).filter(User.name=='ed') + self.assert_compile( + select([q1]), + "SELECT id, name FROM (SELECT users.id AS id, users.name AS name FROM users WHERE users.name = :name_1)", + dialect=default.DefaultDialect() + ) + + def test_join(self): + s = create_session() + + # TODO: do we want aliased() to detect a query and convert to subquery() + # automatically ? + q1 = s.query(Address).filter(Address.email_address=='jack@bean.com') + adalias = aliased(Address, q1.subquery()) + eq_( + s.query(User, adalias).join((adalias, User.id==adalias.user_id)).all(), + [(User(id=7,name=u'jack'), Address(email_address=u'jack@bean.com',user_id=7,id=1))] + ) + # more slice tests are available in test/orm/generative.py class SliceTest(QueryTest): def test_first(self): diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 532203ce20..9787216f73 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -1,7 +1,7 @@ import testenv; testenv.configure_for_tests() import datetime from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData +from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_ from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers from testlib.testing import eq_, startswith_ from orm import _base, _fixtures @@ -650,6 +650,79 @@ class RelationTest6(_base.MappedTest): [TagInstance(data='iplc_case'), TagInstance(data='not_iplc_case')] ) +class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): + """test ambiguous joins due to FKs on both sides treated as self-referential. + + this mapping is very similar to that of test/orm/inheritance/query.py + SelfReferentialTestJoinedToBase , except that inheritance is not used + here. + + """ + + def define_tables(self, metadata): + subscriber_table = Table('subscriber', metadata, + Column('id', Integer, primary_key=True), + Column('dummy', String(10)) # to appease older sqlite version + ) + + address_table = Table('address', + metadata, + Column('subscriber_id', Integer, ForeignKey('subscriber.id'), primary_key=True), + Column('type', String(1), primary_key=True), + ) + + @testing.resolve_artifact_names + def setup_mappers(self): + subscriber_and_address = subscriber.join(address, + and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C']))) + + class Address(_base.ComparableEntity): + pass + + class Subscriber(_base.ComparableEntity): + pass + + mapper(Address, address) + + mapper(Subscriber, subscriber_and_address, properties={ + 'id':[subscriber.c.id, address.c.subscriber_id], + 'addresses' : relation(Address, + backref=backref("customer")) + }) + + @testing.resolve_artifact_names + def test_mapping(self): + from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE + sess = create_session() + assert Subscriber.addresses.property.direction is ONETOMANY + assert Address.customer.property.direction is MANYTOONE + + s1 = Subscriber(type='A', + addresses = [ + Address(type='D'), + Address(type='E'), + ] + ) + a1 = Address(type='B', customer=Subscriber(type='C')) + + assert s1.addresses[0].customer is s1 + assert a1.customer.addresses[0] is a1 + + sess.add_all([s1, a1]) + + sess.flush() + sess.expunge_all() + + eq_( + sess.query(Subscriber).order_by(Subscriber.type).all(), + [ + Subscriber(id=1, type=u'A'), + Subscriber(id=2, type=u'B'), + Subscriber(id=2, type=u'C') + ] + ) + + class ManualBackrefTest(_fixtures.FixtureTest): """Test explicit relations that are backrefs to each other.""" diff --git a/test/profiling/memusage.py b/test/profiling/memusage.py index 3cb4dfb9fa..4f65f1d33c 100644 --- a/test/profiling/memusage.py +++ b/test/profiling/memusage.py @@ -6,6 +6,8 @@ from sqlalchemy.orm.session import _sessions import operator from testlib import testing from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, PickleType +import sqlalchemy as sa +from sqlalchemy.sql import column from orm import _base import sqlalchemy as sa from sqlalchemy.sql import column diff --git a/test/sql/functions.py b/test/sql/functions.py index ac9b7e3292..1519575036 100644 --- a/test/sql/functions.py +++ b/test/sql/functions.py @@ -37,7 +37,11 @@ class CompileTest(TestBase, AssertsCompiledSQL): GenericFunction.__init__(self, args=[arg], **kwargs) self.assert_compile(fake_func('foo'), "fake_func(%s)" % bindtemplate % {'name':'param_1', 'position':1}, dialect=dialect) - + + def test_use_labels(self): + self.assert_compile(select([func.foo()], use_labels=True), + "SELECT foo() AS foo_1" + ) def test_underscores(self): self.assert_compile(func.if_(), "if()") diff --git a/test/sql/generative.py b/test/sql/generative.py index 2072fb75e8..3947a450fe 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -447,6 +447,13 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL): self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2") self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2") + self.assert_compile(vis.traverse(case([(t1.c.col1==5, t1.c.col2)], else_=t1.c.col1)), + "CASE WHEN (t1alias.col1 = :col1_1) THEN t1alias.col2 ELSE t1alias.col1 END" + ) + self.assert_compile(vis.traverse(case([(5, t1.c.col2)], value=t1.c.col1, else_=t1.c.col1)), + "CASE t1alias.col1 WHEN :param_1 THEN t1alias.col2 ELSE t1alias.col1 END" + ) + s = select(['*'], from_obj=[t1]).alias('foo') self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo") diff --git a/test/sql/labels.py b/test/sql/labels.py index 5a620be8c8..94ee20342e 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -4,9 +4,6 @@ from sqlalchemy import exc as exceptions from testlib import * from sqlalchemy.engine import default -# TODO: either create a mock dialect with named paramstyle and a short identifier length, -# or find a way to just use sqlite dialect and make those changes - IDENT_LENGTH = 29 class LabelTypeTest(TestBase): @@ -20,13 +17,18 @@ class LabelTypeTest(TestBase): class LongLabelsTest(TestBase, AssertsCompiledSQL): def setUpAll(self): - global metadata, table1, maxlen + global metadata, table1, table2, maxlen metadata = MetaData(testing.db) table1 = Table("some_large_named_table", metadata, Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True), Column("this_is_the_data_column", String(30)) ) + table2 = Table("table_with_exactly_29_characs", metadata, + Column("this_is_the_primarykey_column", Integer, Sequence("some_seq"), primary_key=True), + Column("this_is_the_data_column", String(30)) + ) + metadata.create_all() maxlen = testing.db.dialect.max_identifier_length @@ -87,6 +89,37 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): (3, "data3"), ], repr(result) + def test_table_alias_names(self): + self.assert_compile( + table2.alias().select(), + "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1" + ) + + ta = table2.alias() + dialect = default.DefaultDialect() + dialect.max_identifier_length = IDENT_LENGTH + self.assert_compile( + select([table1, ta]).select_from(table1.join(ta, table1.c.this_is_the_data_column==ta.c.this_is_the_data_column)).\ + where(ta.c.this_is_the_data_column=='data3'), + + "SELECT some_large_named_table.this_is_the_primarykey_column, some_large_named_table.this_is_the_data_column, " + "table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM " + "some_large_named_table JOIN table_with_exactly_29_characs AS table_with_exactly_29_c_1 ON " + "some_large_named_table.this_is_the_data_column = table_with_exactly_29_c_1.this_is_the_data_column " + "WHERE table_with_exactly_29_c_1.this_is_the_data_column = :this_is_the_data_column_1", + dialect=dialect + ) + + table2.insert().execute( + {"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"}, + {"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"}, + {"this_is_the_primarykey_column":3, "this_is_the_data_column":"data3"}, + {"this_is_the_primarykey_column":4, "this_is_the_data_column":"data4"}, + ) + + r = table2.alias().select().execute() + assert r.fetchall() == [(x, "data%d" % x) for x in range(1, 5)] + def test_colbinds(self): table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"}) table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"}) @@ -153,9 +186,9 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :this_1) AS anon_1", dialect=compile_dialect) compile_dialect = default.DefaultDialect(label_length=4) - self.assert_compile(x, "SELECT anon_1.this_is_the_primarykey_column AS _1, anon_1.this_is_the_data_column AS _2 FROM " + self.assert_compile(x, "SELECT _1.this_is_the_primarykey_column AS _1, _1.this_is_the_data_column AS _2 FROM " "(SELECT some_large_named_table.this_is_the_primarykey_column AS _3, some_large_named_table.this_is_the_data_column AS _4 " - "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :_1) AS anon_1", dialect=compile_dialect) + "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :_1) AS _1", dialect=compile_dialect) if __name__ == '__main__': diff --git a/test/sql/select.py b/test/sql/select.py index 782016e7d6..a4de6e331e 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -131,6 +131,28 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A select([ClauseList(column('a'), column('b'))]).select_from('sometable'), 'SELECT a, b FROM sometable' ) + + def test_use_labels(self): + self.assert_compile( + select([table1.c.myid==5], use_labels=True), + "SELECT mytable.myid = :myid_1 AS anon_1 FROM mytable" + ) + + self.assert_compile( + select([func.foo()], use_labels=True), + "SELECT foo() AS foo_1" + ) + + self.assert_compile( + select([not_(True)], use_labels=True), + "SELECT NOT :param_1" # TODO: should this make an anon label ?? + ) + + self.assert_compile( + select([cast("data", Integer)], use_labels=True), # this will work with plain Integer in 0.6 + "SELECT CAST(:param_1 AS INTEGER) AS anon_1" + ) + def test_nested_uselabels(self): """test nested anonymous label generation. this essentially tests the ANONYMOUS_LABEL regex. @@ -357,7 +379,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A select([x.label('foo')]), 'SELECT a AND b AND c AS foo' ) - + self.assert_compile( and_(table1.c.myid == 12, table1.c.name=='asdf', table2.c.othername == 'foo', "sysdate() = today()"), "mytable.myid = :myid_1 AND mytable.name = :name_1 "\ @@ -812,20 +834,28 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today self.assert_compile(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date_1, :to_date_2)) AS extract_1") def test_collate(self): - for expr in (select([table1.c.name.collate('somecol')]), - select([collate(table1.c.name, 'somecol')])): + for expr in (select([table1.c.name.collate('latin1_german2_ci')]), + select([collate(table1.c.name, 'latin1_german2_ci')])): self.assert_compile( - expr, "SELECT mytable.name COLLATE somecol FROM mytable") + expr, "SELECT mytable.name COLLATE latin1_german2_ci AS anon_1 FROM mytable") - expr = select([table1.c.name.collate('somecol').like('%x%')]) + assert table1.c.name.collate('latin1_german2_ci').type is table1.c.name.type + + expr = select([table1.c.name.collate('latin1_german2_ci').label('k1')]).order_by('k1') + self.assert_compile(expr,"SELECT mytable.name COLLATE latin1_german2_ci AS k1 FROM mytable ORDER BY k1") + + expr = select([collate('foo', 'latin1_german2_ci').label('k1')]) + self.assert_compile(expr,"SELECT :param_1 COLLATE latin1_german2_ci AS k1") + + expr = select([table1.c.name.collate('latin1_german2_ci').like('%x%')]) self.assert_compile(expr, - "SELECT mytable.name COLLATE somecol " + "SELECT mytable.name COLLATE latin1_german2_ci " "LIKE :param_1 AS anon_1 FROM mytable") - expr = select([table1.c.name.like(collate('%x%', 'somecol'))]) + expr = select([table1.c.name.like(collate('%x%', 'latin1_german2_ci'))]) self.assert_compile(expr, "SELECT mytable.name " - "LIKE :param_1 COLLATE somecol AS anon_1 " + "LIKE :param_1 COLLATE latin1_german2_ci AS anon_1 " "FROM mytable") expr = select([table1.c.name.collate('col1').like( @@ -835,10 +865,14 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today "LIKE :param_1 COLLATE col2 AS anon_1 " "FROM mytable") - expr = select([func.concat('a', 'b').collate('somecol').label('x')]) + expr = select([func.concat('a', 'b').collate('latin1_german2_ci').label('x')]) self.assert_compile(expr, "SELECT concat(:param_1, :param_2) " - "COLLATE somecol AS x") + "COLLATE latin1_german2_ci AS x") + + + expr = select([table1.c.name]).order_by(table1.c.name.collate('latin1_german2_ci')) + self.assert_compile(expr, "SELECT mytable.name FROM mytable ORDER BY mytable.name COLLATE latin1_german2_ci") def test_percent_chars(self): t = table("table%name", -- 2.47.3