From faf4aca165cef9bbd8d90b7a4f4ccf2b3d986ea1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 18 Dec 2007 05:40:06 +0000 Subject: [PATCH] - cleanup; lambdas removed from properties; properties mirror same-named functions (more like eventual decorator syntax); remove some old methods, factor out some "raiseerr" ugliness to outer lying functions. - corresponding_column() integrates "require_embedded" flag with other set arithmetic --- CHANGES | 27 ++-- lib/sqlalchemy/databases/mssql.py | 2 +- lib/sqlalchemy/engine/base.py | 83 ++++++++---- lib/sqlalchemy/engine/threadlocal.py | 18 ++- lib/sqlalchemy/orm/attributes.py | 6 +- lib/sqlalchemy/orm/dependency.py | 10 +- lib/sqlalchemy/orm/dynamic.py | 4 +- lib/sqlalchemy/orm/mapper.py | 57 ++++---- lib/sqlalchemy/orm/properties.py | 43 +++--- lib/sqlalchemy/orm/query.py | 6 +- lib/sqlalchemy/orm/session.py | 47 ++++--- lib/sqlalchemy/orm/unitofwork.py | 32 ++--- lib/sqlalchemy/orm/util.py | 6 +- lib/sqlalchemy/schema.py | 168 +++++++++-------------- lib/sqlalchemy/sql/expression.py | 194 +++++++++++++-------------- lib/sqlalchemy/sql/util.py | 4 +- test/sql/selectable.py | 4 + 17 files changed, 354 insertions(+), 357 deletions(-) diff --git a/CHANGES b/CHANGES index 7f48b7e089..1077b4c30d 100644 --- a/CHANGES +++ b/CHANGES @@ -73,6 +73,18 @@ CHANGES issued directly by the ORM in the form of UPDATE statements, by setting the flag "passive_cascades=False". + - new synonym() behavior: an attribute will be placed on the mapped + class, if one does not exist already, in all cases. if a property + already exists on the class, the synonym will decorate the property + with the appropriate comparison operators so that it can be used in in + column expressions just like any other mapped attribute (i.e. usable in + filter(), etc.) the "proxy=True" flag is deprecated and no longer means + anything. Additionally, the flag "map_column=True" will automatically + generate a ColumnProperty corresponding to the name of the synonym, + i.e.: 'somename':synonym('_somename', map_column=True) will map the + column named 'somename' to the attribute '_somename'. See the example + in the mapper docs. [ticket:801] + - Query.select_from() now replaces all existing FROM criterion with the given argument; the previous behavior of constructing a list of FROM clauses was generally not useful as is required @@ -130,18 +142,6 @@ CHANGES disregarding any existing filter, join, group_by or other criterion which has been configured. [ticket:893] - - new synonym() behavior: an attribute will be placed on the mapped - class, if one does not exist already, in all cases. if a property - already exists on the class, the synonym will decorate the property - with the appropriate comparison operators so that it can be used in in - column expressions just like any other mapped attribute (i.e. usable in - filter(), etc.) the "proxy=True" flag is deprecated and no longer means - anything. Additionally, the flag "map_column=True" will automatically - generate a ColumnProperty corresponding to the name of the synonym, - i.e.: 'somename':synonym('_somename', map_column=True) will map the - column named 'somename' to the attribute '_somename'. See the example - in the mapper docs. [ticket:801] - - added support for version_id_col in conjunction with inheriting mappers. version_id_col is typically set on the base mapper in an inheritance relationship where it takes effect for all inheriting mappers. @@ -159,7 +159,8 @@ CHANGES mapper.get_attr_by_column(), mapper.set_attr_by_column(), mapper.pks_by_table, mapper.cascade_callable(), MapperProperty.cascade_callable(), mapper.canload(), - mapper._mapper_registry, attributes.AttributeManager + mapper.save_obj(), mapper.delete_obj(), mapper._mapper_registry, + attributes.AttributeManager - Assigning an incompatible collection type to a relation attribute now raises TypeError instead of sqlalchemy's ArgumentError. diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 098bd33c89..1654677b74 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -915,7 +915,7 @@ class MSSQLCompiler(compiler.DefaultCompiler): # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: - return self.process(t.corresponding_column(column)) + return self.process(expression._corresponding_column_or_error(t, column)) return super(MSSQLCompiler, self).visit_column(column, **kwargs) def visit_binary(self, binary, **kwargs): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 801d4e28c1..3219e6c5b8 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -543,12 +543,6 @@ class Connection(Connectable): self.__savepoint_seq = 0 self.__branch = _branch - def _get_connection(self): - try: - return self.__connection - except AttributeError: - raise exceptions.InvalidRequestError("This Connection is closed") - def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, @@ -559,16 +553,35 @@ class Connection(Connectable): """ return Connection(self.engine, self.__connection, _branch=True) - dialect = property(lambda s:s.engine.dialect, doc="Dialect used by this Connection.") - connection = property(_get_connection, doc="The underlying DB-API connection managed by this Connection.") - should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") + def dialect(self): + "Dialect used by this Connection." + + return self.engine.dialect + dialect = property(dialect) + + def connection(self): + "The underlying DB-API connection managed by this Connection." - info = property(lambda s: s._get_connection().info, - doc=("A collection of per-DB-API connection instance " - "properties.")) - properties = property(lambda s: s._get_connection().info, - doc=("An alias for the .info collection, will be " - "removed in 0.5.")) + try: + return self.__connection + except AttributeError: + raise exceptions.InvalidRequestError("This Connection is closed") + connection = property(connection) + + def should_close_with_result(self): + """Indicates if this Connection should be closed when a corresponding + ResultProxy is closed; this is essentially an auto-release mode. + """ + + return self.__close_with_result + should_close_with_result = property(should_close_with_result) + + def info(self): + """A collection of per-DB-API connection instance properties.""" + return self.connection.info + info = property(info) + + properties = property(info, doc="""An alias for the .info collection, will be removed in 0.5.""") def connect(self): """Returns self. @@ -940,9 +953,15 @@ class Transaction(object): self._connection = connection self._parent = parent or self self._is_active = True + + def connection(self): + "The Connection object referenced by this Transaction" + return self._connection + connection = property(connection) - connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction") - is_active = property(lambda s:s._is_active) + def is_active(self): + return self._is_active + is_active = property(is_active) def close(self): """Close this transaction. @@ -1041,7 +1060,12 @@ class Engine(Connectable): self.engine = self self.logger = logging.instance_logger(self, echoflag=echo) - name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'], doc="String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``.") + def name(self): + "String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``." + + return sys.modules[self.dialect.__module__].descriptor()['name'] + name = property(name) + echo = logging.echo_property() def __repr__(self): @@ -1068,10 +1092,9 @@ class Engine(Connectable): finally: connection.close() - def _func(self): + def func(self): return expression._FunctionGenerator(bind=self) - - func = property(_func) + func = property(func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" @@ -1321,14 +1344,20 @@ class ResultProxy(object): self._rowcount = context.get_rowcount() self.close() - def _get_rowcount(self): + def rowcount(self): if self._rowcount is not None: return self._rowcount else: return self.context.get_rowcount() - rowcount = property(_get_rowcount) - lastrowid = property(lambda s:s.cursor.lastrowid) - out_parameters = property(lambda s:s.context.out_parameters) + rowcount = property(rowcount) + + def lastrowid(self): + return self.cursor.lastrowid + lastrowid = property(lastrowid) + + def out_parameters(self): + return self.context.out_parameters + out_parameters = property(out_parameters) def _init_metadata(self): self.__props = {} @@ -1423,7 +1452,9 @@ class ResultProxy(object): if self.connection.should_close_with_result: self.connection.close() - keys = property(lambda s:s.__keys) + def keys(self): + return self.__keys + keys = property(keys) def _has_key(self, row, key): try: diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index f2b950f2ec..6122b61b23 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -93,7 +93,9 @@ class TLConnection(base.Connection): self.__session = session self.__opencount = 1 - session = property(lambda s:s.__session) + def session(self): + return self.__session + session = property(session) def _increment_connect(self): self.__opencount += 1 @@ -132,8 +134,13 @@ class TLTransaction(base.Transaction): self._trans = trans self._session = session - connection = property(lambda s:s._trans.connection) - is_active = property(lambda s:s._trans.is_active) + def connection(self): + return self._trans.connection + connection = property(connection) + + def is_active(self): + return self._trans.is_active + is_active = property(is_active) def rollback(self): self._session.rollback() @@ -168,12 +175,13 @@ class TLEngine(base.Engine): super(TLEngine, self).__init__(*args, **kwargs) self.context = util.ThreadLocal() - def _session(self): + def session(self): + "Returns the current thread's TLSession" if not hasattr(self.context, 'session'): self.context.session = TLSession(self) return self.context.session - session = property(_session, doc="Returns the current thread's TLSession") + session = property(session) def contextual_connect(self, **kwargs): """Return a TLConnection which is thread-locally scoped.""" diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 6d5fae5077..089522673c 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -276,8 +276,10 @@ class ScalarAttributeImpl(AttributeImpl): state.dict[self.key] = value state.modified=True - - type = property(lambda self: self.property.columns[0].type) + + def type(self): + self.property.columns[0].type + type = property(type) class MutableScalarAttributeImpl(ScalarAttributeImpl): """represents a scalar value-holding InstrumentedAttribute, which can detect diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 8340ccdcc6..c26e186bdf 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -253,7 +253,7 @@ class OneToManyDP(DependencyProcessor): child = getattr(child, '_state', child) source = state dest = child - if dest is None or (not self.post_update and uowcommit.state_is_deleted(dest)): + if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): return self._verify_canload(child) self.syncrules.execute(source, dest, source, child, clearkeys) @@ -363,7 +363,7 @@ class ManyToOneDP(DependencyProcessor): def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): source = child dest = state - if dest is None or (not self.post_update and uowcommit.state_is_deleted(dest)): + if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): return self._verify_canload(child) self.syncrules.execute(source, dest, dest, child, clearkeys) @@ -491,13 +491,13 @@ class MapperStub(object): def polymorphic_iterator(self): return iter([self]) - def register_dependencies(self, uowcommit): + def _register_dependencies(self, uowcommit): pass - def save_obj(self, *args, **kwargs): + def _save_obj(self, *args, **kwargs): pass - def delete_obj(self, *args, **kwargs): + def _delete_obj(self, *args, **kwargs): pass def primary_mapper(self): diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index ea99d65148..fe781ab05e 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -93,9 +93,9 @@ class AppenderQuery(Query): else: return sess - def _get_session(self): + def session(self): return self.__session() - session = property(_get_session) + session = property(session) def __iter__(self): sess = self.__session() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 6294338ebb..95d118ee40 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -7,11 +7,10 @@ import weakref, warnings from itertools import chain from sqlalchemy import sql, util, exceptions, logging -from sqlalchemy.sql import expression, visitors, operators -from sqlalchemy.sql import util as sqlutil -from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter +from sqlalchemy.sql import expression, visitors, operators, util as sqlutil +from sqlalchemy.sql.expression import _corresponding_column_or_error from sqlalchemy.orm import sync, attributes +from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter, state_str, instance_str from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator __all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry'] @@ -337,7 +336,7 @@ class Mapper(object): self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self) if self.polymorphic_on is None: if self.inherits.polymorphic_on is not None: - self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, raiseerr=False) + self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on) else: raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) @@ -440,10 +439,10 @@ class Mapper(object): primary_key = expression.ColumnSet() for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): - c = self.mapped_table.corresponding_column(col, raiseerr=False) + c = self.mapped_table.corresponding_column(col) if c is None: for cc in self._equivalent_columns[col]: - c = self.mapped_table.corresponding_column(cc, raiseerr=False) + c = self.mapped_table.corresponding_column(cc) if c is not None: break else: @@ -462,7 +461,7 @@ class Mapper(object): break for cc in c.foreign_keys: cc = cc.column - c2 = self.mapped_table.corresponding_column(cc, raiseerr=False) + c2 = self.mapped_table.corresponding_column(cc) if c2 is not None: c = c2 tried.add(c) @@ -651,7 +650,7 @@ class Mapper(object): elif prop is None: mapped_column = [] for c in columns: - mc = self.mapped_table.corresponding_column(c, raiseerr=False) + mc = self.mapped_table.corresponding_column(c) if not mc: raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c)) mapped_column.append(mc) @@ -664,7 +663,7 @@ class Mapper(object): if isinstance(prop, ColumnProperty): # relate the mapper's "select table" to the given ColumnProperty - col = self.select_table.corresponding_column(prop.columns[0], raiseerr=False) + col = self.select_table.corresponding_column(prop.columns[0]) # col might not be present! the selectable given to the mapper need not include "deferred" # columns (included in zblog tests) if col is None: @@ -713,10 +712,10 @@ class Mapper(object): if self._init_properties is not None: for key, prop in self._init_properties.iteritems(): if expression.is_column(prop): - props[key] = self.select_table.corresponding_column(prop) + props[key] = _corresponding_column_or_error(self.select_table, prop) elif (isinstance(prop, list) and expression.is_column(prop[0])): - props[key] = [self.select_table.corresponding_column(c) for c in prop] - self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument) + props[key] = [_corresponding_column_or_error(self.select_table, c) for c in prop] + self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=self.primary_key_argument) def _compile_class(self): """If this mapper is to be a primary mapper (i.e. the @@ -919,27 +918,27 @@ class Mapper(object): def _set_attr_by_column(self, obj, column, value): self._get_col_to_prop(column).setattr(obj._state, column, value) - def save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): + def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. This is called within the context of a UOWTransaction during a flush operation. - `save_obj` issues SQL statements not just for instances mapped + `_save_obj` issues SQL statements not just for instances mapped directly by this mapper, but for instances mapped by all inheriting mappers as well. This is to maintain proper insert ordering among a polymorphic chain of instances. Therefore - save_obj is typically called only on a *base mapper*, or a + _save_obj is typically called only on a *base mapper*, or a mapper which does not inherit from any other mapper. """ if self.__should_log_debug: - self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched")) + self.__log_debug("_save_obj() start, " + (single and "non-batched" or "batched")) - # if batch=false, call save_obj separately for each object + # if batch=false, call _save_obj separately for each object if not single and not self.batch: for state in states: - self.save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return # if session has a connection callable, @@ -970,11 +969,11 @@ class Mapper(object): mapper = _state_mapper(state) instance_key = mapper._identity_key_from_state(state) if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map: - existing = uowtransaction.uow.identity_map[instance_key] + existing = uowtransaction.uow.identity_map[instance_key]._state if not uowtransaction.is_deleted(existing): - raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.state_str(state), str(instance_key), mapperutil.instance_str(existing))) + raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) if self.__should_log_debug: - self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, mapperutil.state_str(state), mapperutil.instance_str(existing))) + self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) uowtransaction.set_row_switch(existing) inserted_objects = util.Set() @@ -997,7 +996,7 @@ class Mapper(object): instance_key = mapper._identity_key_from_state(state) if self.__should_log_debug: - self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.state_str(state), str(instance_key))) + self.__log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity params = {} @@ -1153,7 +1152,7 @@ class Mapper(object): if deferred_props: _expire_state(state, deferred_props) - def delete_obj(self, states, uowtransaction): + def _delete_obj(self, states, uowtransaction): """Issue ``DELETE`` statements for a list of objects. This is called within the context of a UOWTransaction during a @@ -1161,7 +1160,7 @@ class Mapper(object): """ if self.__should_log_debug: - self.__log_debug("delete_obj() start") + self.__log_debug("_delete_obj() start") if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] @@ -1223,7 +1222,7 @@ class Mapper(object): if 'after_delete' in mapper.extension.methods: mapper.extension.after_delete(mapper, connection, state.obj()) - def register_dependencies(self, uowcommit): + def _register_dependencies(self, uowcommit): """Register ``DependencyProcessor`` instances with a ``unitofwork.UOWTransaction``. @@ -1303,7 +1302,7 @@ class Mapper(object): state = instance._state if self.__should_log_debug: - self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) + self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey))) isnew = state.runid != context.runid currentload = not isnew @@ -1337,7 +1336,7 @@ class Mapper(object): instance = attributes.new_instance(self.class_) if self.__should_log_debug: - self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) + self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) state = instance._state instance._entity_name = self.entity_name @@ -1460,7 +1459,7 @@ class Mapper(object): statement = sql.select(needs_tables, cond, use_labels=True) def post_execute(instance, **flags): if self.__should_log_debug: - self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance)) + self.__log_debug("Post query loading instance " + instance_str(instance)) identitykey = self.identity_key_from_instance(instance) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 027cefd692..441a1d7cd6 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -12,10 +12,11 @@ to handle flush-time dependency sorting and processing. """ from sqlalchemy import sql, schema, util, exceptions, logging -from sqlalchemy.sql import util as sql_util, visitors, operators, ColumnElement +from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause +from sqlalchemy.sql import visitors, operators, ColumnElement from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper from sqlalchemy.orm import session as sessionlib -from sqlalchemy.orm import util as mapperutil +from sqlalchemy.orm.util import CascadeOptions from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty from sqlalchemy.exceptions import ArgumentError import warnings @@ -201,13 +202,13 @@ class PropertyLoader(StrategizedProperty): self.strategy_class = strategy_class if cascade is not None: - self.cascade = mapperutil.CascadeOptions(cascade) + self.cascade = CascadeOptions(cascade) else: if private: util.warn_deprecated('private option is deprecated; see docs for details') - self.cascade = mapperutil.CascadeOptions("all, delete-orphan") + self.cascade = CascadeOptions("all, delete-orphan") else: - self.cascade = mapperutil.CascadeOptions("save-update, merge") + self.cascade = CascadeOptions("save-update, merge") if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade): raise exceptions.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade") @@ -312,8 +313,10 @@ class PropertyLoader(StrategizedProperty): def _optimized_compare(self, value, value_is_parent=False): return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent) - - private = property(lambda s:s.cascade.delete_orphan) + + def private(self): + return self.cascade.delete_orphan + private = property(private) def create_strategy(self): if self.strategy_class: @@ -456,7 +459,7 @@ class PropertyLoader(StrategizedProperty): # to the "polymorphic" selectable as needed). since this is an API change, put an explicit check/ # error message in case its the "old" way. if self.loads_polymorphic: - vis = sql_util.ColumnsInClause(self.mapper.select_table) + vis = ColumnsInClause(self.mapper.select_table) vis.traverse(self.primaryjoin) if self.secondaryjoin: vis.traverse(self.secondaryjoin) @@ -469,12 +472,12 @@ class PropertyLoader(StrategizedProperty): def col_is_part_of_mappings(col): if self.secondary is None: - return self.parent.mapped_table.corresponding_column(col, raiseerr=False) is not None or \ - self.target.corresponding_column(col, raiseerr=False) is not None + return self.parent.mapped_table.corresponding_column(col) is not None or \ + self.target.corresponding_column(col) is not None else: - return self.parent.mapped_table.corresponding_column(col, raiseerr=False) is not None or \ - self.target.corresponding_column(col, raiseerr=False) is not None or \ - self.secondary.corresponding_column(col, raiseerr=False) is not None + return self.parent.mapped_table.corresponding_column(col) is not None or \ + self.target.corresponding_column(col) is not None or \ + self.secondary.corresponding_column(col) is not None if self.foreign_keys: self._opposite_side = util.Set() @@ -597,13 +600,13 @@ class PropertyLoader(StrategizedProperty): target_equivalents = self.mapper._get_equivalent_columns() if self.secondaryjoin: - self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True) + self.polymorphic_secondaryjoin = ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True) self.polymorphic_primaryjoin = self.primaryjoin else: if self.direction is sync.ONETOMANY: - self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) + self.polymorphic_primaryjoin = ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) elif self.direction is sync.MANYTOONE: - self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) + self.polymorphic_primaryjoin = ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) self.polymorphic_secondaryjoin = None # load "polymorphic" versions of the columns present in "remote_side" - this is @@ -612,7 +615,7 @@ class PropertyLoader(StrategizedProperty): if self.secondary and self.secondary.columns.contains_column(c): continue for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): - corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False) + corr = self.mapper.select_table.corresponding_column(equiv) if corr: self.remote_side.add(corr) break @@ -686,11 +689,11 @@ class PropertyLoader(StrategizedProperty): if polymorphic_parent: # adapt the "parent" side of our join condition to the "polymorphic" select of the parent if self.direction is sync.ONETOMANY: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) elif self.direction is sync.MANYTOONE: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + primaryjoin = ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) elif self.secondaryjoin: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) if secondaryjoin is not None: if secondary and not primary: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 902a4fd3be..2c9a1d0ff5 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -983,8 +983,8 @@ class Query(object): cf.update(sql_util.find_columns(o)) if adapt_criterion: - context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns] - cf = [from_obj.corresponding_column(c, raiseerr=False) or c for c in cf] + context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] + cf = [from_obj.corresponding_column(c) or c for c in cf] s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args()) @@ -1004,7 +1004,7 @@ class Query(object): statement.append_order_by(*context.eager_order_by) else: if adapt_criterion: - context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns] + context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] self._primary_adapter = mapperutil.create_row_adapter(from_obj, self.table) if adapt_criterion or self._distinct: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d6d1d1ff6b..541590b82d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -663,14 +663,14 @@ class Session(object): q = q.add_entity(ent) return q - def _sql(self): + def sql(self): class SQLProxy(object): def __getattr__(self, key): def call(*args, **kwargs): kwargs[engine] = self.engine return getattr(sql, key)(*args, **kwargs) - sql = property(_sql) + sql = property(sql) def _autoflush(self): if self.autoflush and (self.transaction is None or self.transaction.autoflush): @@ -1079,26 +1079,35 @@ class Session(object): return True return False - dirty = property(lambda s:s.uow.locate_dirty(), - doc="""A ``Set`` of all instances marked as 'dirty' within this ``Session``. + def dirty(self): + """Return a ``Set`` of all instances marked as 'dirty' within this ``Session``. - Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection - modification operations will mark an instance as 'dirty' and place it in this set, - even if there is no net change to the attribute's value. At flush time, the value - of each attribute is compared to its previously saved value, - and if there's no net change, no SQL operation will occur (this is a more expensive - operation so it's only done at flush time). + Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection + modification operations will mark an instance as 'dirty' and place it in this set, + even if there is no net change to the attribute's value. At flush time, the value + of each attribute is compared to its previously saved value, + and if there's no net change, no SQL operation will occur (this is a more expensive + operation so it's only done at flush time). - To check if an instance has actionable net changes to its attributes, use the - is_modified() method. - """) - - deleted = property(lambda s:util.IdentitySet(s.uow.deleted.values()), - doc="A ``Set`` of all instances marked as 'deleted' within this ``Session``") - - new = property(lambda s:util.IdentitySet(s.uow.new.values()), - doc="A ``Set`` of all instances marked as 'new' within this ``Session``.") + To check if an instance has actionable net changes to its attributes, use the + is_modified() method. + """ + return self.uow.locate_dirty() + dirty = property(dirty) + + def deleted(self): + "Return a ``Set`` of all instances marked as 'deleted' within this ``Session``" + + return util.IdentitySet(self.uow.deleted.values()) + deleted = property(deleted) + + def new(self): + "Return a ``Set`` of all instances marked as 'new' within this ``Session``." + + return util.IdentitySet(self.uow.new.values()) + new = property(new) + def _expire_state(state, attribute_names): """Standalone expire instance function. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 26ac3703eb..59d784ecf2 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -164,11 +164,6 @@ class UnitOfWork(object): def flush(self, session, objects=None): """create a dependency tree of all pending SQL operations within this unit of work and execute.""" - - # this context will track all the objects we want to save/update/delete, - # and organize a hierarchical dependency structure. it also handles - # communication with the mappers and relationships to fire off SQL - # and synchronize attributes between related objects. dirty = [x for x in self.identity_map.all_states() if x.modified @@ -325,35 +320,30 @@ class UOWTransaction(object): else: task.append(state, listonly, isdelete=isdelete, **kwargs) - def set_row_switch(self, obj): + def set_row_switch(self, state): """mark a deleted object as a 'row switch'. this indicates that an INSERT statement elsewhere corresponds to this DELETE; the INSERT is converted to an UPDATE and the DELETE does not occur. """ - mapper = object_mapper(obj) + mapper = _state_mapper(state) task = self.get_task_by_mapper(mapper) - taskelement = task._objects[obj._state] + taskelement = task._objects[state] taskelement.isdelete = "rowswitch" def unregister_object(self, obj): """remove an object from its parent UOWTask. - called by mapper.save_obj() when an 'identity switch' is detected, so that + called by mapper._save_obj() when an 'identity switch' is detected, so that no further operations occur upon the instance.""" mapper = object_mapper(obj) task = self.get_task_by_mapper(mapper) if obj._state in task._objects: task.delete(obj._state) - def is_deleted(self, obj): - """return true if the given object is marked as deleted within this UOWTransaction.""" + def is_deleted(self, state): + """return true if the given state is marked as deleted within this UOWTransaction.""" - mapper = object_mapper(obj) - task = self.get_task_by_mapper(mapper) - return task.is_deleted(obj._state) - - def state_is_deleted(self, state): mapper = _state_mapper(state) task = self.get_task_by_mapper(mapper) return task.is_deleted(state) @@ -375,11 +365,11 @@ class UOWTransaction(object): base_task = self.tasks[base_mapper] else: self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper) - base_mapper.register_dependencies(self) + base_mapper._register_dependencies(self) if mapper not in self.tasks: self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task) - mapper.register_dependencies(self) + mapper._register_dependencies(self) else: task = self.tasks[mapper] @@ -581,7 +571,7 @@ class UOWTask(object): # postupdates are UPDATED immeditely (for now) # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns # instead of __eq__ - self.mapper.save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) + self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) def delete(self, obj): """remove the given object from this UOWTask, if present.""" @@ -940,10 +930,10 @@ class UOWExecutor(object): self.execute_delete_steps(trans, task) def save_objects(self, trans, task): - task.mapper.save_obj(task.polymorphic_tosave_objects, trans) + task.mapper._save_obj(task.polymorphic_tosave_objects, trans) def delete_objects(self, trans, task): - task.mapper.delete_obj(task.polymorphic_todelete_objects, trans) + task.mapper._delete_obj(task.polymorphic_todelete_objects, trans) def execute_dependency(self, trans, dep, isdelete): dep.execute(trans, isdelete) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 6e31b46468..7b76183be0 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -154,7 +154,7 @@ class AliasedClauses(object): """return the aliased version of the given column, creating a new label for it if not already present in this AliasedClauses.""" - conv = self.alias.corresponding_column(column, raiseerr=False) + conv = self.alias.corresponding_column(column) if conv: return conv @@ -199,13 +199,13 @@ def create_row_adapter(from_, to, equivalent_columns=None): map = {} for c in to.c: - corr = from_.corresponding_column(c, raiseerr=False) + corr = from_.corresponding_column(c) if corr: map[c] = corr elif equivalent_columns: if c in equivalent_columns: for c2 in equivalent_columns[c]: - corr = from_.corresponding_column(c2, raiseerr=False) + corr = from_.corresponding_column(c2) if corr: map[c] = corr break diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 15b35b96a3..e0d45870bb 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -20,8 +20,6 @@ objects as well as the visitor interface, so that the schema package import re, inspect from sqlalchemy import types, exceptions, util, databases from sqlalchemy.sql import expression, visitors -import sqlalchemy - URL = None @@ -43,9 +41,6 @@ class SchemaItem(object): if item is not None: item._set_parent(self) - def _get_parent(self): - raise NotImplementedError() - def _set_parent(self, parent): """Associate with this SchemaItem's parent object.""" @@ -58,20 +53,12 @@ class SchemaItem(object): def __repr__(self): return "%s()" % self.__class__.__name__ - def _get_bind(self, raiseerr=False): - """Return the engine or None if no engine.""" + def bind(self): + """Return the connectable associated with this SchemaItem.""" - if raiseerr: - m = self.metadata - e = m and m.bind or None - if e is None: - raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") - else: - return e - else: - m = self.metadata - return m and m.bind or None - bind = property(lambda s:s._get_bind()) + m = self.metadata + return m and m.bind or None + bind = property(bind) def info(self): try: @@ -231,7 +218,7 @@ class Table(SchemaItem, expression.TableClause): if autoload_with: autoload_with.reflecttable(self, include_columns=include_columns) else: - metadata._get_bind(raiseerr=True).reflecttable(self, include_columns=include_columns) + _bind_or_error(metadata).reflecttable(self, include_columns=include_columns) # initialize all the column, etc. objects. done after # reflection to allow user-overrides @@ -269,9 +256,6 @@ class Table(SchemaItem, expression.TableClause): constraint._set_parent(self) - def _get_parent(self): - return self.metadata - def _set_parent(self, metadata): metadata.tables[_get_table_key(self.name, self.schema)] = self self.metadata = metadata @@ -289,7 +273,7 @@ class Table(SchemaItem, expression.TableClause): """Return True if this table exists.""" if bind is None: - bind = self._get_bind(raiseerr=True) + bind = _bind_or_error(self) def do(conn): return conn.dialect.has_table(conn, self.name, schema=self.schema) @@ -463,9 +447,10 @@ class Column(SchemaItem, expression._ColumnClause): else: return self.description - def _get_bind(self): + def bind(self): return self.table.bind - + bind = property(bind) + def references(self, column): """return true if this column references the given column via foreign key""" for fk in self.foreign_keys: @@ -496,9 +481,6 @@ class Column(SchemaItem, expression._ColumnClause): [(self.table and "table=<%s>" % self.table.description or "")] + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) - def _get_parent(self): - return self.table - def _set_parent(self, table): self.metadata = table.metadata if getattr(self, 'table', None) is not None: @@ -622,15 +604,15 @@ class ForeignKey(SchemaItem): def references(self, table): """Return True if the given table is referenced by this ``ForeignKey``.""" - return table.corresponding_column(self.column, False) is not None + return table.corresponding_column(self.column) is not None def get_referent(self, table): """return the column in the given table referenced by this ``ForeignKey``, or None if this ``ForeignKey`` does not reference the given table. """ - return table.corresponding_column(self.column, False) + return table.corresponding_column(self.column) - def _init_column(self): + def column(self): # ForeignKey inits its remote column as late as possible, so tables can # be defined without dependencies if self._column is None: @@ -674,10 +656,7 @@ class ForeignKey(SchemaItem): self.parent.type = self._column.type return self._column - column = property(lambda s: s._init_column()) - - def _get_parent(self): - return self.parent + column = property(column) def _set_parent(self, column): self.parent = column @@ -704,9 +683,6 @@ class DefaultGenerator(SchemaItem): self.for_update = for_update self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') - def _get_parent(self): - return getattr(self, 'column', None) - def _set_parent(self, column): self.column = column self.metadata = self.column.table.metadata @@ -717,7 +693,7 @@ class DefaultGenerator(SchemaItem): def execute(self, bind=None, **kwargs): if bind is None: - bind = self._get_bind(raiseerr=True) + bind = _bind_or_error(self) return bind._execute_default(self, **kwargs) def __repr__(self): @@ -798,14 +774,14 @@ class Sequence(DefaultGenerator): """Creates this sequence in the database.""" if bind is None: - bind = self._get_bind(raiseerr=True) + bind = _bind_or_error(self) bind.create(self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): """Drops this sequence from the database.""" if bind is None: - bind = self._get_bind(raiseerr=True) + bind = _bind_or_error(self) bind.drop(self, checkfirst=checkfirst) @@ -838,20 +814,17 @@ class Constraint(SchemaItem): def copy(self): raise NotImplementedError() - def _get_parent(self): - return getattr(self, 'table', None) - class CheckConstraint(Constraint): def __init__(self, sqltext, name=None): super(CheckConstraint, self).__init__(name) self.sqltext = sqltext - def _visit_name(self): + def __visit_name__(self): if isinstance(self.parent, Table): return "check_constraint" else: return "column_check_constraint" - __visit_name__ = property(_visit_name) + __visit_name__ = property(__visit_name__) def _set_parent(self, parent): self.parent = parent @@ -976,9 +949,6 @@ class Index(SchemaItem): for column in args: self.append_column(column) - def _get_parent(self): - return self.table - def _set_parent(self, table): self.table = table self.metadata = table.metadata @@ -1002,17 +972,15 @@ class Index(SchemaItem): self.columns.append(column) def create(self, bind=None): - if bind is not None: - bind.create(self) - else: - self._get_bind(raiseerr=True).create(self) + if bind is None: + bind = _bind_or_error(self) + bind.create(self) return self def drop(self, bind=None): - if bind is not None: - bind.drop(self) - else: - self._get_bind(raiseerr=True).drop(self) + if bind is None: + bind = _bind_or_error(self) + bind.drop(self) def __str__(self): return repr(self) @@ -1113,6 +1081,17 @@ class MetaData(SchemaItem): self._bind = bind connect = util.deprecated(connect) + def bind(self): + """An Engine or Connection to which this MetaData is bound. + + This property may be assigned an ``Engine`` or + ``Connection``, or assigned a string or URL to + automatically create a basic ``Engine`` for this bind + with ``create_engine()``. + """ + + return self._bind + def _bind_to(self, bind): """Bind this MetaData to an Engine, Connection, string or URL.""" @@ -1121,17 +1100,11 @@ class MetaData(SchemaItem): from sqlalchemy.engine.url import URL if isinstance(bind, (basestring, URL)): - self._bind = sqlalchemy.create_engine(bind) + from sqlalchemy import create_engine + self._bind = create_engine(bind) else: self._bind = bind - - bind = property(lambda self: self._bind, _bind_to, doc= - """An Engine or Connection to which this MetaData is bound. - - This property may be assigned an ``Engine`` or - ``Connection``, or assigned a string or URL to - automatically create a basic ``Engine`` for this bind - with ``create_engine()``.""") + bind = property(bind, _bind_to) def clear(self): self.tables.clear() @@ -1141,15 +1114,12 @@ class MetaData(SchemaItem): del self.tables[table.key] def table_iterator(self, reverse=True, tables=None): - from sqlalchemy.sql import util as sql_util + from sqlalchemy.sql.util import sort_tables if tables is None: tables = self.tables.values() else: tables = util.Set(tables).intersection(self.tables.values()) - return iter(sql_util.sort_tables(tables, reverse=reverse)) - - def _get_parent(self): - return None + return iter(sort_tables(tables, reverse=reverse)) def reflect(self, bind=None, schema=None, only=None): """Load all available table definitions from the database. @@ -1184,7 +1154,7 @@ class MetaData(SchemaItem): reflect_opts = {'autoload': True} if bind is None: - bind = self._get_bind(raiseerr=True) + bind = _bind_or_error(self) conn = None else: reflect_opts['autoload_with'] = bind @@ -1230,7 +1200,7 @@ class MetaData(SchemaItem): """ if bind is None: - bind = self._get_bind(raiseerr=True) + bind = _bind_or_error(self) bind.create(self, checkfirst=checkfirst, tables=tables) def drop_all(self, bind=None, tables=None, checkfirst=True): @@ -1249,17 +1219,9 @@ class MetaData(SchemaItem): """ if bind is None: - bind = self._get_bind(raiseerr=True) + bind = _bind_or_error(self) bind.drop(self, checkfirst=checkfirst, tables=tables) - - def _get_bind(self, raiseerr=False): - if not self.is_bound(): - if raiseerr: - raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") - else: - return None - return self._bind - + class ThreadLocalMetaData(MetaData): """A MetaData variant that presents a different ``bind`` in every thread. @@ -1279,14 +1241,10 @@ class ThreadLocalMetaData(MetaData): __visit_name__ = 'metadata' def __init__(self): - """Construct a ThreadLocalMetaData. - - Takes no arguments. - """ - + """Construct a ThreadLocalMetaData.""" + self.context = util.ThreadLocal() self.__engines = {} - super(ThreadLocalMetaData, self).__init__() # @deprecated @@ -1315,18 +1273,14 @@ class ThreadLocalMetaData(MetaData): self._bind_to(bind) connect = util.deprecated(connect) - def _get_bind(self, raiseerr=False): - """The bound ``Engine`` or ``Connectable`` for this thread.""" + def bind(self): + """The bound Engine or Connection for this thread. + + This property may be assigned an Engine or Connection, + or assigned a string or URL to automatically create a + basic Engine for this bind with ``create_engine()``.""" - if hasattr(self.context, '_engine'): - return self.context._engine - else: - if raiseerr: - raise exceptions.InvalidRequestError( - "This ThreadLocalMetaData is not bound to any Engine or " - "Connection.") - else: - return None + return getattr(self.context, '_engine', None) def _bind_to(self, bind): """Bind to a Connectable in the caller's thread.""" @@ -1349,12 +1303,7 @@ class ThreadLocalMetaData(MetaData): self.__engines[bind] = bind self.context._engine = bind - bind = property(_get_bind, _bind_to, doc= - """The bound Engine or Connection for this thread. - - This property may be assigned an Engine or Connection, - or assigned a string or URL to automatically create a - basic Engine for this bind with ``create_engine()``.""") + bind = property(bind, _bind_to) def is_bound(self): """True if there is a bind for this thread.""" @@ -1368,8 +1317,13 @@ class ThreadLocalMetaData(MetaData): if hasattr(e, 'dispose'): e.dispose() - class SchemaVisitor(visitors.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" __traverse_options__ = {'schema_visitor':True} + +def _bind_or_error(schemaitem): + bind = schemaitem.bind + if not bind: + raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") + return bind \ No newline at end of file diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 55001dc700..a448fa6d3d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -820,6 +820,12 @@ def _literal_as_binds(element, name=None, type_=None): else: return element +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column(column, require_embedded=require_embedded) + if not c: + raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + return c + def _selectable(element): if hasattr(element, '__selectable__'): return element.__selectable__() @@ -958,13 +964,8 @@ class ClauseElement(object): return False - def _find_engine(self): - """Default strategy for locating an engine within the clause element. - - Relies upon a local engine property, or looks in the *from* - objects which ultimately have to contain Tables or - TableClauses. - """ + def bind(self): + """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" try: if self._bind is not None: @@ -979,8 +980,7 @@ class ClauseElement(object): return engine else: return None - - bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""") + bind = property(bind) def execute(self, *multiparams, **params): """Compile and execute this ``ClauseElement``.""" @@ -1406,7 +1406,6 @@ class ColumnElement(ClauseElement, _CompareMixin): return self._base_columns self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) return self._base_columns - base_columns = property(base_columns) def proxy_set(self): @@ -1603,7 +1602,7 @@ class FromClause(Selectable): from sqlalchemy.sql import util return util.ClauseAdapter(alias).traverse(self, clone=True) - def corresponding_column(self, column, raiseerr=True, require_embedded=False): + def corresponding_column(self, column, require_embedded=False): """Given a ``ColumnElement``, return the exported ``ColumnElement`` object from this ``Selectable`` which corresponds to that original ``Column`` via a common anscestor column. @@ -1611,10 +1610,6 @@ class FromClause(Selectable): column the target ``ColumnElement`` to be matched - raiseerr - if True, raise an error if the given ``ColumnElement`` could - not be matched. if False, non-matches will return None. - require_embedded only return corresponding columns for the given ``ColumnElement``, if the given ``ColumnElement`` is @@ -1624,12 +1619,6 @@ class FromClause(Selectable): of this ``FromClause``. """ - if require_embedded and column not in self._get_all_embedded_columns(): - if not raiseerr: - return None - else: - raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table.description)) - # dont dig around if the column is locally present if self.c.contains_column(column): return column @@ -1638,16 +1627,12 @@ class FromClause(Selectable): target_set = column.proxy_set for c in self.c + [self.oid_column]: i = c.proxy_set.intersection(target_set) - if i and (intersect is None or len(i) > len(intersect)): + if i and \ + (not require_embedded or c.proxy_set.issuperset(target_set)) and \ + (intersect is None or len(i) > len(intersect)): col, intersect = c, i - if col: - return col - - if not raiseerr: - return None - else: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.description)) - + return col + def description(self): """a brief description of this FromClause. @@ -1666,17 +1651,6 @@ class FromClause(Selectable): if hasattr(self, attr): delattr(self, attr) - def _get_all_embedded_columns(self): - if hasattr(self, '_embedded_columns'): - return self._embedded_columns - ret = util.Set() - class FindCols(visitors.ClauseVisitor): - def visit_column(self, col): - ret.add(col) - FindCols().traverse(self) - self._embedded_columns = ret - return ret - def _expr_attr_func(name): def attr(self): try: @@ -1684,12 +1658,11 @@ class FromClause(Selectable): except AttributeError: self._export_columns() return getattr(self, name) - return attr + return property(attr) - columns = property(_expr_attr_func('_columns')) - c = property(_expr_attr_func('_columns')) - primary_key = property(_expr_attr_func('_primary_key')) - foreign_keys = property(_expr_attr_func('_foreign_keys')) + columns = c = _expr_attr_func('_columns') + primary_key = _expr_attr_func('_primary_key') + foreign_keys = _expr_attr_func('_foreign_keys') def _export_columns(self, columns=None): """Initialize column collections.""" @@ -1881,14 +1854,14 @@ class _TextClause(ClauseElement): for b in bindparams: self.bindparams[b.key] = b - def _get_type(self): + def type(self): if self.typemap is not None and len(self.typemap) == 1: return list(self.typemap)[0] else: return None - type = property(_get_type) + type = property(type) - columns = property(lambda s:[]) + columns = [] def _copy_internals(self, clone=_clone): self.bindparams = dict([(b.key, clone(b)) for b in self.bindparams.values()]) @@ -2329,7 +2302,12 @@ class Join(FromClause): else: return and_(*crit) - def _get_folded_equivalents(self, equivs=None): + def _folded_equivalents(self, equivs=None): + """Returns the column list of this Join with all equivalently-named, + equated columns folded into one column, where 'equated' means they are + equated to each other in the ON clause of this join. + """ + if self.__folded_equivalents is not None: return self.__folded_equivalents if equivs is None: @@ -2342,11 +2320,11 @@ class Join(FromClause): LocateEquivs().traverse(self.onclause) collist = [] if isinstance(self.left, Join): - left = self.left._get_folded_equivalents(equivs) + left = self.left._folded_equivalents(equivs) else: left = list(self.left.columns) if isinstance(self.right, Join): - right = self.right._get_folded_equivalents(equivs) + right = self.right._folded_equivalents(equivs) else: right = list(self.right.columns) used = util.Set() @@ -2359,10 +2337,7 @@ class Join(FromClause): collist.append(c) self.__folded_equivalents = collist return self.__folded_equivalents - - folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, " - "equated columns folded into one column, where 'equated' means they are " - "equated to each other in the ON clause of this join.") + folded_equivalents = property(_folded_equivalents) def select(self, whereclause = None, fold_equivalents=False, **kwargs): """Create a ``Select`` from this ``Join``. @@ -2391,7 +2366,9 @@ class Join(FromClause): return select(collist, whereclause, from_obj=[self], **kwargs) - bind = property(lambda s:s.left.bind or s.right.bind) + def bind(self): + return self.left.bind or self.right.bind + bind = property(bind) def alias(self, name=None): """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it. @@ -2474,8 +2451,10 @@ class Alias(FromClause): def _get_from_objects(self, **modifiers): return [self] - - bind = property(lambda s: s.selectable.bind) + + def bind(self): + return self.selectable.bind + bind = property(bind) class _ColumnElementAdapter(ColumnElement): """Adapts a ClauseElement which may or may not be a @@ -2486,9 +2465,14 @@ class _ColumnElementAdapter(ColumnElement): def __init__(self, elem): self.elem = elem self.type = getattr(elem, 'type', None) - - key = property(lambda s: s.elem.key) - _label = property(lambda s: s.elem._label) + + def key(self): + return self.elem.key + key = property(key) + + def _label(self): + return self.elem._label + _label = property(_label) def _copy_internals(self, clone=_clone): self.elem = clone(self.elem) @@ -2520,8 +2504,13 @@ class _FromGrouping(FromClause): def __init__(self, elem): self.elem = elem - columns = c = property(lambda s:s.elem.columns) - _hide_froms = property(lambda s:s.elem._hide_froms) + def columns(self): + return self.elem.columns + columns = c = property(columns) + + def _hide_froms(self): + return self.elem._hide_froms + _hide_froms = property(_hide_froms) def get_children(self, **kwargs): return self.elem, @@ -2553,23 +2542,34 @@ class _Label(ColumnElement): self.obj = obj.self_group(against=operators.as_) self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) - key = property(lambda s: s.name) - _label = property(lambda s: s.name) - proxies = property(lambda s:s.obj.proxies) - base_columns = property(lambda s:s.obj.base_columns) - proxy_set = property(lambda s:s.obj.proxy_set) - primary_key = property(lambda s:s.obj.primary_key) - foreign_keys = property(lambda s:s.obj.foreign_keys) + def key(self): + return self.name + key = property(key) + + def _label(self): + return self.name + _label = property(_label) + + def _proxy_attr(name): + def attr(self): + return getattr(self.obj, name) + return property(attr) + + proxies = _proxy_attr('proxies') + base_columns = _proxy_attr('base_columns') + proxy_set = _proxy_attr('proxy_set') + primary_key = _proxy_attr('primary_key') + foreign_keys = _proxy_attr('foreign_keys') def expression_element(self): return self.obj - def _copy_internals(self, clone=_clone): - self.obj = clone(self.obj) - def get_children(self, **kwargs): return self.obj, + def _copy_internals(self, clone=_clone): + self.obj = clone(self.obj) + def _get_from_objects(self, **modifiers): return self.obj._get_from_objects(**modifiers) @@ -2623,13 +2623,8 @@ class _ColumnClause(ColumnElement): # ColumnClause is immutable return self - def _get_label(self): - """Generate a 'label' for this column. - - The label is a product of the parent table name and column - name, and is treated as a unique identifier of this ``Column`` - across all ``Tables`` and derived selectables for a particular - metadata collection. + def _label(self): + """Generate a 'label' string for this column. """ # for a "literal" column, we've no idea what the text is @@ -2647,7 +2642,7 @@ class _ColumnClause(ColumnElement): self.__label = self.name return self.__label - _label = property(_get_label) + _label = property(_label) def label(self, name): # if going off the "__label" property and its None, we have @@ -2903,10 +2898,9 @@ class _ScalarSelect(_Grouping): raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") self.type = cols[0].type - def _no_cols(self): + def columns(self): raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") - c = property(_no_cols) - columns = c + columns = c = property(columns) def self_group(self, **kwargs): return self @@ -2979,14 +2973,15 @@ class CompoundSelect(_SelectBaseMixin, FromClause): for t in s._table_iterator(): yield t - def _find_engine(self): + def bind(self): for s in self.selects: - e = s._find_engine() + e = s.bind if e: return e else: return None - + bind = property(bind) + class Select(_SelectBaseMixin, FromClause): """Represents a ``SELECT`` statement. @@ -3115,15 +3110,18 @@ class Select(_SelectBaseMixin, FromClause): self._all_froms = froms return froms - def _get_inner_columns(self): + def inner_columns(self): + """a collection of all ColumnElement expressions which would + be rendered into the columns clause of the resulting SELECT statement. + """ + for c in self._raw_columns: if isinstance(c, Selectable): for co in c.columns: yield co else: yield c - - inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""") + inner_columns = property(inner_columns) def is_derived_from(self, fromclause): if self in util.Set(fromclause._cloned_set): @@ -3412,11 +3410,7 @@ class Select(_SelectBaseMixin, FromClause): if isinstance(t, TableClause): yield t - def _find_engine(self): - """Try to return a Engine, either explicitly set in this - object, or searched within the from clauses for one. - """ - + def bind(self): if self._bind is not None: return self._bind for f in self._froms: @@ -3436,7 +3430,8 @@ class Select(_SelectBaseMixin, FromClause): self._bind = e return e return None - + bind = property(bind) + class _UpdateBase(ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" @@ -3459,9 +3454,10 @@ class _UpdateBase(ClauseElement): else: return parameters - def _find_engine(self): + def bind(self): return self.table.bind - + bind = property(bind) + class Insert(_UpdateBase): def __init__(self, table, values=None, inline=False, **kwargs): self.table = table diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d6b10a78a3..b45c0425c8 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -198,10 +198,10 @@ class ClauseAdapter(AbstractClauseProcessor): if self.exclude is not None: if col in self.exclude: return None - newcol = self.selectable.corresponding_column(col, raiseerr=False, require_embedded=True) + newcol = self.selectable.corresponding_column(col, require_embedded=True) if newcol is None and self.equivalents is not None and col in self.equivalents: for equiv in self.equivalents[col]: - newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True) + newcol = self.selectable.corresponding_column(equiv, require_embedded=True) if newcol: return newcol return newcol diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 4796288dfa..1b9959ec43 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -24,6 +24,7 @@ table2 = Table('table2', metadata, class SelectableTest(AssertMixin): def testdistance(self): + # same column three times s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far @@ -50,6 +51,9 @@ class SelectableTest(AssertMixin): def testselectontable(self): sel = select([table, table2], use_labels=True) assert sel.corresponding_column(table.c.col1) is sel.c.table1_col1 + assert sel.corresponding_column(table.c.col1, require_embedded=True) is sel.c.table1_col1 + assert table.corresponding_column(sel.c.table1_col1) is table.c.col1 + assert table.corresponding_column(sel.c.table1_col1, require_embedded=True) is None def testjoinagainstjoin(self): j = outerjoin(table, table2, table.c.col1==table2.c.col2) -- 2.47.3