]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- cleanup; lambdas removed from properties; properties mirror same-named functions...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Dec 2007 05:40:06 +0000 (05:40 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Dec 2007 05:40:06 +0000 (05:40 +0000)
- corresponding_column() integrates "require_embedded" flag with other set arithmetic

17 files changed:
CHANGES
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index 7f48b7e0892f7439949e4b387972905f4c614baa..1077b4c30d83d2a70dca5b93bc6657f2e16e1360 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -73,6 +73,18 @@ CHANGES
      issued directly by the ORM in the form of UPDATE statements, by setting
      the flag "passive_cascades=False".
 
+   - new synonym() behavior: an attribute will be placed on the mapped
+     class, if one does not exist already, in all cases. if a property
+     already exists on the class, the synonym will decorate the property
+     with the appropriate comparison operators so that it can be used in in
+     column expressions just like any other mapped attribute (i.e. usable in
+     filter(), etc.) the "proxy=True" flag is deprecated and no longer means
+     anything. Additionally, the flag "map_column=True" will automatically
+     generate a ColumnProperty corresponding to the name of the synonym,
+     i.e.: 'somename':synonym('_somename', map_column=True) will map the
+     column named 'somename' to the attribute '_somename'. See the example
+     in the mapper docs. [ticket:801]
+
    - Query.select_from() now replaces all existing FROM criterion with
      the given argument; the previous behavior of constructing a list
      of FROM clauses was generally not useful as is required 
@@ -130,18 +142,6 @@ CHANGES
      disregarding any existing filter, join, group_by or other criterion
      which has been configured. [ticket:893]
           
-   - new synonym() behavior: an attribute will be placed on the mapped
-     class, if one does not exist already, in all cases. if a property
-     already exists on the class, the synonym will decorate the property
-     with the appropriate comparison operators so that it can be used in in
-     column expressions just like any other mapped attribute (i.e. usable in
-     filter(), etc.) the "proxy=True" flag is deprecated and no longer means
-     anything. Additionally, the flag "map_column=True" will automatically
-     generate a ColumnProperty corresponding to the name of the synonym,
-     i.e.: 'somename':synonym('_somename', map_column=True) will map the
-     column named 'somename' to the attribute '_somename'. See the example
-     in the mapper docs. [ticket:801]
-
    - added support for version_id_col in conjunction with inheriting mappers.
      version_id_col is typically set on the base mapper in an inheritance
      relationship where it takes effect for all inheriting mappers. 
@@ -159,7 +159,8 @@ CHANGES
      mapper.get_attr_by_column(), mapper.set_attr_by_column(), 
      mapper.pks_by_table, mapper.cascade_callable(), 
      MapperProperty.cascade_callable(), mapper.canload(),
-     mapper._mapper_registry, attributes.AttributeManager
+     mapper.save_obj(), mapper.delete_obj(), mapper._mapper_registry, 
+     attributes.AttributeManager
 
    - Assigning an incompatible collection type to a relation attribute now
      raises TypeError instead of sqlalchemy's ArgumentError.
index 098bd33c8972a4eb9bbd9510a8d447a2d5c1b667..1654677b746e3bbe094a0edacbf9cc6a37a2e37c 100644 (file)
@@ -915,7 +915,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
             # translate for schema-qualified table aliases
             t = self._schema_aliased_table(column.table)
             if t is not None:
-                return self.process(t.corresponding_column(column))
+                return self.process(expression._corresponding_column_or_error(t, column))
         return super(MSSQLCompiler, self).visit_column(column, **kwargs)
 
     def visit_binary(self, binary, **kwargs):
index 801d4e28c15b65ad03cdfeeaacbc6bb369d43634..3219e6c5b81706ba16f0c3b2454903babb79e612 100644 (file)
@@ -543,12 +543,6 @@ class Connection(Connectable):
         self.__savepoint_seq = 0
         self.__branch = _branch
 
-    def _get_connection(self):
-        try:
-            return self.__connection
-        except AttributeError:
-            raise exceptions.InvalidRequestError("This Connection is closed")
-
     def _branch(self):
         """Return a new Connection which references this Connection's 
         engine and connection; but does not have close_with_result enabled,
@@ -559,16 +553,35 @@ class Connection(Connectable):
         """
         return Connection(self.engine, self.__connection, _branch=True)
 
-    dialect = property(lambda s:s.engine.dialect, doc="Dialect used by this Connection.")
-    connection = property(_get_connection, doc="The underlying DB-API connection managed by this Connection.")
-    should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
+    def dialect(self):
+        "Dialect used by this Connection."
+        
+        return self.engine.dialect
+    dialect = property(dialect)
+    
+    def connection(self):
+        "The underlying DB-API connection managed by this Connection."
 
-    info = property(lambda s: s._get_connection().info,
-                    doc=("A collection of per-DB-API connection instance "
-                         "properties."))
-    properties = property(lambda s: s._get_connection().info,
-                          doc=("An alias for the .info collection, will be "
-                               "removed in 0.5."))
+        try:
+            return self.__connection
+        except AttributeError:
+            raise exceptions.InvalidRequestError("This Connection is closed")
+    connection = property(connection)
+    
+    def should_close_with_result(self):
+        """Indicates if this Connection should be closed when a corresponding
+        ResultProxy is closed; this is essentially an auto-release mode.
+        """
+
+        return self.__close_with_result
+    should_close_with_result = property(should_close_with_result)
+
+    def info(self):
+        """A collection of per-DB-API connection instance properties."""
+        return self.connection.info
+    info = property(info)
+
+    properties = property(info, doc="""An alias for the .info collection, will be removed in 0.5.""")
 
     def connect(self):
         """Returns self.
@@ -940,9 +953,15 @@ class Transaction(object):
         self._connection = connection
         self._parent = parent or self
         self._is_active = True
+    
+    def connection(self):
+        "The Connection object referenced by this Transaction"
+        return self._connection
+    connection = property(connection)
 
-    connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction")
-    is_active = property(lambda s:s._is_active)
+    def is_active(self):
+        return self._is_active
+    is_active = property(is_active)
 
     def close(self):
         """Close this transaction.
@@ -1041,7 +1060,12 @@ class Engine(Connectable):
         self.engine = self
         self.logger = logging.instance_logger(self, echoflag=echo)
 
-    name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'], doc="String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``.")
+    def name(self):
+        "String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``."
+        
+        return sys.modules[self.dialect.__module__].descriptor()['name']
+    name = property(name)
+    
     echo = logging.echo_property()
     
     def __repr__(self):
@@ -1068,10 +1092,9 @@ class Engine(Connectable):
         finally:
             connection.close()
 
-    def _func(self):
+    def func(self):
         return expression._FunctionGenerator(bind=self)
-
-    func = property(_func)
+    func = property(func)
 
     def text(self, text, *args, **kwargs):
         """Return a sql.text() object for performing literal queries."""
@@ -1321,14 +1344,20 @@ class ResultProxy(object):
             self._rowcount = context.get_rowcount()
             self.close()
 
-    def _get_rowcount(self):
+    def rowcount(self):
         if self._rowcount is not None:
             return self._rowcount
         else:
             return self.context.get_rowcount()
-    rowcount = property(_get_rowcount)
-    lastrowid = property(lambda s:s.cursor.lastrowid)
-    out_parameters = property(lambda s:s.context.out_parameters)
+    rowcount = property(rowcount)
+    
+    def lastrowid(self):
+        return self.cursor.lastrowid
+    lastrowid = property(lastrowid)
+    
+    def out_parameters(self):
+        return self.context.out_parameters
+    out_parameters = property(out_parameters)
 
     def _init_metadata(self):
         self.__props = {}
@@ -1423,7 +1452,9 @@ class ResultProxy(object):
             if self.connection.should_close_with_result:
                 self.connection.close()
 
-    keys = property(lambda s:s.__keys)
+    def keys(self):
+        return self.__keys
+    keys = property(keys)
 
     def _has_key(self, row, key):
         try:
index f2b950f2ec8f384b78d7d393323af698138dfd3a..6122b61b239dd0b2fb8a38cbfe9b0cb7279976ba 100644 (file)
@@ -93,7 +93,9 @@ class TLConnection(base.Connection):
         self.__session = session
         self.__opencount = 1
 
-    session = property(lambda s:s.__session)
+    def session(self):
+        return self.__session
+    session = property(session)
 
     def _increment_connect(self):
         self.__opencount += 1
@@ -132,8 +134,13 @@ class TLTransaction(base.Transaction):
         self._trans = trans
         self._session = session
 
-    connection = property(lambda s:s._trans.connection)
-    is_active = property(lambda s:s._trans.is_active)
+    def connection(self):
+        return self._trans.connection
+    connection = property(connection)
+    
+    def is_active(self):
+        return self._trans.is_active
+    is_active = property(is_active)
 
     def rollback(self):
         self._session.rollback()
@@ -168,12 +175,13 @@ class TLEngine(base.Engine):
         super(TLEngine, self).__init__(*args, **kwargs)
         self.context = util.ThreadLocal()
 
-    def _session(self):
+    def session(self):
+        "Returns the current thread's TLSession"
         if not hasattr(self.context, 'session'):
             self.context.session = TLSession(self)
         return self.context.session
 
-    session = property(_session, doc="Returns the current thread's TLSession")
+    session = property(session)
 
     def contextual_connect(self, **kwargs):
         """Return a TLConnection which is thread-locally scoped."""
index 6d5fae50771db9800f66ff54f78084bdc7d5209d..089522673c69126fd1bfc2456c50687873ac3c44 100644 (file)
@@ -276,8 +276,10 @@ class ScalarAttributeImpl(AttributeImpl):
 
         state.dict[self.key] = value
         state.modified=True
-
-    type = property(lambda self: self.property.columns[0].type)
+    
+    def type(self):
+        self.property.columns[0].type
+    type = property(type)
 
 class MutableScalarAttributeImpl(ScalarAttributeImpl):
     """represents a scalar value-holding InstrumentedAttribute, which can detect
index 8340ccdcc6e0cf154f86a3603174274a25431fac..c26e186bdf524ffa88411457cb0205a98e9f2a43 100644 (file)
@@ -253,7 +253,7 @@ class OneToManyDP(DependencyProcessor):
             child = getattr(child, '_state', child)
         source = state
         dest = child
-        if dest is None or (not self.post_update and uowcommit.state_is_deleted(dest)):
+        if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
             return
         self._verify_canload(child)
         self.syncrules.execute(source, dest, source, child, clearkeys)
@@ -363,7 +363,7 @@ class ManyToOneDP(DependencyProcessor):
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
         source = child
         dest = state
-        if dest is None or (not self.post_update and uowcommit.state_is_deleted(dest)):
+        if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
             return
         self._verify_canload(child)
         self.syncrules.execute(source, dest, dest, child, clearkeys)
@@ -491,13 +491,13 @@ class MapperStub(object):
     def polymorphic_iterator(self):
         return iter([self])
         
-    def register_dependencies(self, uowcommit):
+    def _register_dependencies(self, uowcommit):
         pass
 
-    def save_obj(self, *args, **kwargs):
+    def _save_obj(self, *args, **kwargs):
         pass
 
-    def delete_obj(self, *args, **kwargs):
+    def _delete_obj(self, *args, **kwargs):
         pass
 
     def primary_mapper(self):
index ea99d65148e27ebd36fcff657725acca44f5eed7..fe781ab05e4bbb7d7afa44f8ccc81c8e2b26cf63 100644 (file)
@@ -93,9 +93,9 @@ class AppenderQuery(Query):
         else:
             return sess
     
-    def _get_session(self):
+    def session(self):
         return self.__session()
-    session = property(_get_session)
+    session = property(session)
     
     def __iter__(self):
         sess = self.__session()
index 6294338ebb89c94a9596ec3a6d8b954eb913c1de..95d118ee406676fd198643e8931067bf8ff9763d 100644 (file)
@@ -7,11 +7,10 @@
 import weakref, warnings
 from itertools import chain
 from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression, visitors, operators
-from sqlalchemy.sql import util as sqlutil
-from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter
+from sqlalchemy.sql import expression, visitors, operators, util as sqlutil
+from sqlalchemy.sql.expression import _corresponding_column_or_error
 from sqlalchemy.orm import sync, attributes
+from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter, state_str, instance_str
 from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator
 
 __all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry']
@@ -337,7 +336,7 @@ class Mapper(object):
                 self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self)
                 if self.polymorphic_on is None:
                     if self.inherits.polymorphic_on is not None:
-                        self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, raiseerr=False)
+                        self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on)
                     else:
                         raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
 
@@ -440,10 +439,10 @@ class Mapper(object):
             primary_key = expression.ColumnSet()
 
             for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
-                c = self.mapped_table.corresponding_column(col, raiseerr=False)
+                c = self.mapped_table.corresponding_column(col)
                 if c is None:
                     for cc in self._equivalent_columns[col]:
-                        c = self.mapped_table.corresponding_column(cc, raiseerr=False)
+                        c = self.mapped_table.corresponding_column(cc)
                         if c is not None:
                             break
                     else:
@@ -462,7 +461,7 @@ class Mapper(object):
                         break
                     for cc in c.foreign_keys:
                         cc = cc.column
-                        c2 = self.mapped_table.corresponding_column(cc, raiseerr=False)
+                        c2 = self.mapped_table.corresponding_column(cc)
                         if c2 is not None:
                             c = c2
                             tried.add(c)
@@ -651,7 +650,7 @@ class Mapper(object):
             elif prop is None:
                 mapped_column = []
                 for c in columns:
-                    mc = self.mapped_table.corresponding_column(c, raiseerr=False)
+                    mc = self.mapped_table.corresponding_column(c)
                     if not mc:
                         raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table.  Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c))
                     mapped_column.append(mc)
@@ -664,7 +663,7 @@ class Mapper(object):
 
         if isinstance(prop, ColumnProperty):
             # relate the mapper's "select table" to the given ColumnProperty
-            col = self.select_table.corresponding_column(prop.columns[0], raiseerr=False)
+            col = self.select_table.corresponding_column(prop.columns[0])
             # col might not be present! the selectable given to the mapper need not include "deferred"
             # columns (included in zblog tests)
             if col is None:
@@ -713,10 +712,10 @@ class Mapper(object):
             if self._init_properties is not None:
                 for key, prop in self._init_properties.iteritems():
                     if expression.is_column(prop):
-                        props[key] = self.select_table.corresponding_column(prop)
+                        props[key] = _corresponding_column_or_error(self.select_table, prop)
                     elif (isinstance(prop, list) and expression.is_column(prop[0])):
-                        props[key] = [self.select_table.corresponding_column(c) for c in prop]
-            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument)
+                        props[key] = [_corresponding_column_or_error(self.select_table, c) for c in prop]
+            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=self.primary_key_argument)
 
     def _compile_class(self):
         """If this mapper is to be a primary mapper (i.e. the
@@ -919,27 +918,27 @@ class Mapper(object):
     def _set_attr_by_column(self, obj, column, value):
         self._get_col_to_prop(column).setattr(obj._state, column, value)
 
-    def save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
+    def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
         """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
 
         This is called within the context of a UOWTransaction during a
         flush operation.
 
-        `save_obj` issues SQL statements not just for instances mapped
+        `_save_obj` issues SQL statements not just for instances mapped
         directly by this mapper, but for instances mapped by all
         inheriting mappers as well.  This is to maintain proper insert
         ordering among a polymorphic chain of instances. Therefore
-        save_obj is typically called only on a *base mapper*, or a
+        _save_obj is typically called only on a *base mapper*, or a
         mapper which does not inherit from any other mapper.
         """
 
         if self.__should_log_debug:
-            self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched"))
+            self.__log_debug("_save_obj() start, " + (single and "non-batched" or "batched"))
 
-        # if batch=false, call save_obj separately for each object
+        # if batch=false, call _save_obj separately for each object
         if not single and not self.batch:
             for state in states:
-                self.save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
+                self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
             return
 
         # if session has a connection callable, 
@@ -970,11 +969,11 @@ class Mapper(object):
             mapper = _state_mapper(state)
             instance_key = mapper._identity_key_from_state(state)
             if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map:
-                existing = uowtransaction.uow.identity_map[instance_key]
+                existing = uowtransaction.uow.identity_map[instance_key]._state
                 if not uowtransaction.is_deleted(existing):
-                    raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.state_str(state), str(instance_key), mapperutil.instance_str(existing)))
+                    raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
                 if self.__should_log_debug:
