]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- move inline "import" statements to use new "util.importlater()" construct. cuts
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Nov 2010 18:19:36 +0000 (13:19 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Nov 2010 18:19:36 +0000 (13:19 -0500)
down on clutter, timeit says there's a teeny performance gain, at least where
the access is compared against attr.subattr.  these aren't super-critical
calls anyway
- slight inlining in _class_to_mapper

12 files changed:
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/test/util.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/aaa_profiling/test_compiler.py

index ddaf62c7f17d2ebcc4de75d1cbcf053e6df19f32..aefccb63aefefc50649542e64a38dfd45774860f 100644 (file)
@@ -23,10 +23,7 @@ from sqlalchemy import util
 from sqlalchemy.orm import interfaces, collections, exc
 import sqlalchemy.exceptions as sa_exc
 
-# lazy imports
-_entity_info = None
-identity_equal = None
-state = None
+mapperutil = util.importlater("sqlalchemy.orm", "util")
 
 PASSIVE_NO_RESULT = util.symbol('PASSIVE_NO_RESULT')
 ATTR_WAS_SET = util.symbol('ATTR_WAS_SET')
@@ -571,7 +568,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
                                             compare_function=compare_function, 
                                             **kwargs)
         if compare_function is None:
-            self.is_equal = identity_equal
+            self.is_equal = mapperutil.identity_equal
 
     def delete(self, state, dict_):
         old = self.get(state, dict_)
index a14ad647a8242c4cfcd87063b831aeb4d86077c1..e9da4f5337fc3c82461c26f0237852b82ef32bc9 100644 (file)
@@ -31,6 +31,8 @@ from sqlalchemy.orm.util import (
      _state_mapper, class_mapper, instance_str, state_str,
      )
 import sys