-                    self.__log_debug("detected row switch for identity %s.  will update %s, remove %s from transaction" % (instance_key, mapperutil.state_str(state), mapperutil.instance_str(existing)))
+                    self.__log_debug("detected row switch for identity %s.  will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing)))
                 uowtransaction.set_row_switch(existing)
 
         inserted_objects = util.Set()
@@ -997,7 +996,7 @@ class Mapper(object):
                 instance_key = mapper._identity_key_from_state(state)
 
                 if self.__should_log_debug:
-                    self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.state_str(state), str(instance_key)))
+                    self.__log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key)))
 
                 isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity
                 params = {}
@@ -1153,7 +1152,7 @@ class Mapper(object):
         if deferred_props:
             _expire_state(state, deferred_props)
 
-    def delete_obj(self, states, uowtransaction):
+    def _delete_obj(self, states, uowtransaction):
         """Issue ``DELETE`` statements for a list of objects.
 
         This is called within the context of a UOWTransaction during a
@@ -1161,7 +1160,7 @@ class Mapper(object):
         """
 
         if self.__should_log_debug:
-            self.__log_debug("delete_obj() start")
+            self.__log_debug("_delete_obj() start")
 
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
             connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
@@ -1223,7 +1222,7 @@ class Mapper(object):
                 if 'after_delete' in mapper.extension.methods:
                     mapper.extension.after_delete(mapper, connection, state.obj())
 
-    def register_dependencies(self, uowcommit):
+    def _register_dependencies(self, uowcommit):
         """Register ``DependencyProcessor`` instances with a
         ``unitofwork.UOWTransaction``.
 
@@ -1303,7 +1302,7 @@ class Mapper(object):
             state = instance._state
 
             if self.__should_log_debug:
-                self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
+                self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey)))
 
             isnew = state.runid != context.runid
             currentload = not isnew
@@ -1337,7 +1336,7 @@ class Mapper(object):
                 instance = attributes.new_instance(self.class_)
                 
             if self.__should_log_debug:
-                self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
+                self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
             
             state = instance._state    
             instance._entity_name = self.entity_name
@@ -1460,7 +1459,7 @@ class Mapper(object):
         statement = sql.select(needs_tables, cond, use_labels=True)
         def post_execute(instance, **flags):
             if self.__should_log_debug:
-                self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance))
+                self.__log_debug("Post query loading instance " + instance_str(instance))
 
             identitykey = self.identity_key_from_instance(instance)
 
index 027cefd692f20fc89c6ce182c75ef4c7466eba5e..441a1d7cd62343e5f445e0225539560396214da6 100644 (file)
@@ -12,10 +12,11 @@ to handle flush-time dependency sorting and processing.
 """
 
 from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql import util as sql_util, visitors, operators, ColumnElement
+from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause
+from sqlalchemy.sql import visitors, operators, ColumnElement
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
 from sqlalchemy.orm import session as sessionlib
-from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.util import CascadeOptions
 from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
 from sqlalchemy.exceptions import ArgumentError
 import warnings
@@ -201,13 +202,13 @@ class PropertyLoader(StrategizedProperty):
         self.strategy_class = strategy_class
 
         if cascade is not None:
-            self.cascade = mapperutil.CascadeOptions(cascade)
+            self.cascade = CascadeOptions(cascade)
         else:
             if private:
                 util.warn_deprecated('private option is deprecated; see docs for details')
-                self.cascade = mapperutil.CascadeOptions("all, delete-orphan")
+                self.cascade = CascadeOptions("all, delete-orphan")
             else:
-                self.cascade = mapperutil.CascadeOptions("save-update, merge")
+                self.cascade = CascadeOptions("save-update, merge")
         
         if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade):
             raise exceptions.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade")
@@ -312,8 +313,10 @@ class PropertyLoader(StrategizedProperty):
     
     def _optimized_compare(self, value, value_is_parent=False):
         return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent)
-        
-    private = property(lambda s:s.cascade.delete_orphan)
+    
+    def private(self):
+        return self.cascade.delete_orphan
+    private = property(private)
 
     def create_strategy(self):
         if self.strategy_class:
@@ -456,7 +459,7 @@ class PropertyLoader(StrategizedProperty):
         # to the "polymorphic" selectable as needed).  since this is an API change, put an explicit check/
         # error message in case its the "old" way.
         if self.loads_polymorphic:
-            vis = sql_util.ColumnsInClause(self.mapper.select_table)
+            vis = ColumnsInClause(self.mapper.select_table)
             vis.traverse(self.primaryjoin)
             if self.secondaryjoin:
                 vis.traverse(self.secondaryjoin)
@@ -469,12 +472,12 @@ class PropertyLoader(StrategizedProperty):
 
         def col_is_part_of_mappings(col):
             if self.secondary is None:
-                return self.parent.mapped_table.corresponding_column(col, raiseerr=False) is not None or \
-                    self.target.corresponding_column(col, raiseerr=False) is not None
+                return self.parent.mapped_table.corresponding_column(col) is not None or \
+                    self.target.corresponding_column(col) is not None
             else:
-                return self.parent.mapped_table.corresponding_column(col, raiseerr=False) is not None or \
-                    self.target.corresponding_column(col, raiseerr=False) is not None or \
-                    self.secondary.corresponding_column(col, raiseerr=False) is not None
+                return self.parent.mapped_table.corresponding_column(col) is not None or \
+                    self.target.corresponding_column(col) is not None or \
+                    self.secondary.corresponding_column(col) is not None
 
         if self.foreign_keys:
             self._opposite_side = util.Set()