+sessionlib = util.importlater("sqlalchemy.orm", "session")
+properties = util.importlater("sqlalchemy.orm", "properties")
 
 __all__ = (
     'Mapper',
@@ -57,13 +59,6 @@ NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE')
 # lock used to synchronize the "mapper compile" step
 _COMPILE_MUTEX = util.threading.RLock()
 
-# initialize these lazily
-ColumnProperty = None
-RelationshipProperty = None
-ConcreteInheritedProperty = None
-_expire_state = None
-_state_session = None
-
 class Mapper(object):
     """Define the correlation of class attributes to database table
     columns.
@@ -596,7 +591,7 @@ class Mapper(object):
                     
             self._configure_property(
                             col.key, 
-                            ColumnProperty(col, _instrument=instrument),
+                            properties.ColumnProperty(col, _instrument=instrument),
                             init=False, setparent=True)
 
     def _adapt_inherited_property(self, key, prop, init):
@@ -605,7 +600,7 @@ class Mapper(object):
         elif key not in self._props:
             self._configure_property(
                             key, 
-                            ConcreteInheritedProperty(), 
+                            properties.ConcreteInheritedProperty(), 
                             init=init, setparent=True)
             
     def _configure_property(self, key, prop, init=True, setparent=True):
@@ -613,7 +608,7 @@ class Mapper(object):
 
         if not isinstance(prop, MapperProperty):
             # we were passed a Column or a list of Columns; 
-            # generate a ColumnProperty
+            # generate a properties.ColumnProperty
             columns = util.to_list(prop)
             column = columns[0]
             if not expression.is_column(column):
@@ -623,12 +618,12 @@ class Mapper(object):
 
             prop = self._props.get(key, None)
 
-            if isinstance(prop, ColumnProperty):
+            if isinstance(prop, properties.ColumnProperty):
                 # TODO: the "property already exists" case is still not 
                 # well defined here. assuming single-column, etc.
 
                 if prop.parent is not self:
-                    # existing ColumnProperty from an inheriting mapper.
+                    # existing properties.ColumnProperty from an inheriting mapper.
                     # make a copy and append our column to it
                     prop = prop.copy()
                 else:
@@ -643,9 +638,9 @@ class Mapper(object):
                 # this hypothetically changes to 
                 # prop.columns.insert(0, column) when we do [ticket:1892]
                 prop.columns.append(column)
-                self._log("appending to existing ColumnProperty %s" % (key))
+                self._log("appending to existing properties.ColumnProperty %s" % (key))
                              
-            elif prop is None or isinstance(prop, ConcreteInheritedProperty):
+            elif prop is None or isinstance(prop, properties.ConcreteInheritedProperty):
                 mapped_column = []
                 for c in columns:
                     mc = self.mapped_table.corresponding_column(c)
@@ -666,7 +661,7 @@ class Mapper(object):
                             "force this column to be mapped as a read-only "
                             "attribute." % (key, self, c))
                     mapped_column.append(mc)
-                prop = ColumnProperty(*mapped_column)
+                prop = properties.ColumnProperty(*mapped_column)
             else:
                 raise sa_exc.ArgumentError(
                     "WARNING: when configuring property '%s' on %s, "
@@ -680,7 +675,7 @@ class Mapper(object):
                     "columns get mapped." % 
                     (key, self, column.key, prop))
 
-        if isinstance(prop, ColumnProperty):
+        if isinstance(prop, properties.ColumnProperty):
             col = self.mapped_table.corresponding_column(prop.columns[0])
             
             # if the column is not present in the mapped table, 
@@ -719,7 +714,7 @@ class Mapper(object):
                                     col not in self._cols_by_table[col.table]:
                     self._cols_by_table[col.table].add(col)
             
-            # if this ColumnProperty represents the "polymorphic
+            # if this properties.ColumnProperty represents the "polymorphic
             # discriminator" column, mark it.  We'll need this when rendering
             # columns in SELECT statements.
             if not hasattr(prop, '_is_polymorphic_discriminator'):
@@ -1897,7 +1892,7 @@ class Mapper(object):
             )
             
             if readonly:
-                _expire_state(state, state.dict, readonly)
+                sessionlib._expire_state(state, state.dict, readonly)
 
             # if eager_defaults option is enabled,
             # refresh whatever has been expired.
@@ -1939,7 +1934,7 @@ class Mapper(object):
                 self._set_state_attr_by_column(state, dict_, c, params[c.key])
 
         if postfetch_cols:
-            _expire_state(state, state.dict, 
+            sessionlib._expire_state(state, state.dict, 
                                 [self._columntoproperty[c].key 
                                 for c in postfetch_cols]
                             )
@@ -2437,7 +2432,7 @@ def _load_scalar_attributes(state, attribute_names):
     """initiate a column-based attribute refresh operation."""
     
     mapper = _state_mapper(state)
-    session = _state_session(state)
+    session = sessionlib._state_session(state)
     if not session:
         raise orm_exc.DetachedInstanceError(
                     "Instance %s is not bound to a Session; "
index 0cbbf630d4fb36a4aeb42bbdd34fb91e0476186a..feee041cea4127ecc1bc8845676a2985abe95b07 100644 (file)
@@ -1483,6 +1483,3 @@ class RelationshipProperty(StrategizedProperty):
 PropertyLoader = RelationProperty = RelationshipProperty
 log.class_logger(RelationshipProperty)
 
-mapper.ColumnProperty = ColumnProperty
-mapper.RelationshipProperty = RelationshipProperty
-mapper.ConcreteInheritedProperty = ConcreteInheritedProperty
index 80c353ebca58446ebbc3d16a7a5ac4630429f7e0..b9b935c88e3262aad570de919eea71a3c04dae94 100644 (file)
@@ -1695,8 +1695,3 @@ def _state_session(state):
             pass
     return None
 
-# Lazy initialization to avoid circular imports
-unitofwork._state_session = _state_session
-from sqlalchemy.orm import mapper
-mapper._expire_state = _expire_state
-mapper._state_session = _state_session
index a9808e6ba6e4f29ccc6a54ecc7b5a7788dae0eb8..673591e8e352b082ce35c28bfe482696c05ecb10 100644 (file)
@@ -16,9 +16,7 @@ from sqlalchemy import util, topological
 from sqlalchemy.orm import attributes, interfaces
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.util import _state_mapper
-
-# Load lazily
-_state_session = None
+session = util.importlater("sqlalchemy.orm", "session")
 
 class UOWEventHandler(interfaces.AttributeExtension):
     """An event handler added to all relationship attributes which handles
@@ -34,7 +32,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
         # process "save_update" cascade rules for when 
         # an instance is appended to the list of another instance
 
-        sess = _state_session(state)
+        sess = session._state_session(state)
         if sess:
             prop = _state_mapper(state).get_property(self.key)
             if prop.cascade.save_update and \
@@ -44,7 +42,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
         return item
         
     def remove(self, state, item, initiator):
-        sess = _state_session(state)
+        sess = session._state_session(state)
         if sess:
             prop = _state_mapper(state).get_property(self.key)
             # expunge pending orphans
@@ -59,7 +57,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
         if oldvalue is newvalue:
             return newvalue
 
-        sess = _state_session(state)
+        sess = session._state_session(state)
         if sess:
             prop = _state_mapper(state).get_property(self.key)
             if newvalue is not None and \
index f79a8449fc4a5a179bcc9a78772278193163043e..db28089efb8850caffbb7cf6b33126854dc7f6c6 100644 (file)
@@ -13,7 +13,7 @@ from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\
                                 AttributeExtension
 from sqlalchemy.orm import attributes, exc
 
-mapperlib = None
+mapperlib = util.importlater("sqlalchemy.orm", "mapperlib")
 
 all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
                           "expunge", "save-update", "refresh-expire",
@@ -532,10 +532,6 @@ def _entity_info(entity, compile=True):
     if isinstance(entity, AliasedClass):
         return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
 
-    global mapperlib
-    if mapperlib is None:
-        from sqlalchemy.orm import mapperlib
-    
     if isinstance(entity, mapperlib.Mapper):
         mapper = entity
         
@@ -630,24 +626,28 @@ def class_mapper(class_, compile=True):
 def _class_to_mapper(class_or_mapper, compile=True):
     if _is_aliased_class(class_or_mapper):
         return class_or_mapper._AliasedClass__mapper
+
     elif isinstance(class_or_mapper, type):
-        return class_mapper(class_or_mapper, compile=compile)
-    elif hasattr(class_or_mapper, 'compile'):
-        if compile:
-            return class_or_mapper.compile()
-        else:
-            return class_or_mapper
+        try:
+            class_manager = attributes.manager_of_class(class_or_mapper)
+            mapper = class_manager.mapper
+        except exc.NO_STATE:
+            raise exc.UnmappedClassError(class_or_mapper)
+    elif isinstance(class_or_mapper, mapperlib.Mapper):
+        mapper = class_or_mapper
     else:
         raise exc.UnmappedClassError(class_or_mapper)
+        
+    if compile:
+        return mapper.compile()
+    else:
+        return mapper
 
 def has_identity(object):
     state = attributes.instance_state(object)
     return state.has_identity
 
 def _is_mapped_class(cls):
-    global mapperlib
-    if mapperlib is None:
-        from sqlalchemy.orm import mapperlib
     if isinstance(cls, (AliasedClass, mapperlib.Mapper)):
         return True
     if isinstance(cls, expression.ClauseElement):
@@ -690,8 +690,3 @@ def identity_equal(a, b):
         return False
     return state_a.key == state_b.key
 
-
-# TODO: Avoid circular import.
-attributes.identity_equal = identity_equal
-attributes._is_aliased_class = _is_aliased_class
-attributes._entity_info = _entity_info
index 069e58ceddc65eaaf87f52099668cfdb43b1dcce..8e937968d0b0c4190a8d8b13b4edbfbef15c3fec 100644 (file)
@@ -32,7 +32,9 @@ import re, inspect
 from sqlalchemy import exc, util, dialects
 from sqlalchemy.sql import expression, visitors
 
-URL = None
+sqlutil = util.importlater("sqlalchemy.sql", "util")
+url = util.importlater("sqlalchemy.engine", "url")
+
 
 __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
            'ForeignKeyConstraint', 'PrimaryKeyConstraint', 'CheckConstraint',
@@ -1979,11 +1981,7 @@ class MetaData(SchemaItem):
     def _bind_to(self, bind):
         """Bind this MetaData to an Engine, Connection, string or URL."""
 
-        global URL
-        if URL is None:
-            from sqlalchemy.engine.url import URL
-
-        if isinstance(bind, (basestring, URL)):
+        if isinstance(bind, (basestring, url.URL)):
             from sqlalchemy import create_engine
             self._bind = create_engine(bind)
         else:
@@ -2007,8 +2005,7 @@ class MetaData(SchemaItem):
         """Returns a list of ``Table`` objects sorted in order of
         dependency.
         """
-        from sqlalchemy.sql.util import sort_tables
-        return sort_tables(self.tables.itervalues())
+        return sqlutil.sort_tables(self.tables.itervalues())
         
     def reflect(self, bind=None, schema=None, views=False, only=None):
         """Load all available table definitions from the database.
@@ -2205,11 +2202,7 @@ class ThreadLocalMetaData(MetaData):
     def _bind_to(self, bind):
         """Bind to a Connectable in the caller's thread."""
 
-        global URL
-        if URL is None:
-            from sqlalchemy.engine.url import URL
-
-        if isinstance(bind, (basestring, URL)):
+        if isinstance(bind, (basestring, url.URL)):
             try:
                 self.context._engine = self.__engines[bind]
             except KeyError:
index 625893a68a1085785c45e8a2ad715fabd9b0d213..c3dc339a50dd9adaac92e502228e2b0c9767da94 100644 (file)
@@ -29,13 +29,15 @@ to stay the same in future releases.
 import itertools, re
 from operator import attrgetter
 
-from sqlalchemy import util, exc #, types as sqltypes
+from sqlalchemy import util, exc
 from sqlalchemy.sql import operators
 from sqlalchemy.sql.visitors import Visitable, cloned_traverse
 import operator
 
-functions, sql_util, sqltypes = None, None, None
-DefaultDialect = None
+functions = util.importlater("sqlalchemy.sql", "functions")
+sqlutil = util.importlater("sqlalchemy.sql", "util")
+sqltypes = util.importlater("sqlalchemy", "types")
+default = util.importlater("sqlalchemy.engine", "default")
 
 __all__ = [
     'Alias', 'ClauseElement', 'ColumnCollection', 'ColumnElement',
@@ -957,9 +959,6 @@ class _FunctionGenerator(object):
         o = self.opts.copy()
         o.update(kwargs)
         if len(self.__names) == 1:
-            global functions
-            if functions is None:
-                from sqlalchemy.sql import functions
             func = getattr(functions, self.__names[-1].lower(), None)
             if func is not None and \
                     isinstance(func, type) and \
@@ -1205,10 +1204,7 @@ class ClauseElement(Visitable):
         dictionary.
         
         """
-        global sql_util
-        if sql_util is None:
-            from sqlalchemy.sql import util as sql_util
-        return sql_util.Annotated(self, values)
+        return sqlutil.Annotated(self, values)
 
     def _deannotate(self):
         """return a copy of this ClauseElement with an empty annotations
@@ -1389,10 +1385,7 @@ class ClauseElement(Visitable):
                 dialect = self.bind.dialect
                 bind = self.bind
             else:
-                global DefaultDialect
-                if DefaultDialect is None:
-                    from sqlalchemy.engine.default import DefaultDialect
-                dialect = DefaultDialect()
+                dialect = default.DefaultDialect()
         compiler = self._compiler(dialect, bind=bind, **kw)
         compiler.compile()
         return compiler
@@ -2154,10 +2147,7 @@ class FromClause(Selectable):
         
         """
 
-        global sql_util
-        if sql_util is None:
-            from sqlalchemy.sql import util as sql_util
-        return sql_util.ClauseAdapter(alias).traverse(self)
+        return sqlutil.ClauseAdapter(alias).traverse(self)
 
     def correspond_on_equivalents(self, column, equivalents):
         """Return corresponding_column for the given column, or if None
@@ -3098,10 +3088,7 @@ class Join(FromClause):
         columns = [c for c in self.left.columns] + \
                         [c for c in self.right.columns]
 
-        global sql_util
-        if not sql_util:
-            from sqlalchemy.sql import util as sql_util
-        self._primary_key.extend(sql_util.reduce_columns(
+        self._primary_key.extend(sqlutil.reduce_columns(
                 (c for c in columns if c.primary_key), self.onclause))
         self._columns.update((col._label, col) for col in columns)
         self._foreign_keys.update(itertools.chain(
@@ -3118,14 +3105,11 @@ class Join(FromClause):
         return self.left, self.right, self.onclause
 
     def _match_primaries(self, left, right):
-        global sql_util
-        if not sql_util:
-            from sqlalchemy.sql import util as sql_util
         if isinstance(left, Join):
             left_right = left.right
         else:
             left_right = None
-        return sql_util.join_condition(left, right, a_subset=left_right)
+        return sqlutil.join_condition(left, right, a_subset=left_right)
 
     def select(self, whereclause=None, fold_equivalents=False, **kwargs):
         """Create a :class:`Select` from this :class:`Join`.
@@ -3145,11 +3129,8 @@ class Join(FromClause):
           underlying :func:`select()` function.
 
         """
-        global sql_util
-        if not sql_util:
-            from sqlalchemy.sql import util as sql_util
         if fold_equivalents:
-            collist = sql_util.folded_equivalents(self)
+            collist = sqlutil.folded_equivalents(self)
         else:
             collist = [self.left, self.right]
 
index ff2c3d7b79bbc373a79e7b3b6729e0a1dec59f03..f2b6b49ea29098a0104d661dcbd54b2c2d264b69 100644 (file)
@@ -22,8 +22,6 @@ else:
     def lazy_gc():
         pass
 
-
-
 def picklers():
     picklers = set()
     # Py2K
index ee1fdc67f57b15147a728228e3962b5c41b63d74..9f322d1eb584e2c30abf852bf2ebbf551fbf830e 100644 (file)
@@ -34,8 +34,8 @@ from sqlalchemy.sql.visitors import Visitable
 from sqlalchemy import util
 from sqlalchemy import processors
 import collections
+default = util.importlater("sqlalchemy.engine", "default")
 
-DefaultDialect = None
 NoneType = type(None)
 if util.jython:
     import array
@@ -143,10 +143,7 @@ class AbstractType(Visitable):
             mod = ".".join(tokens)
             return getattr(__import__(mod).dialects, tokens[-1]).dialect()
         else:
-            global DefaultDialect
-            if DefaultDialect is None:
-                from sqlalchemy.engine.default import DefaultDialect
-            return DefaultDialect()
+            return default.DefaultDialect()
         
     def __str__(self):
         # Py3K
index 922d5bee527f63251c4929bf590532b9604cb3d3..8665cd0d4ac699d8ec075ae9bcfa7427aec4c2c9 100644 (file)
@@ -1558,7 +1558,51 @@ class group_expirable_memoized_property(object):
         self.attributes.append(fn.__name__)
         return memoized_property(fn)
 
-
+class importlater(object):
+    """Deferred import object.
+    
+    e.g.::
+    
+        somesubmod = importlater("mypackage.somemodule", "somesubmod")
+        
+    is equivalent to::
+    
+        from mypackage.somemodule import somesubmod
+        
+    except evaluted upon attribute access to "somesubmod".
+    
+    """
+    def __init__(self, path, addtl=None):
+        self._il_path = path
+        self._il_addtl = addtl
+    
+    @memoized_property
+    def _il_module(self):
+        m = __import__(self._il_path)
+        for token in self._il_path.split(".")[1:]:
+            m = getattr(m, token)
+        if self._il_addtl:
+            try:
+                return getattr(m, self._il_addtl)
+            except AttributeError:
+                raise AttributeError(
+                        "Module %s has no attribute '%s'" % 
+                        (self._il_path, self._il_addtl)
+                    )
+        else:
+            return m
+        
+    def __getattr__(self, key):
+        try:
+            attr = getattr(self._il_module, key)
+        except AttributeError:
+            raise AttributeError(
+                        "Module %s has no attribute '%s'" % 
+                        (self._il_path, key)
+                    )
+        self.__dict__[key] = attr
+        return attr
+        
 class WeakIdentityMapping(weakref.WeakKeyDictionary):
     """A WeakKeyDictionary with an object identity index.
 
index a7f64410b1e7eb439e2feaa03a9783fa4616f657..bc589c0b2fdb87447d0d875140c5a2ea8f58695c 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy.test import *
 class CompileTest(TestBase, AssertsExecutionResults):
     @classmethod
     def setup_class(cls):
+        
         global t1, t2, metadata
         metadata = MetaData()
         t1 = Table('t1', metadata,
@@ -15,6 +16,10 @@ class CompileTest(TestBase, AssertsExecutionResults):
             Column('c1', Integer, primary_key=True),
             Column('c2', String(30)))
 
+        # do a "compile" ahead of time to load
+        # deferred imports
+        t1.insert().compile()
+
         # go through all the TypeEngine
         # objects in use and pre-load their _type_affinity
         # entries.