@@ -597,13 +600,13 @@ class PropertyLoader(StrategizedProperty):
             target_equivalents = self.mapper._get_equivalent_columns()
 
             if self.secondaryjoin:
-                self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
+                self.polymorphic_secondaryjoin = ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
                 self.polymorphic_primaryjoin = self.primaryjoin
             else:
                 if self.direction is sync.ONETOMANY:
-                    self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
+                    self.polymorphic_primaryjoin = ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
                 elif self.direction is sync.MANYTOONE:
-                    self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
+                    self.polymorphic_primaryjoin = ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
                 self.polymorphic_secondaryjoin = None
 
             # load "polymorphic" versions of the columns present in "remote_side" - this is
@@ -612,7 +615,7 @@ class PropertyLoader(StrategizedProperty):
                 if self.secondary and self.secondary.columns.contains_column(c):
                     continue
                 for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): 
-                    corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False)
+                    corr = self.mapper.select_table.corresponding_column(equiv)
                     if corr:
                         self.remote_side.add(corr)
                         break
@@ -686,11 +689,11 @@ class PropertyLoader(StrategizedProperty):
             if polymorphic_parent:
                 # adapt the "parent" side of our join condition to the "polymorphic" select of the parent
                 if self.direction is sync.ONETOMANY:
-                    primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+                    primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
                 elif self.direction is sync.MANYTOONE:
-                    primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+                    primaryjoin = ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
                 elif self.secondaryjoin:
-                    primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+                    primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
 
             if secondaryjoin is not None:
                 if secondary and not primary:
index 902a4fd3be14f5f7b78fc828a94df37a36706e7a..2c9a1d0ff52f2f921d4dd4c454cdfd75194c73bd 100644 (file)
@@ -983,8 +983,8 @@ class Query(object):
                     cf.update(sql_util.find_columns(o))
 
             if adapt_criterion:
-                context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns]
-                cf = [from_obj.corresponding_column(c, raiseerr=False) or c for c in cf]
+                context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
+                cf = [from_obj.corresponding_column(c) or c for c in cf]
 
             s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args())
             
@@ -1004,7 +1004,7 @@ class Query(object):
             statement.append_order_by(*context.eager_order_by)
         else:
             if adapt_criterion:
-                context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns]
+                context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
                 self._primary_adapter = mapperutil.create_row_adapter(from_obj, self.table)
 
             if adapt_criterion or self._distinct:
index d6d1d1ff6ba8d1ff45496dc8ddca93d62ec3e6e0..541590b82d81c05db3bc17f682d0e7659429f08e 100644 (file)
@@ -663,14 +663,14 @@ class Session(object):
             q = q.add_entity(ent)
         return q
 
-    def _sql(self):
+    def sql(self):
         class SQLProxy(object):
             def __getattr__(self, key):
                 def call(*args, **kwargs):
                     kwargs[engine] = self.engine
                     return getattr(sql, key)(*args, **kwargs)
 
-    sql = property(_sql)
+    sql = property(sql)
 
     def _autoflush(self):
         if self.autoflush and (self.transaction is None or self.transaction.autoflush):
@@ -1079,26 +1079,35 @@ class Session(object):
                 return True
         return False
 
-    dirty = property(lambda s:s.uow.locate_dirty(),
-                     doc="""A ``Set`` of all instances marked as 'dirty' within this ``Session``.
+    def dirty(self):
+        """Return a ``Set`` of all instances marked as 'dirty' within this ``Session``.
 
-                     Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection
-                     modification operations will mark an instance as 'dirty' and place it in this set,
-                     even if there is no net change to the attribute's value.  At flush time, the value
-                     of each attribute is compared to its previously saved value,
-                     and if there's no net change, no SQL operation will occur (this is a more expensive
-                     operation so it's only done at flush time).
+        Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection
+        modification operations will mark an instance as 'dirty' and place it in this set,
+        even if there is no net change to the attribute's value.  At flush time, the value
+        of each attribute is compared to its previously saved value,
+        and if there's no net change, no SQL operation will occur (this is a more expensive
+        operation so it's only done at flush time).
 
-                     To check if an instance has actionable net changes to its attributes, use the
-                     is_modified() method.
-                     """)
-
-    deleted = property(lambda s:util.IdentitySet(s.uow.deleted.values()),
-                       doc="A ``Set`` of all instances marked as 'deleted' within this ``Session``")
-
-    new = property(lambda s:util.IdentitySet(s.uow.new.values()),
-                   doc="A ``Set`` of all instances marked as 'new' within this ``Session``.")
+        To check if an instance has actionable net changes to its attributes, use the
+        is_modified() method.
+        """
 
+        return self.uow.locate_dirty()
+    dirty = property(dirty)
+
+    def deleted(self):
+        "Return a ``Set`` of all instances marked as 'deleted' within this ``Session``"
+        
+        return util.IdentitySet(self.uow.deleted.values())
+    deleted = property(deleted)
+
+    def new(self):
+        "Return a ``Set`` of all instances marked as 'new' within this ``Session``."
+        
+        return util.IdentitySet(self.uow.new.values())
+    new = property(new)
+    
 def _expire_state(state, attribute_names):
     """Standalone expire instance function.
 
index 26ac3703ebbdec5727565040629cfcb8a756ee78..59d784ecf2951c3c82da7db4949b0caee07a77a6 100644 (file)
@@ -164,11 +164,6 @@ class UnitOfWork(object):
 
     def flush(self, session, objects=None):
         """create a dependency tree of all pending SQL operations within this unit of work and execute."""
-        
-        # this context will track all the objects we want to save/update/delete,
-        # and organize a hierarchical dependency structure.  it also handles
-        # communication with the mappers and relationships to fire off SQL
-        # and synchronize attributes between related objects.
 
         dirty = [x for x in self.identity_map.all_states()
             if x.modified
@@ -325,35 +320,30 @@ class UOWTransaction(object):
         else:
             task.append(state, listonly, isdelete=isdelete, **kwargs)
 
-    def set_row_switch(self, obj):
+    def set_row_switch(self, state):
         """mark a deleted object as a 'row switch'.
         
         this indicates that an INSERT statement elsewhere corresponds to this DELETE;
         the INSERT is converted to an UPDATE and the DELETE does not occur.
         """
-        mapper = object_mapper(obj)
+        mapper = _state_mapper(state)
         task = self.get_task_by_mapper(mapper)
-        taskelement = task._objects[obj._state]
+        taskelement = task._objects[state]
         taskelement.isdelete = "rowswitch"
         
     def unregister_object(self, obj):
         """remove an object from its parent UOWTask.
         
-        called by mapper.save_obj() when an 'identity switch' is detected, so that
+        called by mapper._save_obj() when an 'identity switch' is detected, so that
         no further operations occur upon the instance."""
         mapper = object_mapper(obj)
         task = self.get_task_by_mapper(mapper)
         if obj._state in task._objects:
             task.delete(obj._state)
 
-    def is_deleted(self, obj):
-        """return true if the given object is marked as deleted within this UOWTransaction."""
+    def is_deleted(self, state):
+        """return true if the given state is marked as deleted within this UOWTransaction."""
         
-        mapper = object_mapper(obj)
-        task = self.get_task_by_mapper(mapper)
-        return task.is_deleted(obj._state)
-
-    def state_is_deleted(self, state):
         mapper = _state_mapper(state)
         task = self.get_task_by_mapper(mapper)
         return task.is_deleted(state)
@@ -375,11 +365,11 @@ class UOWTransaction(object):
                 base_task = self.tasks[base_mapper]
             else:
                 self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper)
-                base_mapper.register_dependencies(self)
+                base_mapper._register_dependencies(self)
 
             if mapper not in self.tasks:
                 self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task)
-                mapper.register_dependencies(self)
+                mapper._register_dependencies(self)
             else:
                 task = self.tasks[mapper]
                 
@@ -581,7 +571,7 @@ class UOWTask(object):
         # postupdates are UPDATED immeditely (for now)
         # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns
         # instead of __eq__
-        self.mapper.save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols))
+        self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols))
 
     def delete(self, obj):
         """remove the given object from this UOWTask, if present."""
@@ -940,10 +930,10 @@ class UOWExecutor(object):
                 self.execute_delete_steps(trans, task)
 
     def save_objects(self, trans, task):
-        task.mapper.save_obj(task.polymorphic_tosave_objects, trans)
+        task.mapper._save_obj(task.polymorphic_tosave_objects, trans)
 
     def delete_objects(self, trans, task):
-        task.mapper.delete_obj(task.polymorphic_todelete_objects, trans)
+        task.mapper._delete_obj(task.polymorphic_todelete_objects, trans)
 
     def execute_dependency(self, trans, dep, isdelete):
         dep.execute(trans, isdelete)
index 6e31b46468244306b2d06c9d734eb3897bb6c18f..7b76183be01647f9a4a7f9338dbd466c434c3dd4 100644 (file)
@@ -154,7 +154,7 @@ class AliasedClauses(object):
         """return the aliased version of the given column, creating a new label for it if not already
         present in this AliasedClauses."""
 
-        conv = self.alias.corresponding_column(column, raiseerr=False)
+        conv = self.alias.corresponding_column(column)
         if conv:
             return conv
 
@@ -199,13 +199,13 @@ def create_row_adapter(from_, to, equivalent_columns=None):
     
     map = {}
     for c in to.c:
-        corr = from_.corresponding_column(c, raiseerr=False)
+        corr = from_.corresponding_column(c)
         if corr:
             map[c] = corr
         elif equivalent_columns:
             if c in equivalent_columns:
                 for c2 in equivalent_columns[c]:
-                    corr = from_.corresponding_column(c2, raiseerr=False)
+                    corr = from_.corresponding_column(c2)
                     if corr:
                         map[c] = corr
                         break
index 15b35b96a307f040b6aec6c7706bdcbb7a46af4d..e0d45870bbc0f5a65dd78dbb8861d86a29932389 100644 (file)
@@ -20,8 +20,6 @@ objects as well as the visitor interface, so that the schema package
 import re, inspect
 from sqlalchemy import types, exceptions, util, databases
 from sqlalchemy.sql import expression, visitors
-import sqlalchemy
-
 
 URL = None
 
@@ -43,9 +41,6 @@ class SchemaItem(object):
             if item is not None:
                 item._set_parent(self)
 
-    def _get_parent(self):
-        raise NotImplementedError()
-
     def _set_parent(self, parent):
         """Associate with this SchemaItem's parent object."""
 
@@ -58,20 +53,12 @@ class SchemaItem(object):
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
-    def _get_bind(self, raiseerr=False):
-        """Return the engine or None if no engine."""
+    def bind(self):
+        """Return the connectable associated with this SchemaItem."""
 
-        if raiseerr:
-            m = self.metadata
-            e = m and m.bind or None
-            if e is None:
-                raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
-            else:
-                return e
-        else:
-            m = self.metadata
-            return m and m.bind or None
-    bind = property(lambda s:s._get_bind())
+        m = self.metadata
+        return m and m.bind or None
+    bind = property(bind)
     
     def info(self):
         try:
@@ -231,7 +218,7 @@ class Table(SchemaItem, expression.TableClause):
             if autoload_with:
                 autoload_with.reflecttable(self, include_columns=include_columns)
             else:
-                metadata._get_bind(raiseerr=True).reflecttable(self, include_columns=include_columns)
+                _bind_or_error(metadata).reflecttable(self, include_columns=include_columns)
                 
         # initialize all the column, etc. objects.  done after
         # reflection to allow user-overrides
@@ -269,9 +256,6 @@ class Table(SchemaItem, expression.TableClause):
 
         constraint._set_parent(self)
 
-    def _get_parent(self):
-        return self.metadata
-
     def _set_parent(self, metadata):
         metadata.tables[_get_table_key(self.name, self.schema)] = self
         self.metadata = metadata
@@ -289,7 +273,7 @@ class Table(SchemaItem, expression.TableClause):
         """Return True if this table exists."""
 
         if bind is None:
-            bind = self._get_bind(raiseerr=True)
+            bind = _bind_or_error(self)
 
         def do(conn):
             return conn.dialect.has_table(conn, self.name, schema=self.schema)
@@ -463,9 +447,10 @@ class Column(SchemaItem, expression._ColumnClause):
         else:
             return self.description
 
-    def _get_bind(self):
+    def bind(self):
         return self.table.bind
-
+    bind = property(bind)
+    
     def references(self, column):
         """return true if this column references the given column via foreign key"""
         for fk in self.foreign_keys:
@@ -496,9 +481,6 @@ class Column(SchemaItem, expression._ColumnClause):
             [(self.table and "table=<%s>" % self.table.description or "")] +
             ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg])
 
-    def _get_parent(self):
-        return self.table
-
     def _set_parent(self, table):
         self.metadata = table.metadata
         if getattr(self, 'table', None) is not None:
@@ -622,15 +604,15 @@ class ForeignKey(SchemaItem):
     def references(self, table):
         """Return True if the given table is referenced by this ``ForeignKey``."""
 
-        return table.corresponding_column(self.column, False) is not None
+        return table.corresponding_column(self.column) is not None
     
     def get_referent(self, table):
         """return the column in the given table referenced by this ``ForeignKey``, or
         None if this ``ForeignKey`` does not reference the given table.
         """
-        return table.corresponding_column(self.column, False)
+        return table.corresponding_column(self.column)
         
-    def _init_column(self):
+    def column(self):
         # ForeignKey inits its remote column as late as possible, so tables can
         # be defined without dependencies
         if self._column is None:
@@ -674,10 +656,7 @@ class ForeignKey(SchemaItem):
             self.parent.type = self._column.type
         return self._column
 
-    column = property(lambda s: s._init_column())
-
-    def _get_parent(self):
-        return self.parent
+    column = property(column)
 
     def _set_parent(self, column):
         self.parent = column
@@ -704,9 +683,6 @@ class DefaultGenerator(SchemaItem):
         self.for_update = for_update
         self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
 
-    def _get_parent(self):
-        return getattr(self, 'column', None)
-
     def _set_parent(self, column):
         self.column = column
         self.metadata = self.column.table.metadata
@@ -717,7 +693,7 @@ class DefaultGenerator(SchemaItem):
 
     def execute(self, bind=None, **kwargs):
         if bind is None:
-            bind = self._get_bind(raiseerr=True)
+            bind = _bind_or_error(self)
         return bind._execute_default(self, **kwargs)
 
     def __repr__(self):
@@ -798,14 +774,14 @@ class Sequence(DefaultGenerator):
         """Creates this sequence in the database."""
         
         if bind is None:
-            bind = self._get_bind(raiseerr=True)
+            bind = _bind_or_error(self)
         bind.create(self, checkfirst=checkfirst)
 
     def drop(self, bind=None, checkfirst=True):
         """Drops this sequence from the database."""
 
         if bind is None:
-            bind = self._get_bind(raiseerr=True)
+            bind = _bind_or_error(self)
         bind.drop(self, checkfirst=checkfirst)
 
 
@@ -838,20 +814,17 @@ class Constraint(SchemaItem):
     def copy(self):
         raise NotImplementedError()
 
-    def _get_parent(self):
-        return getattr(self, 'table', None)
-
 class CheckConstraint(Constraint):
     def __init__(self, sqltext, name=None):
         super(CheckConstraint, self).__init__(name)
         self.sqltext = sqltext
 
-    def _visit_name(self):
+    def __visit_name__(self):
         if isinstance(self.parent, Table):
             return "check_constraint"
         else:
             return "column_check_constraint"
-    __visit_name__ = property(_visit_name)
+    __visit_name__ = property(__visit_name__)
 
     def _set_parent(self, parent):
         self.parent = parent
@@ -976,9 +949,6 @@ class Index(SchemaItem):
         for column in args:
             self.append_column(column)
 
-    def _get_parent(self):
-        return self.table
-
     def _set_parent(self, table):
         self.table = table
         self.metadata = table.metadata
@@ -1002,17 +972,15 @@ class Index(SchemaItem):
         self.columns.append(column)
 
     def create(self, bind=None):
-        if bind is not None:
-            bind.create(self)
-        else:
-            self._get_bind(raiseerr=True).create(self)
+        if bind is None:
+            bind = _bind_or_error(self)
+        bind.create(self)
         return self
 
     def drop(self, bind=None):
-        if bind is not None:
-            bind.drop(self)
-        else:
-            self._get_bind(raiseerr=True).drop(self)
+        if bind is None:
+            bind = _bind_or_error(self)
+        bind.drop(self)
 
     def __str__(self):
         return repr(self)
@@ -1113,6 +1081,17 @@ class MetaData(SchemaItem):
             self._bind = bind
     connect = util.deprecated(connect)
 
+    def bind(self):
+        """An Engine or Connection to which this MetaData is bound.
+
+        This property may be assigned an ``Engine`` or
+        ``Connection``, or assigned a string or URL to
+        automatically create a basic ``Engine`` for this bind
+        with ``create_engine()``.
+        """
+        
+        return self._bind
+        
     def _bind_to(self, bind):
         """Bind this MetaData to an Engine, Connection, string or URL."""
 
@@ -1121,17 +1100,11 @@ class MetaData(SchemaItem):
             from sqlalchemy.engine.url import URL
 
         if isinstance(bind, (basestring, URL)):
-            self._bind = sqlalchemy.create_engine(bind)
+            from sqlalchemy import create_engine
+            self._bind = create_engine(bind)
         else:
             self._bind = bind
-
-    bind = property(lambda self: self._bind, _bind_to, doc=
-                    """An Engine or Connection to which this MetaData is bound.
-
-                    This property may be assigned an ``Engine`` or
-                    ``Connection``, or assigned a string or URL to
-                    automatically create a basic ``Engine`` for this bind
-                    with ``create_engine()``.""")
+    bind = property(bind, _bind_to)
     
     def clear(self):
         self.tables.clear()
@@ -1141,15 +1114,12 @@ class MetaData(SchemaItem):
         del self.tables[table.key]
         
     def table_iterator(self, reverse=True, tables=None):
-        from sqlalchemy.sql import util as sql_util
+        from sqlalchemy.sql.util import sort_tables
         if tables is None:
             tables = self.tables.values()
         else:
             tables = util.Set(tables).intersection(self.tables.values())
-        return iter(sql_util.sort_tables(tables, reverse=reverse))
-
-    def _get_parent(self):
-        return None
+        return iter(sort_tables(tables, reverse=reverse))
 
     def reflect(self, bind=None, schema=None, only=None):
         """Load all available table definitions from the database.
@@ -1184,7 +1154,7 @@ class MetaData(SchemaItem):
 
         reflect_opts = {'autoload': True}
         if bind is None:
-            bind = self._get_bind(raiseerr=True)
+            bind = _bind_or_error(self)
             conn = None
         else:
             reflect_opts['autoload_with'] = bind
@@ -1230,7 +1200,7 @@ class MetaData(SchemaItem):
         """
 
         if bind is None:
-            bind = self._get_bind(raiseerr=True)
+            bind = _bind_or_error(self)
         bind.create(self, checkfirst=checkfirst, tables=tables)
 
     def drop_all(self, bind=None, tables=None, checkfirst=True):
@@ -1249,17 +1219,9 @@ class MetaData(SchemaItem):
         """
 
         if bind is None:
-            bind = self._get_bind(raiseerr=True)
+            bind = _bind_or_error(self)
         bind.drop(self, checkfirst=checkfirst, tables=tables)
-
-    def _get_bind(self, raiseerr=False):
-        if not self.is_bound():
-            if raiseerr:
-                raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
-            else:
-                return None
-        return self._bind
-
+    
 class ThreadLocalMetaData(MetaData):
     """A MetaData variant that presents a different ``bind`` in every thread.
 
@@ -1279,14 +1241,10 @@ class ThreadLocalMetaData(MetaData):
     __visit_name__ = 'metadata'
 
     def __init__(self):
-        """Construct a ThreadLocalMetaData.
-
-        Takes no arguments.
-        """
-        
+        """Construct a ThreadLocalMetaData."""
+    
         self.context = util.ThreadLocal()
         self.__engines = {}
-        
         super(ThreadLocalMetaData, self).__init__()
 
     # @deprecated
@@ -1315,18 +1273,14 @@ class ThreadLocalMetaData(MetaData):
         self._bind_to(bind)
     connect = util.deprecated(connect)
 
-    def _get_bind(self, raiseerr=False):
-        """The bound ``Engine`` or ``Connectable`` for this thread."""
+    def bind(self):
+        """The bound Engine or Connection for this thread.
+
+        This property may be assigned an Engine or Connection,
+        or assigned a string or URL to automatically create a
+        basic Engine for this bind with ``create_engine()``."""
         
-        if hasattr(self.context, '_engine'):
-            return self.context._engine
-        else:
-            if raiseerr:
-                raise exceptions.InvalidRequestError(
-                    "This ThreadLocalMetaData is not bound to any Engine or "
-                    "Connection.")
-            else: 
-                return None
+        return getattr(self.context, '_engine', None)
 
     def _bind_to(self, bind):
         """Bind to a Connectable in the caller's thread."""
@@ -1349,12 +1303,7 @@ class ThreadLocalMetaData(MetaData):
                 self.__engines[bind] = bind
             self.context._engine = bind
 
-    bind = property(_get_bind, _bind_to, doc=
-                    """The bound Engine or Connection for this thread.
-
-                    This property may be assigned an Engine or Connection,
-                    or assigned a string or URL to automatically create a
-                    basic Engine for this bind with ``create_engine()``.""")
+    bind = property(bind, _bind_to)
 
     def is_bound(self):
         """True if there is a bind for this thread."""
@@ -1368,8 +1317,13 @@ class ThreadLocalMetaData(MetaData):
             if hasattr(e, 'dispose'):
                 e.dispose()
 
-
 class SchemaVisitor(visitors.ClauseVisitor):
     """Define the visiting for ``SchemaItem`` objects."""
 
     __traverse_options__ = {'schema_visitor':True}
+
+def _bind_or_error(schemaitem):
+    bind = schemaitem.bind
+    if not bind:
+        raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
+    return bind
\ No newline at end of file
index 55001dc700857bfbdc3f20047cf244a2c6c0e88e..a448fa6d3d99d3f275bb3ba4dd74675f42fe2beb 100644 (file)
@@ -820,6 +820,12 @@ def _literal_as_binds(element, name=None, type_=None):
     else:
         return element
 
+def _corresponding_column_or_error(fromclause, column, require_embedded=False):
+    c = fromclause.corresponding_column(column, require_embedded=require_embedded)
+    if not c:
+        raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+    return c
+
 def _selectable(element):
     if hasattr(element, '__selectable__'):
         return element.__selectable__()
@@ -958,13 +964,8 @@ class ClauseElement(object):
 
         return False
 
-    def _find_engine(self):
-        """Default strategy for locating an engine within the clause element.
-
-        Relies upon a local engine property, or looks in the *from*
-        objects which ultimately have to contain Tables or
-        TableClauses.
-        """
+    def bind(self):
+        """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
 
         try:
             if self._bind is not None:
@@ -979,8 +980,7 @@ class ClauseElement(object):
                 return engine
         else:
             return None
-
-    bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""")
+    bind = property(bind)
 
     def execute(self, *multiparams, **params):
         """Compile and execute this ``ClauseElement``."""
@@ -1406,7 +1406,6 @@ class ColumnElement(ClauseElement, _CompareMixin):
             return self._base_columns
         self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
         return self._base_columns
-
     base_columns = property(base_columns)
     
     def proxy_set(self):
@@ -1603,7 +1602,7 @@ class FromClause(Selectable):
       from sqlalchemy.sql import util
       return util.ClauseAdapter(alias).traverse(self, clone=True)
 
-    def corresponding_column(self, column, raiseerr=True, require_embedded=False):
+    def corresponding_column(self, column, require_embedded=False):
         """Given a ``ColumnElement``, return the exported ``ColumnElement``
         object from this ``Selectable`` which corresponds to that
         original ``Column`` via a common anscestor column.
@@ -1611,10 +1610,6 @@ class FromClause(Selectable):
         column
           the target ``ColumnElement`` to be matched
 
-        raiseerr
-          if True, raise an error if the given ``ColumnElement`` could
-          not be matched. if False, non-matches will return None.
-
         require_embedded
           only return corresponding columns for the given
           ``ColumnElement``, if the given ``ColumnElement`` is
@@ -1624,12 +1619,6 @@ class FromClause(Selectable):
           of this ``FromClause``.
         """
 
-        if require_embedded and column not in self._get_all_embedded_columns():
-            if not raiseerr:
-                return None
-            else:
-                raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table.description))
-        
         # dont dig around if the column is locally present
         if self.c.contains_column(column):
             return column
@@ -1638,16 +1627,12 @@ class FromClause(Selectable):
         target_set = column.proxy_set
         for c in self.c + [self.oid_column]:
             i = c.proxy_set.intersection(target_set)
-            if i and (intersect is None or len(i) > len(intersect)):
+            if i and \
+                (not require_embedded or c.proxy_set.issuperset(target_set)) and \
+                (intersect is None or len(i) > len(intersect)):
                 col, intersect = c, i
-        if col:
-            return col
-
-        if not raiseerr:
-            return None
-        else:
-            raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.description))
-
+        return col
+        
     def description(self):
         """a brief description of this FromClause.
         
@@ -1666,17 +1651,6 @@ class FromClause(Selectable):
             if hasattr(self, attr):
                 delattr(self, attr)
 
-    def _get_all_embedded_columns(self):
-        if hasattr(self, '_embedded_columns'):
-            return self._embedded_columns
-        ret = util.Set()
-        class FindCols(visitors.ClauseVisitor):
-            def visit_column(self, col):
-                ret.add(col)
-        FindCols().traverse(self)
-        self._embedded_columns = ret
-        return ret
-
     def _expr_attr_func(name):
         def attr(self):
             try:
@@ -1684,12 +1658,11 @@ class FromClause(Selectable):
             except AttributeError:
                 self._export_columns()
                 return getattr(self, name)
-        return attr
+        return property(attr)
 
-    columns = property(_expr_attr_func('_columns'))
-    c = property(_expr_attr_func('_columns'))
-    primary_key = property(_expr_attr_func('_primary_key'))
-    foreign_keys = property(_expr_attr_func('_foreign_keys'))
+    columns = c = _expr_attr_func('_columns')
+    primary_key = _expr_attr_func('_primary_key')
+    foreign_keys = _expr_attr_func('_foreign_keys')
 
     def _export_columns(self, columns=None):
         """Initialize column collections."""
@@ -1881,14 +1854,14 @@ class _TextClause(ClauseElement):
             for b in bindparams:
                 self.bindparams[b.key] = b
 
-    def _get_type(self):
+    def type(self):
         if self.typemap is not None and len(self.typemap) == 1:
             return list(self.typemap)[0]
         else:
             return None
-    type = property(_get_type)
+    type = property(type)
 
-    columns = property(lambda s:[])
+    columns = []
 
     def _copy_internals(self, clone=_clone):
         self.bindparams = dict([(b.key, clone(b)) for b in self.bindparams.values()])
@@ -2329,7 +2302,12 @@ class Join(FromClause):
         else:
             return and_(*crit)
 
-    def _get_folded_equivalents(self, equivs=None):
+    def _folded_equivalents(self, equivs=None):
+        """Returns the column list of this Join with all equivalently-named, 
+        equated columns folded into one column, where 'equated' means they are 
+        equated to each other in the ON clause of this join.
+        """
+        
         if self.__folded_equivalents is not None:
             return self.__folded_equivalents
         if equivs is None:
@@ -2342,11 +2320,11 @@ class Join(FromClause):
         LocateEquivs().traverse(self.onclause)
         collist = []
         if isinstance(self.left, Join):
-            left = self.left._get_folded_equivalents(equivs)
+            left = self.left._folded_equivalents(equivs)
         else:
             left = list(self.left.columns)
         if isinstance(self.right, Join):
-            right = self.right._get_folded_equivalents(equivs)
+            right = self.right._folded_equivalents(equivs)
         else:
             right = list(self.right.columns)
         used = util.Set()
@@ -2359,10 +2337,7 @@ class Join(FromClause):
                 collist.append(c)
         self.__folded_equivalents = collist
         return self.__folded_equivalents
-
-    folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, "
-                                                            "equated columns folded into one column, where 'equated' means they are "
-                                                            "equated to each other in the ON clause of this join.")
+    folded_equivalents = property(_folded_equivalents)
 
     def select(self, whereclause = None, fold_equivalents=False, **kwargs):
         """Create a ``Select`` from this ``Join``.
@@ -2391,7 +2366,9 @@ class Join(FromClause):
 
         return select(collist, whereclause, from_obj=[self], **kwargs)
 
-    bind = property(lambda s:s.left.bind or s.right.bind)
+    def bind(self):
+        return self.left.bind or self.right.bind
+    bind = property(bind)
 
     def alias(self, name=None):
         """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it.
@@ -2474,8 +2451,10 @@ class Alias(FromClause):
 
     def _get_from_objects(self, **modifiers):
         return [self]
-
-    bind = property(lambda s: s.selectable.bind)
+    
+    def bind(self):
+        return self.selectable.bind
+    bind = property(bind)
 
 class _ColumnElementAdapter(ColumnElement):
     """Adapts a ClauseElement which may or may not be a
@@ -2486,9 +2465,14 @@ class _ColumnElementAdapter(ColumnElement):
     def __init__(self, elem):
         self.elem = elem
         self.type = getattr(elem, 'type', None)
-
-    key = property(lambda s: s.elem.key)
-    _label = property(lambda s: s.elem._label)
+    
+    def key(self):
+        return self.elem.key
+    key = property(key)
+    
+    def _label(self):
+        return self.elem._label
+    _label = property(_label)
 
     def _copy_internals(self, clone=_clone):
         self.elem = clone(self.elem)
@@ -2520,8 +2504,13 @@ class _FromGrouping(FromClause):
     def __init__(self, elem):
         self.elem = elem
 
-    columns = c = property(lambda s:s.elem.columns)
-    _hide_froms = property(lambda s:s.elem._hide_froms)
+    def columns(self):
+        return self.elem.columns
+    columns = c = property(columns)
+    
+    def _hide_froms(self):
+        return self.elem._hide_froms
+    _hide_froms = property(_hide_froms)
     
     def get_children(self, **kwargs):
         return self.elem,
@@ -2553,23 +2542,34 @@ class _Label(ColumnElement):
         self.obj = obj.self_group(against=operators.as_)
         self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
 
-    key = property(lambda s: s.name)
-    _label = property(lambda s: s.name)
-    proxies = property(lambda s:s.obj.proxies)
-    base_columns = property(lambda s:s.obj.base_columns)
-    proxy_set = property(lambda s:s.obj.proxy_set)
-    primary_key = property(lambda s:s.obj.primary_key)
-    foreign_keys = property(lambda s:s.obj.foreign_keys)
+    def key(self):
+        return self.name
+    key = property(key)
+    
+    def _label(self):
+        return self.name
+    _label = property(_label)
+    
+    def _proxy_attr(name):
+        def attr(self):
+            return getattr(self.obj, name)
+        return property(attr)
+            
+    proxies = _proxy_attr('proxies')
+    base_columns = _proxy_attr('base_columns')
+    proxy_set = _proxy_attr('proxy_set')
+    primary_key = _proxy_attr('primary_key')
+    foreign_keys = _proxy_attr('foreign_keys')
     
     def expression_element(self):
         return self.obj
 
-    def _copy_internals(self, clone=_clone):
-        self.obj = clone(self.obj)
-
     def get_children(self, **kwargs):
         return self.obj,
 
+    def _copy_internals(self, clone=_clone):
+        self.obj = clone(self.obj)
+
     def _get_from_objects(self, **modifiers):
         return self.obj._get_from_objects(**modifiers)
 
@@ -2623,13 +2623,8 @@ class _ColumnClause(ColumnElement):
         # ColumnClause is immutable
         return self
 
-    def _get_label(self):
-        """Generate a 'label' for this column.
-
-        The label is a product of the parent table name and column
-        name, and is treated as a unique identifier of this ``Column``
-        across all ``Tables`` and derived selectables for a particular
-        metadata collection.
+    def _label(self):
+        """Generate a 'label' string for this column.
         """
 
         # for a "literal" column, we've no idea what the text is
@@ -2647,7 +2642,7 @@ class _ColumnClause(ColumnElement):
                 self.__label = self.name
         return self.__label
 
-    _label = property(_get_label)
+    _label = property(_label)
 
     def label(self, name):
         # if going off the "__label" property and its None, we have
@@ -2903,10 +2898,9 @@ class _ScalarSelect(_Grouping):
             raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
         self.type = cols[0].type
 
-    def _no_cols(self):
+    def columns(self):
         raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
-    c = property(_no_cols)
-    columns = c
+    columns = c = property(columns)
 
     def self_group(self, **kwargs):
         return self
@@ -2979,14 +2973,15 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
             for t in s._table_iterator():
                 yield t
 
-    def _find_engine(self):
+    def bind(self):
         for s in self.selects:
-            e = s._find_engine()
+            e = s.bind
             if e:
                 return e
         else:
             return None
-
+    bind = property(bind)
+    
 class Select(_SelectBaseMixin, FromClause):
     """Represents a ``SELECT`` statement.
 
@@ -3115,15 +3110,18 @@ class Select(_SelectBaseMixin, FromClause):
         self._all_froms = froms
         return froms
 
-    def _get_inner_columns(self):
+    def inner_columns(self):
+        """a collection of all ColumnElement expressions which would 
+        be rendered into the columns clause of the resulting SELECT statement.
+        """        
+        
         for c in self._raw_columns:
             if isinstance(c, Selectable):
                 for co in c.columns:
                     yield co
             else:
                 yield c
-
-    inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""")
+    inner_columns = property(inner_columns)
 
     def is_derived_from(self, fromclause):
         if self in util.Set(fromclause._cloned_set):
@@ -3412,11 +3410,7 @@ class Select(_SelectBaseMixin, FromClause):
             if isinstance(t, TableClause):
                 yield t
 
-    def _find_engine(self):
-        """Try to return a Engine, either explicitly set in this
-        object, or searched within the from clauses for one.
-        """
-
+    def bind(self):
         if self._bind is not None:
             return self._bind
         for f in self._froms:
@@ -3436,7 +3430,8 @@ class Select(_SelectBaseMixin, FromClause):
                 self._bind = e
                 return e
         return None
-
+    bind = property(bind)
+    
 class _UpdateBase(ClauseElement):
     """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
 
@@ -3459,9 +3454,10 @@ class _UpdateBase(ClauseElement):
         else:
             return parameters
             
-    def _find_engine(self):
+    def bind(self):
         return self.table.bind
-
+    bind = property(bind)
+    
 class Insert(_UpdateBase):
     def __init__(self, table, values=None, inline=False, **kwargs):
         self.table = table
index d6b10a78a335e835eafcca7d548d4bb3b5e5255d..b45c0425c8dc1c793cfdecaafddd3a286172537d 100644 (file)
@@ -198,10 +198,10 @@ class ClauseAdapter(AbstractClauseProcessor):
         if self.exclude is not None:
             if col in self.exclude:
                 return None
-        newcol = self.selectable.corresponding_column(col, raiseerr=False, require_embedded=True)
+        newcol = self.selectable.corresponding_column(col, require_embedded=True)
         if newcol is None and self.equivalents is not None and col in self.equivalents:
             for equiv in self.equivalents[col]:
-                newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True)
+                newcol = self.selectable.corresponding_column(equiv, require_embedded=True)
                 if newcol:
                     return newcol
         return newcol
index 4796288dfa5e26185853b2911da1ce434d6637bb..1b9959ec431966a0b8534e0de3938b7c6831f6eb 100755 (executable)
@@ -24,6 +24,7 @@ table2 = Table('table2', metadata,
 
 class SelectableTest(AssertMixin):
     def testdistance(self):
+        # same column three times
         s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])
 
         # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far
@@ -50,6 +51,9 @@ class SelectableTest(AssertMixin):
     def testselectontable(self):
         sel = select([table, table2], use_labels=True)
         assert sel.corresponding_column(table.c.col1) is sel.c.table1_col1
+        assert sel.corresponding_column(table.c.col1, require_embedded=True) is sel.c.table1_col1
+        assert table.corresponding_column(sel.c.table1_col1) is table.c.col1
+        assert table.corresponding_column(sel.c.table1_col1, require_embedded=True) is None
         
     def testjoinagainstjoin(self):
         j  = outerjoin(table, table2, table.c.col1==table2.c.col2)